summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
diff options
context:
space:
mode:
authorAndy Davis <andydavis@google.com>2019-12-17 06:26:31 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-17 06:27:01 -0800
commit4e825c59be48b602a4790c91df0801138f3cbb6e (patch)
treeb3efe66a150eca10d5d60b58e80855a3ff5412e5 /mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
parent80ec474a65a29a740b2edf7cc77d493ab4013a6b (diff)
downloadbcm5719-llvm-4e825c59be48b602a4790c91df0801138f3cbb6e.tar.gz
bcm5719-llvm-4e825c59be48b602a4790c91df0801138f3cbb6e.zip
Update vector op unrolling transformation to generate ExtractSlicesOp and InsertSlicesOp (instead of less structured chain of StridedSliceOps and InsertStridedSliceOps).
PiperOrigin-RevId: 285968051
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorTransforms.cpp115
1 files changed, 90 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
index 6825709334b..8d70f4ac83f 100644
--- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp
@@ -142,6 +142,47 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap,
}
}
+// Returns a tuple type with vector element types for each resulting slice
+// of 'vectorType' unrolled by 'sizes' and 'strides'.
+// TODO(andydavis) Move this to a utility function and share it with
+// Extract/InsertSlicesOp verification.
+static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
+ ArrayRef<int64_t> sizes,
+ ArrayRef<int64_t> strides,
+ PatternRewriter &builder) {
+ assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
+ unsigned rank = vectorType.getRank();
+ assert(sizes.size() == rank);
+ assert(strides.size() == rank);
+
+ // Compute shape ratio of 'shape' and 'sizes'.
+ auto shape = vectorType.getShape();
+ auto maybeDimSliceCounts = shapeRatio(shape, sizes);
+ assert(maybeDimSliceCounts.hasValue());
+ auto sliceDimCounts = *maybeDimSliceCounts;
+
+ // Compute strides w.r.t number of slices in each dimension.
+ auto basis = computeStrides(sliceDimCounts);
+ int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts);
+ SmallVector<Type, 4> vectorTypes(sliceCount);
+ for (unsigned i = 0; i < sliceCount; ++i) {
+ // De-linearize w.r.t. 'basis'.
+ auto vectorOffsets = delinearize(i, basis);
+ // Convert from unrolled vector-space offsets to element-space offsets.
+ auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
+ vectorOffsets, sizes);
+ // Initialize 'sliceSizes' to target 'sizes'
+ SmallVector<int64_t, 4> sliceSizes(sizes.begin(), sizes.end());
+ for (unsigned j = 0; j < rank; ++j) {
+ // Based on 'offsets' and 'shape' clip some dim sizes for partial tiles.
+ sliceSizes[j] = std::min(sliceSizes[j], shape[j] - offsets[j]);
+ }
+ // Create Vector type and add to 'vectorTypes[i]'.
+ vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
+ }
+ return TupleType::get(vectorTypes, builder.getContext());
+}
+
// UnrolledVectorState aggregates per-operand/result vector state required for
// unrolling.
struct UnrolledVectorState {
@@ -149,14 +190,16 @@ struct UnrolledVectorState {
SmallVector<int64_t, 4> unrollFactors;
SmallVector<int64_t, 8> basis;
int64_t numInstances;
+ Value *slicesTuple;
};
// Populates 'state' with unrolled shape, unroll factors, basis and
// num unrolled instances for 'vectorType'.
-static void initUnrolledVectorState(VectorType vectorType,
+static void initUnrolledVectorState(VectorType vectorType, Value *initValue,
const DenseMap<int64_t, int64_t> &indexMap,
ArrayRef<int64_t> targetShape,
- UnrolledVectorState &state) {
+ UnrolledVectorState &state,
+ PatternRewriter &builder) {
// Compute unrolled shape of 'vectorType'.
state.unrolledShape.resize(vectorType.getRank());
getMappedElements(indexMap, targetShape, state.unrolledShape);
@@ -168,6 +211,16 @@ static void initUnrolledVectorState(VectorType vectorType,
// Compute 'basis' and 'numInstances' based on 'state.unrollFactors'.
state.basis = computeStrides(state.unrollFactors);
state.numInstances = computeMaxLinearIndex(state.unrollFactors);
+ state.slicesTuple = nullptr;
+ if (initValue != nullptr) {
+ // Create ExtractSlicesOp.
+ SmallVector<int64_t, 4> sizes(state.unrolledShape);
+ SmallVector<int64_t, 4> strides(state.unrollFactors.size(), 1);
+ auto tupleType =
+ generateExtractSlicesOpResultType(vectorType, sizes, strides, builder);
+ state.slicesTuple = builder.create<vector::ExtractSlicesOp>(
+ initValue->getLoc(), tupleType, initValue, sizes, strides);
+ }
}
// Computes and returns the linear index of the unrolled vector at
@@ -202,10 +255,14 @@ static Value *getOrCreateUnrolledVectorSlice(
assert(sliceLinearIndex < static_cast<int64_t>(cache.size()));
auto *valueSlice = cache[sliceLinearIndex];
if (valueSlice == nullptr) {
- assert(initValue != nullptr);
- // Initialize 'cache' with slice from 'state.value'.
- valueSlice = builder.create<vector::StridedSliceOp>(
- loc, initValue, sliceOffsets, state.unrolledShape, sliceStrides);
+ // Return tuple element at 'sliceLinearIndex'.
+ auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex);
+ auto initValueType = initValue->getType().cast<VectorType>();
+ auto vectorType =
+ VectorType::get(state.unrolledShape, initValueType.getElementType());
+ // Initialize 'cache' with slice from 'initValue'.
+ valueSlice = builder.create<vector::TupleGetOp>(
+ loc, vectorType, state.slicesTuple, tupleIndex);
// Store value back to 'cache'.
cache[sliceLinearIndex] = valueSlice;
}
@@ -293,8 +350,10 @@ static Value *unrollSingleResultStructuredOp(Operation *op,
unsigned numVectors = vectors.size();
SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors);
for (unsigned i = 0; i < numVectors; ++i) {
- initUnrolledVectorState(vectors[i].type, vectors[i].indexMap, targetShape,
- unrolledVectorState[i]);
+ int64_t operandIndex = vectors[i].operandIndex;
+ auto *operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr;
+ initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap,
+ targetShape, unrolledVectorState[i], builder);
}
// Compute number of total unrolled instances.
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
@@ -341,21 +400,26 @@ static Value *unrollSingleResultStructuredOp(Operation *op,
caches[resultIndex][linearIndex] = resultVector;
}
- // Make zero splat into which we will insert results from
- // 'cache[resultIndex]'
- auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
- auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
- SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
- // Insert vector accumulators into output.
+ // Create TupleOp of unrolled result vectors.
+ SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances);
+ SmallVector<Value *, 4> vectorTupleValues(resultValueState.numInstances);
for (unsigned i = 0; i < resultValueState.numInstances; ++i) {
- auto vectorOffsets = delinearize(i, resultValueState.basis);
- // Convert from unrolled vector-space offsets to element-space offsets.
- auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
- vectorOffsets, resultValueState.unrolledShape);
- res = builder.create<vector::InsertStridedSliceOp>(
- op->getLoc(), caches[resultIndex][i], res, offsets, strides);
+ vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>();
+ vectorTupleValues[i] = caches[resultIndex][i];
}
- return res;
+ TupleType tupleType = builder.getTupleType(vectorTupleTypes);
+ Value *tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType,
+ vectorTupleValues);
+
+ // Create InsertSlicesOp(Tuple(result_vectors)).
+ auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
+ SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape);
+ SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1);
+
+ Value *insertSlicesOp = builder.create<vector::InsertSlicesOp>(
+ op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes),
+ builder.getI64ArrayAttr(strides));
+ return insertSlicesOp;
}
static void getVectorContractionOpUnrollState(
@@ -381,10 +445,10 @@ static void getVectorContractionOpUnrollState(
if (llvm::size(contractionOp.masks()) == 2) {
// Add vectors for lhs/rhs vector mask arguments. Masks have the
// same vector shape lhs/rhs args, so copy their index maps.
- vectors.push_back(
- {vectors[0].type, vectors[0].indexMap, accOperandIndex + 1, false});
- vectors.push_back(
- {vectors[1].type, vectors[1].indexMap, accOperandIndex + 2, false});
+ vectors.push_back({contractionOp.getLHSVectorMaskType(),
+ vectors[0].indexMap, accOperandIndex + 1, false});
+ vectors.push_back({contractionOp.getRHSVectorMaskType(),
+ vectors[1].indexMap, accOperandIndex + 2, false});
}
// Unroll 'op' 'iterationBounds' to 'targetShape'.
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
@@ -509,6 +573,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
}
};
+// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO(andydavis) Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
OpenPOWER on IntegriCloud