diff options
author | Nicolas Vasilache <ntv@google.com> | 2019-11-14 08:10:36 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-14 08:15:23 -0800 |
commit | f2b6ae99913d0049c7929160aed5f213b1081abb (patch) | |
tree | 41ccb1f9a181eaada07f21cbc7da75ef2d229575 /mlir/lib | |
parent | 7c28de4aef6da3ab2f53118ecf717e56c68352e7 (diff) | |
download | bcm5719-llvm-f2b6ae99913d0049c7929160aed5f213b1081abb.tar.gz bcm5719-llvm-f2b6ae99913d0049c7929160aed5f213b1081abb.zip |
Move VectorOps to Tablegen - (almost) NFC
This CL moves VectorOps to Tablegen and cleans up the implementation.
This is almost NFC but 2 changes occur:
1. an interface change occurs in the padding value specification in vector_transfer_read:
the value becomes non-optional. As a shortcut we currently use %f0 for all paddings.
This should become an OpInterface for vectorization in the future.
2. the return type of vector.type_cast is trivial and simplified to `memref<vector<...>>`
Relevant roundtrip and invalid tests that used to sit in core are moved to the vector dialect.
The op documentation is moved to the .td file.
PiperOrigin-RevId: 280430869
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Analysis/VectorAnalysis.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp | 6 | ||||
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 385 | ||||
-rw-r--r-- | mlir/lib/Transforms/LowerVectorTransfers.cpp | 34 | ||||
-rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 18 | ||||
-rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 14 |
6 files changed, 136 insertions, 323 deletions
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index e765ce35e74..2dab3481e56 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -195,7 +195,7 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op, (void)mustDivide; VectorType superVectorType; if (auto read = dyn_cast<vector::VectorTransferReadOp>(op)) { - superVectorType = read.getResultType(); + superVectorType = read.getVectorType(); mustDivide = true; } else if (auto write = dyn_cast<vector::VectorTransferWriteOp>(op)) { superVectorType = write.getVectorType(); diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp index 5ccf740f2fb..21bcdc9a6db 100644 --- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp @@ -196,10 +196,10 @@ public: int64_t offset; SmallVector<int64_t, 4> strides; auto successStrides = - getStridesAndOffset(targetMemRefType, strides, offset); + getStridesAndOffset(sourceMemRefType, strides, offset); bool isContiguous = (strides.back() == 1); if (isContiguous) { - auto sizes = targetMemRefType.getShape(); + auto sizes = sourceMemRefType.getShape(); for (int index = 0, e = strides.size() - 2; index < e; ++index) { if (strides[index] != strides[index + 1] * sizes[index + 1]) { isContiguous = false; @@ -207,7 +207,7 @@ public: } } } - // Only contiguous tensors supported atm. + // Only contiguous source tensors supported atm. if (failed(successStrides) || !isContiguous) return matchFailure(); diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index 8626f241955..215e92d0947 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -37,8 +37,6 @@ using namespace mlir::vector; mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addOperations<VectorTransferReadOp, VectorTransferWriteOp, - VectorTypeCastOp>(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/VectorOps/VectorOps.cpp.inc" @@ -195,354 +193,165 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap, return success(); } -void VectorTransferReadOp::build(Builder *builder, OperationState &result, - VectorType vectorType, Value *srcMemRef, - ArrayRef<Value *> srcIndices, - AffineMap permutationMap, - Optional<Value *> paddingValue) { - result.addOperands(srcMemRef); - result.addOperands(srcIndices); - if (paddingValue) { - result.addOperands({*paddingValue}); - } - result.addAttribute(getPermutationMapAttrName(), - AffineMapAttr::get(permutationMap)); - result.addTypes(vectorType); -} - -auto VectorTransferReadOp::getIndices() -> operand_range { - auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; - auto end = begin + getMemRefType().getRank(); - return {begin, end}; -} - -Optional<Value *> VectorTransferReadOp::getPaddingValue() { - auto memRefRank = getMemRefType().getRank(); - if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { - return None; - } - return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank)); -} - -AffineMap VectorTransferReadOp::getPermutationMap() { - return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue(); -} - -void VectorTransferReadOp::print(OpAsmPrinter &p) { - p << getOperationName() << " "; - p.printOperand(getMemRef()); +static void print(OpAsmPrinter &p, VectorTransferReadOp op) { + p << op.getOperationName() << " "; + p.printOperand(op.memref()); p << "["; - p.printOperands(getIndices()); - p << "]"; - auto optionalPaddingValue = getPaddingValue(); - if (optionalPaddingValue) { - p << ", ("; - p.printOperand(*optionalPaddingValue); - p << ")"; - } - p.printOptionalAttrDict(getAttrs()); - p << " : " << getMemRefType(); - p << ", " << getResultType(); + p.printOperands(op.indices()); + p << "], "; + p.printOperand(op.padding()); + p << " "; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getMemRefType(); + p << ", " << op.getVectorType(); } -ParseResult VectorTransferReadOp::parse(OpAsmParser &parser, - OperationState &result) { +ParseResult parseVectorTransferReadOp(OpAsmParser &parser, + OperationState &result) { + llvm::SMLoc typesLoc; OpAsmParser::OperandType memrefInfo; SmallVector<OpAsmParser::OperandType, 8> indexInfo; - SmallVector<OpAsmParser::OperandType, 8> paddingInfo; + OpAsmParser::OperandType paddingInfo; SmallVector<Type, 2> types; - // Parsing with support for optional paddingValue. if (parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseTrailingOperandList(paddingInfo, - OpAsmParser::Delimiter::Paren) || + parser.parseComma() || parser.parseOperand(paddingInfo) || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonTypeList(types)) + parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) return failure(); - - // Resolution. if (types.size() != 2) - return parser.emitError(parser.getNameLoc(), "expected 2 types"); - MemRefType memrefType = types[0].dyn_cast<MemRefType>(); - if (!memrefType) - return parser.emitError(parser.getNameLoc(), "memRef type expected"); - VectorType vectorType = types[1].dyn_cast<VectorType>(); - if (!vectorType) - return parser.emitError(parser.getNameLoc(), "vector type expected"); - - // Extract optional paddingValue. - // At this point, indexInfo may contain the optional paddingValue, pop it - // out. - if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank()) - return parser.emitError(parser.getNameLoc(), - "expected " + Twine(memrefType.getRank()) + - " indices to the memref"); - if (paddingInfo.size() > 1) - return parser.emitError(parser.getNameLoc(), - "expected at most one padding value"); - Type paddingType; - bool hasOptionalPaddingValue = !paddingInfo.empty(); - if (hasOptionalPaddingValue) { - paddingType = vectorType.getElementType(); - } + return parser.emitError(typesLoc, "two types required"); auto indexType = parser.getBuilder().getIndexType(); + MemRefType memRefType = types[0].dyn_cast<MemRefType>(); + if (!memRefType) + return parser.emitError(typesLoc, "memref type required"), failure(); + Type vectorType = types[1]; return failure( - parser.resolveOperand(memrefInfo, memrefType, result.operands) || + parser.resolveOperand(memrefInfo, memRefType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands) || - (hasOptionalPaddingValue && - parser.resolveOperand(paddingInfo[0], paddingType, result.operands)) || + parser.resolveOperand(paddingInfo, memRefType.getElementType(), + result.operands) || parser.addTypeToList(vectorType, result.types)); } -LogicalResult VectorTransferReadOp::verify() { - // Consistency of memref type in function type. - if (llvm::empty(getOperands())) { - return emitOpError( - "requires at least a memref operand followed by 'rank' indices"); - } - if (!getMemRef()->getType().isa<MemRefType>()) { - return emitOpError("requires a memref as first operand"); - } - // Consistency of vector type in function type. - if (!getResult()->getType().isa<VectorType>()) { - return emitOpError("should have a vector result type in function type: " - "memref_type<...xelemental_type>, vector_type"); - } +static LogicalResult verify(VectorTransferReadOp op) { // Consistency of elemental types in memref and vector. - MemRefType memrefType = getMemRefType(); - VectorType vectorType = getResultType(); + MemRefType memrefType = op.getMemRefType(); + VectorType vectorType = op.getVectorType(); if (memrefType.getElementType() != vectorType.getElementType()) - return emitOpError( + return op.emitOpError( "requires memref and vector types of the same elemental type"); - // Consistency of number of input types. - auto optionalPaddingValue = getPaddingValue(); - unsigned expectedNumOperands = Offsets::FirstIndexOffset + - memrefType.getRank() + - (optionalPaddingValue ? 1 : 0); - // Checks on the actual operands and their types. - if (getNumOperands() != expectedNumOperands) { - return emitOpError("expects ") - << expectedNumOperands << " operands (of which " - << memrefType.getRank() << " indices)"; - } - // Consistency of padding value with vector type. - if (optionalPaddingValue) { - auto paddingValue = *optionalPaddingValue; - auto elementalType = paddingValue->getType(); - if (!VectorType::isValidElementType(elementalType)) { - return emitOpError("requires valid padding vector elemental type"); - } - if (elementalType != vectorType.getElementType()) { - return emitOpError( - "requires formal padding and vector of the same elemental type"); - } - } - // Consistency of indices types. - unsigned numIndices = 0; - for (auto *idx : getIndices()) { - if (!idx->getType().isIndex()) { - return emitOpError( - "index to vector.transfer_read must have 'index' type"); - } - ++numIndices; - } - if (numIndices != memrefType.getRank()) { - return emitOpError("requires at least a memref operand followed by ") - << memrefType.getRank() << " indices"; - } - - // Consistency of AffineMap attribute. - if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) { - return emitOpError("requires an AffineMapAttr named 'permutation_map'"); - } - auto permutationMap = getPermutationMap(); - if (permutationMap.getNumSymbols() != 0) { - return emitOpError("requires a permutation_map without symbols"); - } - if (permutationMap.getNumInputs() != memrefType.getRank()) { - return emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type"); - } - if (permutationMap.getNumResults() != vectorType.getRank()) { - return emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type (") - << permutationMap.getNumResults() << " vs " << vectorType.getRank(); - } + auto elementalType = op.padding()->getType(); + if (!VectorType::isValidElementType(elementalType)) + return op.emitOpError("requires valid padding vector elemental type"); + if (elementalType != vectorType.getElementType()) + return op.emitOpError( + "requires formal padding and vector of the same elemental type"); + if (llvm::size(op.indices()) != memrefType.getRank()) + return op.emitOpError("requires ") << memrefType.getRank() << " indices"; + auto permutationMap = op.permutation_map(); + if (permutationMap.getNumSymbols() != 0) + return op.emitOpError("requires permutation_map without symbols"); + if (permutationMap.getNumInputs() != memrefType.getRank()) + return op.emitOpError("requires a permutation_map with input dims of the " + "same rank as the memref type"); + if (permutationMap.getNumResults() != vectorType.getRank()) + return op.emitOpError("requires a permutation_map with result dims of the " + "same rank as the vector type"); return verifyPermutationMap(permutationMap, - [this](Twine t) { return emitOpError(t); }); + [&op](Twine t) { return op.emitOpError(t); }); } //===----------------------------------------------------------------------===// // VectorTransferWriteOp //===----------------------------------------------------------------------===// -void VectorTransferWriteOp::build(Builder *builder, OperationState &result, - Value *srcVector, Value *dstMemRef, - ArrayRef<Value *> dstIndices, - AffineMap permutationMap) { - result.addOperands({srcVector, dstMemRef}); - result.addOperands(dstIndices); - result.addAttribute(getPermutationMapAttrName(), - AffineMapAttr::get(permutationMap)); -} - -auto VectorTransferWriteOp::getIndices() -> operand_range { - auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; - auto end = begin + getMemRefType().getRank(); - return {begin, end}; -} - -AffineMap VectorTransferWriteOp::getPermutationMap() { - return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue(); -} - -void VectorTransferWriteOp::print(OpAsmPrinter &p) { - p << getOperationName(); - p << " " << *getVector(); - p << ", " << *getMemRef(); +static void print(OpAsmPrinter &p, VectorTransferWriteOp op) { + p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref(); p << "["; - p.printOperands(getIndices()); + p.printOperands(op.indices()); p << "]"; - p.printOptionalAttrDict(getAttrs()); + p.printOptionalAttrDict(op.getAttrs()); p << " : "; - p.printType(getVectorType()); + p.printType(op.getVectorType()); p << ", "; - p.printType(getMemRefType()); + p.printType(op.getMemRefType()); } -ParseResult VectorTransferWriteOp::parse(OpAsmParser &parser, - OperationState &result) { +ParseResult parseVectorTransferWriteOp(OpAsmParser &parser, + OperationState &result) { + llvm::SMLoc typesLoc; OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType memrefInfo; + OpAsmParser::OperandType memRefInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<Type, 2> types; - auto indexType = parser.getBuilder().getIndexType(); if (parser.parseOperand(storeValueInfo) || parser.parseComma() || - parser.parseOperand(memrefInfo) || + parser.parseOperand(memRefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonTypeList(types)) + parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) return failure(); - if (types.size() != 2) - return parser.emitError(parser.getNameLoc(), "expected 2 types"); - VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>(); - if (!vectorType) - return parser.emitError(parser.getNameLoc(), "vector type expected"); - MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>(); - if (!memrefType) - return parser.emitError(parser.getNameLoc(), "memRef type expected"); - + return parser.emitError(typesLoc, "two types required"); + auto indexType = parser.getBuilder().getIndexType(); + Type vectorType = types[0], memRefType = types[1]; return failure( - parser.resolveOperands(storeValueInfo, vectorType, result.operands) || - parser.resolveOperands(memrefInfo, memrefType, result.operands) || + parser.resolveOperand(storeValueInfo, vectorType, result.operands) || + parser.resolveOperand(memRefInfo, memRefType, result.operands) || parser.resolveOperands(indexInfo, indexType, result.operands)); } -LogicalResult VectorTransferWriteOp::verify() { - // Consistency of memref type in function type. - if (llvm::empty(getOperands())) { - return emitOpError( - "requires at least a memref operand followed by 'rank' indices"); - } - if (!getMemRef()->getType().isa<MemRefType>()) { - return emitOpError("requires a memref first operand"); - } - // Consistency of vector type in function type. - if (!getVector()->getType().isa<VectorType>()) { - return emitOpError("should have a vector input type in function type: " - "(vector_type, memref_type [, elemental_type]) -> ()"); - } +static LogicalResult verify(VectorTransferWriteOp op) { // Consistency of elemental types in memref and vector. - MemRefType memrefType = getMemRefType(); - VectorType vectorType = getVectorType(); + MemRefType memrefType = op.getMemRefType(); + VectorType vectorType = op.getVectorType(); if (memrefType.getElementType() != vectorType.getElementType()) - return emitOpError( + return op.emitOpError( "requires memref and vector types of the same elemental type"); - // Consistency of number of input types. - unsigned expectedNumOperands = - Offsets::FirstIndexOffset + memrefType.getRank(); - // Checks on the actual operands and their types. - if (getNumOperands() != expectedNumOperands) { - return emitOpError() << "expects " << expectedNumOperands - << " operands (of which " << memrefType.getRank() - << " indices)"; - } - // Consistency of indices types. - unsigned numIndices = 0; - for (auto *idx : getIndices()) { - if (!idx->getType().isIndex()) { - return emitOpError( - "index to vector.transfer_write must have 'index' type"); - } - numIndices++; - } - if (numIndices != memrefType.getRank()) { - return emitOpError("requires at least a memref operand followed by ") - << memrefType.getRank() << " indices"; - } + if (llvm::size(op.indices()) != memrefType.getRank()) + return op.emitOpError("requires ") << memrefType.getRank() << " indices"; // Consistency of AffineMap attribute. - if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) { - return emitOpError("requires an AffineMapAttr named 'permutation_map'"); - } - auto permutationMap = getPermutationMap(); - if (permutationMap.getNumSymbols() != 0) { - return emitOpError("requires a permutation_map without symbols"); - } - if (permutationMap.getNumInputs() != memrefType.getRank()) { - return emitOpError("requires a permutation_map with input dims of the " - "same rank as the memref type"); - } - if (permutationMap.getNumResults() != vectorType.getRank()) { - return emitOpError("requires a permutation_map with result dims of the " - "same rank as the vector type (") - << permutationMap.getNumResults() << " vs " << vectorType.getRank(); - } + auto permutationMap = op.permutation_map(); + if (permutationMap.getNumSymbols() != 0) + return op.emitOpError("requires a symbol-less permutation_map"); + if (permutationMap.getNumInputs() != memrefType.getRank()) + return op.emitOpError("requires a permutation_map with input dims of the " + "same rank as the memref type: ") + << permutationMap.getNumInputs() << " vs " << memrefType; + if (permutationMap.getNumResults() != vectorType.getRank()) + return op.emitOpError("requires a permutation_map with result dims of the " + "same rank as the vector type.") + << permutationMap.getNumResults() << " vs " << vectorType; return verifyPermutationMap(permutationMap, - [this](Twine t) { return emitOpError(t); }); + [&op](Twine t) { return op.emitOpError(t); }); } //===----------------------------------------------------------------------===// // VectorTypeCastOp //===----------------------------------------------------------------------===// -void VectorTypeCastOp::build(Builder *builder, OperationState &result, - Value *srcVector, Type dstType) { - result.addOperands(srcVector); - result.addTypes(dstType); -} -ParseResult VectorTypeCastOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::OperandType operand; - Type srcType, dstType; - return failure(parser.parseOperand(operand) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(srcType) || parser.parseComma() || - parser.parseType(dstType) || - parser.addTypeToList(dstType, result.types) || - parser.resolveOperand(operand, srcType, result.operands)); +static MemRefType inferVectorTypeCastResultType(MemRefType t) { + return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType())); } -void VectorTypeCastOp::print(OpAsmPrinter &p) { - p << getOperationName() << ' ' << *getOperand() << " : " - << getOperand()->getType() << ", " << getType(); +void VectorTypeCastOp::build(Builder *builder, OperationState &result, + Value *source) { + result.addOperands(source); + result.addTypes( + inferVectorTypeCastResultType(source->getType().cast<MemRefType>())); } -LogicalResult VectorTypeCastOp::verify() { - auto dstMemrefType = getType().dyn_cast<MemRefType>(); - if (!dstMemrefType) - return emitOpError("expects target type to be a memref type"); - auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>(); - if (!dstVectorType) - return emitOpError( - "expects vector as an element of the target memref type"); - if (!dstMemrefType.hasStaticShape()) - return emitOpError("does not support dynamic shapes"); - - if (!getOperand()->getType().isa<MemRefType>()) - return emitOpError("expects source type to be a memref type"); +static void print(OpAsmPrinter &p, VectorTypeCastOp &op) { + auto type = op.getOperand()->getType().cast<MemRefType>(); + p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to " + << inferVectorTypeCastResultType(type); +} +static LogicalResult verify(VectorTypeCastOp &op) { + auto resultType = inferVectorTypeCastResultType(op.getMemRefType()); + if (op.getResultMemRefType() != resultType) + return op.emitOpError("expects result type to be: ") << resultType; return success(); } diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp index c517d74f221..57dd18dac0f 100644 --- a/mlir/lib/Transforms/LowerVectorTransfers.cpp +++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp @@ -113,12 +113,6 @@ struct VectorTransferRewriter : public RewritePattern { {}, 0); } - /// View of tmpMemRefType as one vector, used in vector load/store to tmp - /// buffer. - MemRefType vectorMemRefType(VectorTransferOpTy transfer) const { - return MemRefType::get({1}, transfer.getVectorType(), {}, 0); - } - /// Performs the rewrite. PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; @@ -139,7 +133,7 @@ void coalesceCopy(VectorTransferOpTy transfer, // the loop order for creating pointwise copies between remote and local // memories. int coalescedIdx = -1; - auto exprs = transfer.getPermutationMap().getResults(); + auto exprs = transfer.permutation_map().getResults(); for (auto en : llvm::enumerate(exprs)) { auto dim = en.value().template dyn_cast<AffineDimExpr>(); if (!dim) { @@ -170,7 +164,7 @@ llvm::SmallVector<edsc::ValueHandle, 8> clip(VectorTransferOpTy transfer, using edsc::intrinsics::select; IndexHandle zero(index_t(0)), one(index_t(1)); - llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.getIndices()); + llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices()); llvm::SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs( memRefAccess.size(), edsc::IndexHandle()); @@ -180,7 +174,7 @@ llvm::SmallVector<edsc::ValueHandle, 8> clip(VectorTransferOpTy transfer, ++memRefDim) { // Linear search on a small number of entries. int loopIndex = -1; - auto exprs = transfer.getPermutationMap().getResults(); + auto exprs = transfer.permutation_map().getResults(); for (auto en : llvm::enumerate(exprs)) { auto expr = en.value(); auto dim = expr.template dyn_cast<AffineDimExpr>(); @@ -273,9 +267,9 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite( // 1. Setup all the captures. ScopedContext scope(rewriter, transfer.getLoc()); - IndexedValue remote(transfer.getMemRef()); - MemRefView view(transfer.getMemRef()); - VectorView vectorView(transfer.getVector()); + IndexedValue remote(transfer.memref()); + MemRefView view(transfer.memref()); + VectorView vectorView(transfer.vector()); SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank()); SmallVector<ValueHandle *, 8> pivs = makeIndexHandlePointers(MutableArrayRef<IndexHandle>(ivs)); @@ -291,12 +285,12 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite( // 2. Emit alloc-copy-load-dealloc. ValueHandle tmp = alloc(tmpMemRefType(transfer)); IndexedValue local(tmp); - ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer)); + ValueHandle vec = vector_type_cast(tmp); LoopNestBuilder(pivs, lbs, ubs, steps)([&] { // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). local(ivs) = remote(clip(transfer, view, ivs)); }); - ValueHandle vectorValue = std_load(vec, {constant_index(0)}); + ValueHandle vectorValue = std_load(vec); (dealloc(tmp)); // vexing parse // 3. Propagate. @@ -336,10 +330,10 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite( // 1. Setup all the captures. ScopedContext scope(rewriter, transfer.getLoc()); - IndexedValue remote(transfer.getMemRef()); - MemRefView view(transfer.getMemRef()); - ValueHandle vectorValue(transfer.getVector()); - VectorView vectorView(transfer.getVector()); + IndexedValue remote(transfer.memref()); + MemRefView view(transfer.memref()); + ValueHandle vectorValue(transfer.vector()); + VectorView vectorView(transfer.vector()); SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank()); SmallVector<ValueHandle *, 8> pivs = makeIndexHandlePointers(ivs); coalesceCopy(transfer, &pivs, &vectorView); @@ -354,8 +348,8 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite( // 2. Emit alloc-store-copy-dealloc. ValueHandle tmp = alloc(tmpMemRefType(transfer)); IndexedValue local(tmp); - ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer)); - std_store(vectorValue, vec, {constant_index(0)}); + ValueHandle vec = vector_type_cast(tmp); + std_store(vectorValue, vec); LoopNestBuilder(pivs, lbs, ubs, steps)([&] { // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). remote(clip(transfer, view, ivs)) = local(ivs); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index a0b60dd3648..06016da5caa 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -465,7 +465,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer, ++dim; }, superVectorType.getShape(), *optionalRatio); - auto permutationMap = transfer.getPermutationMap(); + auto permutationMap = transfer.permutation_map(); LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: ")); if (keep.empty()) { return permutationMap; @@ -486,16 +486,16 @@ static Operation *instantiate(OpBuilder b, VectorTransferReadOp read, ArrayRef<unsigned> hwVectorInstance, DenseMap<Value *, Value *> *substitutionsMap) { SmallVector<Value *, 8> indices = - map(makePtrDynCaster<Value>(), read.getIndices()); + map(makePtrDynCaster<Value>(), read.indices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto map = projectedPermutationMap(read, hwVectorType); if (!map) { return nullptr; } - auto cloned = b.create<VectorTransferReadOp>(read.getLoc(), hwVectorType, - read.getMemRef(), affineIndices, - map, read.getPaddingValue()); + auto cloned = b.create<VectorTransferReadOp>( + read.getLoc(), hwVectorType, read.memref(), affineIndices, + AffineMapAttr::get(map), read.padding()); return cloned.getOperation(); } @@ -510,14 +510,14 @@ static Operation *instantiate(OpBuilder b, VectorTransferWriteOp write, ArrayRef<unsigned> hwVectorInstance, DenseMap<Value *, Value *> *substitutionsMap) { SmallVector<Value *, 8> indices = - map(makePtrDynCaster<Value>(), write.getIndices()); + map(makePtrDynCaster<Value>(), write.indices()); auto affineIndices = reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); auto cloned = b.create<VectorTransferWriteOp>( write.getLoc(), - substitute(write.getVector(), hwVectorType, substitutionsMap), - write.getMemRef(), affineIndices, - projectedPermutationMap(write, hwVectorType)); + substitute(write.vector(), hwVectorType, substitutionsMap), + write.memref(), affineIndices, + AffineMapAttr::get(projectedPermutationMap(write, hwVectorType))); return cloned.getOperation(); } diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index a1e87568745..b3eea35a55f 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -35,6 +35,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMap.h" @@ -718,6 +719,8 @@ struct VectorizationState { // Checks that the type of `op` is AffineStoreOp and adds it to the terminals // set. void registerTerminal(Operation *op); + // Folder used to factor out constant creation. + OperationFolder *folder; private: void registerReplacement(Value *key, Value *value); @@ -832,7 +835,11 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<vector::VectorTransferReadOp>( opInst->getLoc(), vectorType, memoryOp.getMemRef(), - map(makePtrDynCaster<Value>(), indices), permutationMap); + map(makePtrDynCaster<Value>(), indices), + AffineMapAttr::get(permutationMap), + // TODO(b/144455320) add a proper padding value, not just 0.0 : f32 + state->folder->create<ConstantFloatOp>( + b, opInst->getLoc(), llvm::APFloat(0.0f), b.getF32Type())); state->registerReplacement(opInst, transfer.getOperation()); } else { state->registerTerminal(opInst); @@ -1058,7 +1065,8 @@ static Operation *vectorizeOneOperation(Operation *opInst, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<vector::VectorTransferWriteOp>( - opInst->getLoc(), vectorValue, memRef, indices, permutationMap); + opInst->getLoc(), vectorValue, memRef, indices, + AffineMapAttr::get(permutationMap)); auto *res = transfer.getOperation(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminals" (i.e. AffineStoreOps) are erased on the spot. @@ -1152,8 +1160,10 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) { static LogicalResult vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { auto loop = cast<AffineForOp>(m.getMatchedOperation()); + OperationFolder folder(loop.getContext()); VectorizationState state; state.strategy = strategy; + state.folder = &folder; // Since patterns are recursive, they can very well intersect. // Since we do not want a fully greedy strategy in general, we decouple |