diff options
author | Andy Davis <andydavis@google.com> | 2019-12-17 06:26:31 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-17 06:27:01 -0800 |
commit | 4e825c59be48b602a4790c91df0801138f3cbb6e (patch) | |
tree | b3efe66a150eca10d5d60b58e80855a3ff5412e5 /mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | |
parent | 80ec474a65a29a740b2edf7cc77d493ab4013a6b (diff) | |
download | bcm5719-llvm-4e825c59be48b602a4790c91df0801138f3cbb6e.tar.gz bcm5719-llvm-4e825c59be48b602a4790c91df0801138f3cbb6e.zip |
Update vector op unrolling transformation to generate ExtractSlicesOp and InsertSlicesOp (instead of less structured chain of StridedSliceOps and InsertStridedSliceOps).
PiperOrigin-RevId: 285968051
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 115 |
1 files changed, 90 insertions, 25 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 6825709334b..8d70f4ac83f 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -142,6 +142,47 @@ static void getMappedElements(const DenseMap<int64_t, int64_t> &indexMap, } } +// Returns a tuple type with vector element types for each resulting slice +// of 'vectorType' unrolled by 'sizes' and 'strides'. +// TODO(andydavis) Move this to a utility function and share it with +// Extract/InsertSlicesOp verification. +static TupleType generateExtractSlicesOpResultType(VectorType vectorType, + ArrayRef<int64_t> sizes, + ArrayRef<int64_t> strides, + PatternRewriter &builder) { + assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); + unsigned rank = vectorType.getRank(); + assert(sizes.size() == rank); + assert(strides.size() == rank); + + // Compute shape ratio of 'shape' and 'sizes'. + auto shape = vectorType.getShape(); + auto maybeDimSliceCounts = shapeRatio(shape, sizes); + assert(maybeDimSliceCounts.hasValue()); + auto sliceDimCounts = *maybeDimSliceCounts; + + // Compute strides w.r.t number of slices in each dimension. + auto basis = computeStrides(sliceDimCounts); + int64_t sliceCount = computeMaxLinearIndex(sliceDimCounts); + SmallVector<Type, 4> vectorTypes(sliceCount); + for (unsigned i = 0; i < sliceCount; ++i) { + // De-linearize w.r.t. 'basis'. + auto vectorOffsets = delinearize(i, basis); + // Convert from unrolled vector-space offsets to element-space offsets. + auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, + vectorOffsets, sizes); + // Initialize 'sliceSizes' to target 'sizes' + SmallVector<int64_t, 4> sliceSizes(sizes.begin(), sizes.end()); + for (unsigned j = 0; j < rank; ++j) { + // Based on 'offsets' and 'shape' clip some dim sizes for partial tiles. + sliceSizes[j] = std::min(sliceSizes[j], shape[j] - offsets[j]); + } + // Create Vector type and add to 'vectorTypes[i]'. + vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType()); + } + return TupleType::get(vectorTypes, builder.getContext()); +} + // UnrolledVectorState aggregates per-operand/result vector state required for // unrolling. struct UnrolledVectorState { @@ -149,14 +190,16 @@ struct UnrolledVectorState { SmallVector<int64_t, 4> unrollFactors; SmallVector<int64_t, 8> basis; int64_t numInstances; + Value *slicesTuple; }; // Populates 'state' with unrolled shape, unroll factors, basis and // num unrolled instances for 'vectorType'. -static void initUnrolledVectorState(VectorType vectorType, +static void initUnrolledVectorState(VectorType vectorType, Value *initValue, const DenseMap<int64_t, int64_t> &indexMap, ArrayRef<int64_t> targetShape, - UnrolledVectorState &state) { + UnrolledVectorState &state, + PatternRewriter &builder) { // Compute unrolled shape of 'vectorType'. state.unrolledShape.resize(vectorType.getRank()); getMappedElements(indexMap, targetShape, state.unrolledShape); @@ -168,6 +211,16 @@ static void initUnrolledVectorState(VectorType vectorType, // Compute 'basis' and 'numInstances' based on 'state.unrollFactors'. state.basis = computeStrides(state.unrollFactors); state.numInstances = computeMaxLinearIndex(state.unrollFactors); + state.slicesTuple = nullptr; + if (initValue != nullptr) { + // Create ExtractSlicesOp. + SmallVector<int64_t, 4> sizes(state.unrolledShape); + SmallVector<int64_t, 4> strides(state.unrollFactors.size(), 1); + auto tupleType = + generateExtractSlicesOpResultType(vectorType, sizes, strides, builder); + state.slicesTuple = builder.create<vector::ExtractSlicesOp>( + initValue->getLoc(), tupleType, initValue, sizes, strides); + } } // Computes and returns the linear index of the unrolled vector at @@ -202,10 +255,14 @@ static Value *getOrCreateUnrolledVectorSlice( assert(sliceLinearIndex < static_cast<int64_t>(cache.size())); auto *valueSlice = cache[sliceLinearIndex]; if (valueSlice == nullptr) { - assert(initValue != nullptr); - // Initialize 'cache' with slice from 'state.value'. - valueSlice = builder.create<vector::StridedSliceOp>( - loc, initValue, sliceOffsets, state.unrolledShape, sliceStrides); + // Return tuple element at 'sliceLinearIndex'. + auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex); + auto initValueType = initValue->getType().cast<VectorType>(); + auto vectorType = + VectorType::get(state.unrolledShape, initValueType.getElementType()); + // Initialize 'cache' with slice from 'initValue'. + valueSlice = builder.create<vector::TupleGetOp>( + loc, vectorType, state.slicesTuple, tupleIndex); // Store value back to 'cache'. cache[sliceLinearIndex] = valueSlice; } @@ -293,8 +350,10 @@ static Value *unrollSingleResultStructuredOp(Operation *op, unsigned numVectors = vectors.size(); SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors); for (unsigned i = 0; i < numVectors; ++i) { - initUnrolledVectorState(vectors[i].type, vectors[i].indexMap, targetShape, - unrolledVectorState[i]); + int64_t operandIndex = vectors[i].operandIndex; + auto *operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; + initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap, + targetShape, unrolledVectorState[i], builder); } // Compute number of total unrolled instances. auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); @@ -341,21 +400,26 @@ static Value *unrollSingleResultStructuredOp(Operation *op, caches[resultIndex][linearIndex] = resultVector; } - // Make zero splat into which we will insert results from - // 'cache[resultIndex]' - auto resultVectorType = op->getResult(0)->getType().cast<VectorType>(); - auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType); - SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1); - // Insert vector accumulators into output. + // Create TupleOp of unrolled result vectors. + SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances); + SmallVector<Value *, 4> vectorTupleValues(resultValueState.numInstances); for (unsigned i = 0; i < resultValueState.numInstances; ++i) { - auto vectorOffsets = delinearize(i, resultValueState.basis); - // Convert from unrolled vector-space offsets to element-space offsets. - auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, - vectorOffsets, resultValueState.unrolledShape); - res = builder.create<vector::InsertStridedSliceOp>( - op->getLoc(), caches[resultIndex][i], res, offsets, strides); + vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>(); + vectorTupleValues[i] = caches[resultIndex][i]; } - return res; + TupleType tupleType = builder.getTupleType(vectorTupleTypes); + Value *tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType, + vectorTupleValues); + + // Create InsertSlicesOp(Tuple(result_vectors)). + auto resultVectorType = op->getResult(0)->getType().cast<VectorType>(); + SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape); + SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1); + + Value *insertSlicesOp = builder.create<vector::InsertSlicesOp>( + op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes), + builder.getI64ArrayAttr(strides)); + return insertSlicesOp; } static void getVectorContractionOpUnrollState( @@ -381,10 +445,10 @@ static void getVectorContractionOpUnrollState( if (llvm::size(contractionOp.masks()) == 2) { // Add vectors for lhs/rhs vector mask arguments. Masks have the // same vector shape lhs/rhs args, so copy their index maps. - vectors.push_back( - {vectors[0].type, vectors[0].indexMap, accOperandIndex + 1, false}); - vectors.push_back( - {vectors[1].type, vectors[1].indexMap, accOperandIndex + 2, false}); + vectors.push_back({contractionOp.getLHSVectorMaskType(), + vectors[0].indexMap, accOperandIndex + 1, false}); + vectors.push_back({contractionOp.getRHSVectorMaskType(), + vectors[1].indexMap, accOperandIndex + 2, false}); } // Unroll 'op' 'iterationBounds' to 'targetShape'. // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition @@ -509,6 +573,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { } }; +// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { |