diff options
Diffstat (limited to 'mlir/lib/Dialect/VectorOps/VectorTransforms.cpp')
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorTransforms.cpp | 76 |
1 files changed, 37 insertions, 39 deletions
diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp index 64cacb28720..e5c281cbf64 100644 --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -106,17 +106,17 @@ static SmallVector<int64_t, 8> delinearize(int64_t linearIndex, // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, Location loc, Operation *op, - ArrayRef<Value *> operands, + ArrayRef<ValuePtr> operands, ArrayRef<Type> resultTypes) { OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, op->getAttrs()); return builder.createOperation(res); } -static Value *makeSplatZero(Location loc, PatternRewriter &rewriter, - VectorType vt) { +static ValuePtr makeSplatZero(Location loc, PatternRewriter &rewriter, + VectorType vt) { auto t = vt.getElementType(); - Value *f = nullptr; + ValuePtr f = nullptr; if (t.isBF16() || t.isF16()) f = rewriter.create<ConstantOp>(loc, t, rewriter.getF64FloatAttr(0.0f)); else if (t.isF32()) @@ -190,12 +190,12 @@ struct UnrolledVectorState { SmallVector<int64_t, 4> unrollFactors; SmallVector<int64_t, 8> basis; int64_t numInstances; - Value *slicesTuple; + ValuePtr slicesTuple; }; // Populates 'state' with unrolled shape, unroll factors, basis and // num unrolled instances for 'vectorType'. -static void initUnrolledVectorState(VectorType vectorType, Value *initValue, +static void initUnrolledVectorState(VectorType vectorType, ValuePtr initValue, const DenseMap<int64_t, int64_t> &indexMap, ArrayRef<int64_t> targetShape, UnrolledVectorState &state, @@ -239,10 +239,10 @@ getUnrolledVectorLinearIndex(UnrolledVectorState &state, // Returns an unrolled vector at 'vectorOffsets' within the vector // represented by 'state'. The vector is created from a slice of 'initValue' // if not present in 'cache'. -static Value *getOrCreateUnrolledVectorSlice( +static ValuePtr getOrCreateUnrolledVectorSlice( Location loc, UnrolledVectorState &state, ArrayRef<int64_t> vectorOffsets, ArrayRef<int64_t> offsets, DenseMap<int64_t, int64_t> &indexMap, - Value *initValue, SmallVectorImpl<Value *> &cache, + ValuePtr initValue, SmallVectorImpl<ValuePtr> &cache, PatternRewriter &builder) { // Compute slice offsets. SmallVector<int64_t, 4> sliceOffsets(state.unrolledShape.size()); @@ -253,7 +253,7 @@ static Value *getOrCreateUnrolledVectorSlice( int64_t sliceLinearIndex = getUnrolledVectorLinearIndex(state, vectorOffsets, indexMap); assert(sliceLinearIndex < static_cast<int64_t>(cache.size())); - auto *valueSlice = cache[sliceLinearIndex]; + auto valueSlice = cache[sliceLinearIndex]; if (valueSlice == nullptr) { // Return tuple element at 'sliceLinearIndex'. auto tupleIndex = builder.getI64IntegerAttr(sliceLinearIndex); @@ -330,12 +330,10 @@ struct VectorState { // TODO(andydavis) Generalize this to support structured ops beyond // vector ContractionOp, and merge it with 'unrollSingleResultOpMatchingType' -static Value *unrollSingleResultStructuredOp(Operation *op, - ArrayRef<int64_t> iterationBounds, - std::vector<VectorState> &vectors, - unsigned resultIndex, - ArrayRef<int64_t> targetShape, - PatternRewriter &builder) { +static ValuePtr unrollSingleResultStructuredOp( + Operation *op, ArrayRef<int64_t> iterationBounds, + std::vector<VectorState> &vectors, unsigned resultIndex, + ArrayRef<int64_t> targetShape, PatternRewriter &builder) { auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>(); if (!shapedType || !shapedType.hasStaticShape()) assert(false && "Expected a statically shaped result type"); @@ -351,7 +349,7 @@ static Value *unrollSingleResultStructuredOp(Operation *op, SmallVector<UnrolledVectorState, 3> unrolledVectorState(numVectors); for (unsigned i = 0; i < numVectors; ++i) { int64_t operandIndex = vectors[i].operandIndex; - auto *operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; + auto operand = operandIndex >= 0 ? op->getOperand(operandIndex) : nullptr; initUnrolledVectorState(vectors[i].type, operand, vectors[i].indexMap, targetShape, unrolledVectorState[i], builder); } @@ -364,7 +362,7 @@ static Value *unrollSingleResultStructuredOp(Operation *op, shapedType.getElementType()); // Initialize caches for intermediate vector results. - std::vector<SmallVector<Value *, 4>> caches(numVectors); + std::vector<SmallVector<ValuePtr, 4>> caches(numVectors); for (unsigned i = 0; i < numVectors; ++i) caches[i].resize(unrolledVectorState[i].numInstances); @@ -376,13 +374,13 @@ static Value *unrollSingleResultStructuredOp(Operation *op, auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, targetShape); // Get cached slice (or create slice) for each operand at 'offsets'. - SmallVector<Value *, 3> operands; + SmallVector<ValuePtr, 3> operands; operands.resize(op->getNumOperands()); for (unsigned i = 0; i < numVectors; ++i) { int64_t operandIndex = vectors[i].operandIndex; if (operandIndex < 0) continue; // Output - auto *operand = op->getOperand(operandIndex); + auto operand = op->getOperand(operandIndex); operands[operandIndex] = getOrCreateUnrolledVectorSlice( op->getLoc(), unrolledVectorState[i], vectorOffsets, offsets, vectors[i].indexMap, operand, caches[i], builder); @@ -402,21 +400,21 @@ static Value *unrollSingleResultStructuredOp(Operation *op, // Create TupleOp of unrolled result vectors. SmallVector<Type, 4> vectorTupleTypes(resultValueState.numInstances); - SmallVector<Value *, 4> vectorTupleValues(resultValueState.numInstances); + SmallVector<ValuePtr, 4> vectorTupleValues(resultValueState.numInstances); for (unsigned i = 0; i < resultValueState.numInstances; ++i) { vectorTupleTypes[i] = caches[resultIndex][i]->getType().cast<VectorType>(); vectorTupleValues[i] = caches[resultIndex][i]; } TupleType tupleType = builder.getTupleType(vectorTupleTypes); - Value *tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType, - vectorTupleValues); + ValuePtr tupleOp = builder.create<vector::TupleOp>(op->getLoc(), tupleType, + vectorTupleValues); // Create InsertSlicesOp(Tuple(result_vectors)). auto resultVectorType = op->getResult(0)->getType().cast<VectorType>(); SmallVector<int64_t, 4> sizes(resultValueState.unrolledShape); SmallVector<int64_t, 4> strides(resultValueState.unrollFactors.size(), 1); - Value *insertSlicesOp = builder.create<vector::InsertSlicesOp>( + ValuePtr insertSlicesOp = builder.create<vector::InsertSlicesOp>( op->getLoc(), resultVectorType, tupleOp, builder.getI64ArrayAttr(sizes), builder.getI64ArrayAttr(strides)); return insertSlicesOp; @@ -487,7 +485,7 @@ getVectorElementwiseOpUnrollState(Operation *op, ArrayRef<int64_t> targetShape, } // Entry point for unrolling declarative pattern rewrites. -Value *mlir::vector::unrollSingleResultOpMatchingType( +ValuePtr mlir::vector::unrollSingleResultOpMatchingType( PatternRewriter &builder, Operation *op, ArrayRef<int64_t> targetShape) { assert(op->getNumResults() == 1 && "Expected single result operation"); @@ -516,8 +514,8 @@ Value *mlir::vector::unrollSingleResultOpMatchingType( static void generateTransferOpSlices(VectorType vectorType, TupleType tupleType, ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, - ArrayRef<Value *> indices, PatternRewriter &rewriter, - function_ref<void(unsigned, ArrayRef<Value *>)> fn) { + ArrayRef<ValuePtr> indices, PatternRewriter &rewriter, + function_ref<void(unsigned, ArrayRef<ValuePtr>)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); assert(maybeDimSliceCounts.hasValue()); @@ -534,13 +532,13 @@ generateTransferOpSlices(VectorType vectorType, TupleType tupleType, auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. - SmallVector<Value *, 4> sliceIndices(numSliceIndices); + SmallVector<ValuePtr, 4> sliceIndices(numSliceIndices); for (auto it : llvm::enumerate(indices)) { auto expr = getAffineDimExpr(0, ctx) + getAffineConstantExpr(offsets[it.index()], ctx); auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); sliceIndices[it.index()] = rewriter.create<AffineApplyOp>( - it.value()->getLoc(), map, ArrayRef<Value *>(it.value())); + it.value()->getLoc(), map, ArrayRef<ValuePtr>(it.value())); } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); @@ -559,7 +557,7 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { if (!xferReadOp.permutation_map().isIdentity()) return matchFailure(); // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. - Value *xferReadResult = xferReadOp.getResult(); + ValuePtr xferReadResult = xferReadOp.getResult(); auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(*xferReadResult->getUsers().begin()); if (!xferReadResult->hasOneUse() || !extractSlicesOp) @@ -576,10 +574,10 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { Location loc = xferReadOp.getLoc(); int64_t numSlices = resultTupleType.size(); - SmallVector<Value *, 4> vectorTupleValues(numSlices); - SmallVector<Value *, 4> indices(xferReadOp.indices().begin(), - xferReadOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) { + SmallVector<ValuePtr, 4> vectorTupleValues(numSlices); + SmallVector<ValuePtr, 4> indices(xferReadOp.indices().begin(), + xferReadOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef<ValuePtr> sliceIndices) { // Get VectorType for slice 'i'. auto sliceVectorType = resultTupleType.getType(index); // Create split TransferReadOp for 'sliceUser'. @@ -591,8 +589,8 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> { indices, rewriter, createSlice); // Create tuple of splice xfer read operations. - Value *tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType, - vectorTupleValues); + ValuePtr tupleOp = rewriter.create<vector::TupleOp>(loc, resultTupleType, + vectorTupleValues); // Replace 'xferReadOp' with result 'insertSlicesResult'. rewriter.replaceOpWithNewOp<vector::InsertSlicesOp>( xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), @@ -632,9 +630,9 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> { insertSlicesOp.getStrides(strides); Location loc = xferWriteOp.getLoc(); - SmallVector<Value *, 4> indices(xferWriteOp.indices().begin(), - xferWriteOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef<Value *> sliceIndices) { + SmallVector<ValuePtr, 4> indices(xferWriteOp.indices().begin(), + xferWriteOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef<ValuePtr> sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. rewriter.create<vector::TransferWriteOp>( loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, @@ -676,7 +674,7 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> { return matchFailure(); // Forward Value from 'tupleOp' at 'tupleGetOp.index'. - Value *tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); + ValuePtr tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); rewriter.replaceOp(tupleGetOp, tupleValue); return matchSuccess(); } |