summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-11-14 08:10:36 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-14 08:15:23 -0800
commitf2b6ae99913d0049c7929160aed5f213b1081abb (patch)
tree41ccb1f9a181eaada07f21cbc7da75ef2d229575 /mlir/lib
parent7c28de4aef6da3ab2f53118ecf717e56c68352e7 (diff)
downloadbcm5719-llvm-f2b6ae99913d0049c7929160aed5f213b1081abb.tar.gz
bcm5719-llvm-f2b6ae99913d0049c7929160aed5f213b1081abb.zip
Move VectorOps to Tablegen - (almost) NFC
This CL moves VectorOps to Tablegen and cleans up the implementation. This is almost NFC but 2 changes occur: 1. an interface change occurs in the padding value specification in vector_transfer_read: the value becomes non-optional. As a shortcut we currently use %f0 for all paddings. This should become an OpInterface for vectorization in the future. 2. the return type of vector.type_cast is trivial and simplified to `memref<vector<...>>` Relevant roundtrip and invalid tests that used to sit in core are moved to the vector dialect. The op documentation is moved to the .td file. PiperOrigin-RevId: 280430869
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/VectorAnalysis.cpp2
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp6
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp385
-rw-r--r--mlir/lib/Transforms/LowerVectorTransfers.cpp34
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp18
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp14
6 files changed, 136 insertions, 323 deletions
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp
index e765ce35e74..2dab3481e56 100644
--- a/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -195,7 +195,7 @@ bool mlir::matcher::operatesOnSuperVectorsOf(Operation &op,
(void)mustDivide;
VectorType superVectorType;
if (auto read = dyn_cast<vector::VectorTransferReadOp>(op)) {
- superVectorType = read.getResultType();
+ superVectorType = read.getVectorType();
mustDivide = true;
} else if (auto write = dyn_cast<vector::VectorTransferWriteOp>(op)) {
superVectorType = write.getVectorType();
diff --git a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
index 5ccf740f2fb..21bcdc9a6db 100644
--- a/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/VectorToLLVM.cpp
@@ -196,10 +196,10 @@ public:
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides =
- getStridesAndOffset(targetMemRefType, strides, offset);
+ getStridesAndOffset(sourceMemRefType, strides, offset);
bool isContiguous = (strides.back() == 1);
if (isContiguous) {
- auto sizes = targetMemRefType.getShape();
+ auto sizes = sourceMemRefType.getShape();
for (int index = 0, e = strides.size() - 2; index < e; ++index) {
if (strides[index] != strides[index + 1] * sizes[index + 1]) {
isContiguous = false;
@@ -207,7 +207,7 @@ public:
}
}
}
- // Only contiguous tensors supported atm.
+ // Only contiguous source tensors supported atm.
if (failed(successStrides) || !isContiguous)
return matchFailure();
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index 8626f241955..215e92d0947 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -37,8 +37,6 @@ using namespace mlir::vector;
mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
- addOperations<VectorTransferReadOp, VectorTransferWriteOp,
- VectorTypeCastOp>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/VectorOps/VectorOps.cpp.inc"
@@ -195,354 +193,165 @@ static LogicalResult verifyPermutationMap(AffineMap permutationMap,
return success();
}
-void VectorTransferReadOp::build(Builder *builder, OperationState &result,
- VectorType vectorType, Value *srcMemRef,
- ArrayRef<Value *> srcIndices,
- AffineMap permutationMap,
- Optional<Value *> paddingValue) {
- result.addOperands(srcMemRef);
- result.addOperands(srcIndices);
- if (paddingValue) {
- result.addOperands({*paddingValue});
- }
- result.addAttribute(getPermutationMapAttrName(),
- AffineMapAttr::get(permutationMap));
- result.addTypes(vectorType);
-}
-
-auto VectorTransferReadOp::getIndices() -> operand_range {
- auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
- auto end = begin + getMemRefType().getRank();
- return {begin, end};
-}
-
-Optional<Value *> VectorTransferReadOp::getPaddingValue() {
- auto memRefRank = getMemRefType().getRank();
- if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
- return None;
- }
- return Optional<Value *>(getOperand(Offsets::FirstIndexOffset + memRefRank));
-}
-
-AffineMap VectorTransferReadOp::getPermutationMap() {
- return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
-}
-
-void VectorTransferReadOp::print(OpAsmPrinter &p) {
- p << getOperationName() << " ";
- p.printOperand(getMemRef());
+static void print(OpAsmPrinter &p, VectorTransferReadOp op) {
+ p << op.getOperationName() << " ";
+ p.printOperand(op.memref());
p << "[";
- p.printOperands(getIndices());
- p << "]";
- auto optionalPaddingValue = getPaddingValue();
- if (optionalPaddingValue) {
- p << ", (";
- p.printOperand(*optionalPaddingValue);
- p << ")";
- }
- p.printOptionalAttrDict(getAttrs());
- p << " : " << getMemRefType();
- p << ", " << getResultType();
+ p.printOperands(op.indices());
+ p << "], ";
+ p.printOperand(op.padding());
+ p << " ";
+ p.printOptionalAttrDict(op.getAttrs());
+ p << " : " << op.getMemRefType();
+ p << ", " << op.getVectorType();
}
-ParseResult VectorTransferReadOp::parse(OpAsmParser &parser,
- OperationState &result) {
+ParseResult parseVectorTransferReadOp(OpAsmParser &parser,
+ OperationState &result) {
+ llvm::SMLoc typesLoc;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 8> indexInfo;
- SmallVector<OpAsmParser::OperandType, 8> paddingInfo;
+ OpAsmParser::OperandType paddingInfo;
SmallVector<Type, 2> types;
-
// Parsing with support for optional paddingValue.
if (parser.parseOperand(memrefInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
- parser.parseTrailingOperandList(paddingInfo,
- OpAsmParser::Delimiter::Paren) ||
+ parser.parseComma() || parser.parseOperand(paddingInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonTypeList(types))
+ parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
-
- // Resolution.
if (types.size() != 2)
- return parser.emitError(parser.getNameLoc(), "expected 2 types");
- MemRefType memrefType = types[0].dyn_cast<MemRefType>();
- if (!memrefType)
- return parser.emitError(parser.getNameLoc(), "memRef type expected");
- VectorType vectorType = types[1].dyn_cast<VectorType>();
- if (!vectorType)
- return parser.emitError(parser.getNameLoc(), "vector type expected");
-
- // Extract optional paddingValue.
- // At this point, indexInfo may contain the optional paddingValue, pop it
- // out.
- if (static_cast<int64_t>(indexInfo.size()) != memrefType.getRank())
- return parser.emitError(parser.getNameLoc(),
- "expected " + Twine(memrefType.getRank()) +
- " indices to the memref");
- if (paddingInfo.size() > 1)
- return parser.emitError(parser.getNameLoc(),
- "expected at most one padding value");
- Type paddingType;
- bool hasOptionalPaddingValue = !paddingInfo.empty();
- if (hasOptionalPaddingValue) {
- paddingType = vectorType.getElementType();
- }
+ return parser.emitError(typesLoc, "two types required");
auto indexType = parser.getBuilder().getIndexType();
+ MemRefType memRefType = types[0].dyn_cast<MemRefType>();
+ if (!memRefType)
+ return parser.emitError(typesLoc, "memref type required"), failure();
+ Type vectorType = types[1];
return failure(
- parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
+ parser.resolveOperand(memrefInfo, memRefType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands) ||
- (hasOptionalPaddingValue &&
- parser.resolveOperand(paddingInfo[0], paddingType, result.operands)) ||
+ parser.resolveOperand(paddingInfo, memRefType.getElementType(),
+ result.operands) ||
parser.addTypeToList(vectorType, result.types));
}
-LogicalResult VectorTransferReadOp::verify() {
- // Consistency of memref type in function type.
- if (llvm::empty(getOperands())) {
- return emitOpError(
- "requires at least a memref operand followed by 'rank' indices");
- }
- if (!getMemRef()->getType().isa<MemRefType>()) {
- return emitOpError("requires a memref as first operand");
- }
- // Consistency of vector type in function type.
- if (!getResult()->getType().isa<VectorType>()) {
- return emitOpError("should have a vector result type in function type: "
- "memref_type<...xelemental_type>, vector_type");
- }
+static LogicalResult verify(VectorTransferReadOp op) {
// Consistency of elemental types in memref and vector.
- MemRefType memrefType = getMemRefType();
- VectorType vectorType = getResultType();
+ MemRefType memrefType = op.getMemRefType();
+ VectorType vectorType = op.getVectorType();
if (memrefType.getElementType() != vectorType.getElementType())
- return emitOpError(
+ return op.emitOpError(
"requires memref and vector types of the same elemental type");
- // Consistency of number of input types.
- auto optionalPaddingValue = getPaddingValue();
- unsigned expectedNumOperands = Offsets::FirstIndexOffset +
- memrefType.getRank() +
- (optionalPaddingValue ? 1 : 0);
- // Checks on the actual operands and their types.
- if (getNumOperands() != expectedNumOperands) {
- return emitOpError("expects ")
- << expectedNumOperands << " operands (of which "
- << memrefType.getRank() << " indices)";
- }
- // Consistency of padding value with vector type.
- if (optionalPaddingValue) {
- auto paddingValue = *optionalPaddingValue;
- auto elementalType = paddingValue->getType();
- if (!VectorType::isValidElementType(elementalType)) {
- return emitOpError("requires valid padding vector elemental type");
- }
- if (elementalType != vectorType.getElementType()) {
- return emitOpError(
- "requires formal padding and vector of the same elemental type");
- }
- }
- // Consistency of indices types.
- unsigned numIndices = 0;
- for (auto *idx : getIndices()) {
- if (!idx->getType().isIndex()) {
- return emitOpError(
- "index to vector.transfer_read must have 'index' type");
- }
- ++numIndices;
- }
- if (numIndices != memrefType.getRank()) {
- return emitOpError("requires at least a memref operand followed by ")
- << memrefType.getRank() << " indices";
- }
-
- // Consistency of AffineMap attribute.
- if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
- return emitOpError("requires an AffineMapAttr named 'permutation_map'");
- }
- auto permutationMap = getPermutationMap();
- if (permutationMap.getNumSymbols() != 0) {
- return emitOpError("requires a permutation_map without symbols");
- }
- if (permutationMap.getNumInputs() != memrefType.getRank()) {
- return emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
- }
- if (permutationMap.getNumResults() != vectorType.getRank()) {
- return emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type (")
- << permutationMap.getNumResults() << " vs " << vectorType.getRank();
- }
+ auto elementalType = op.padding()->getType();
+ if (!VectorType::isValidElementType(elementalType))
+ return op.emitOpError("requires valid padding vector elemental type");
+ if (elementalType != vectorType.getElementType())
+ return op.emitOpError(
+ "requires formal padding and vector of the same elemental type");
+ if (llvm::size(op.indices()) != memrefType.getRank())
+ return op.emitOpError("requires ") << memrefType.getRank() << " indices";
+ auto permutationMap = op.permutation_map();
+ if (permutationMap.getNumSymbols() != 0)
+ return op.emitOpError("requires permutation_map without symbols");
+ if (permutationMap.getNumInputs() != memrefType.getRank())
+ return op.emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type");
+ if (permutationMap.getNumResults() != vectorType.getRank())
+ return op.emitOpError("requires a permutation_map with result dims of the "
+ "same rank as the vector type");
return verifyPermutationMap(permutationMap,
- [this](Twine t) { return emitOpError(t); });
+ [&op](Twine t) { return op.emitOpError(t); });
}
//===----------------------------------------------------------------------===//
// VectorTransferWriteOp
//===----------------------------------------------------------------------===//
-void VectorTransferWriteOp::build(Builder *builder, OperationState &result,
- Value *srcVector, Value *dstMemRef,
- ArrayRef<Value *> dstIndices,
- AffineMap permutationMap) {
- result.addOperands({srcVector, dstMemRef});
- result.addOperands(dstIndices);
- result.addAttribute(getPermutationMapAttrName(),
- AffineMapAttr::get(permutationMap));
-}
-
-auto VectorTransferWriteOp::getIndices() -> operand_range {
- auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
- auto end = begin + getMemRefType().getRank();
- return {begin, end};
-}
-
-AffineMap VectorTransferWriteOp::getPermutationMap() {
- return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
-}
-
-void VectorTransferWriteOp::print(OpAsmPrinter &p) {
- p << getOperationName();
- p << " " << *getVector();
- p << ", " << *getMemRef();
+static void print(OpAsmPrinter &p, VectorTransferWriteOp op) {
+ p << op.getOperationName() << " " << *op.vector() << ", " << *op.memref();
p << "[";
- p.printOperands(getIndices());
+ p.printOperands(op.indices());
p << "]";
- p.printOptionalAttrDict(getAttrs());
+ p.printOptionalAttrDict(op.getAttrs());
p << " : ";
- p.printType(getVectorType());
+ p.printType(op.getVectorType());
p << ", ";
- p.printType(getMemRefType());
+ p.printType(op.getMemRefType());
}
-ParseResult VectorTransferWriteOp::parse(OpAsmParser &parser,
- OperationState &result) {
+ParseResult parseVectorTransferWriteOp(OpAsmParser &parser,
+ OperationState &result) {
+ llvm::SMLoc typesLoc;
OpAsmParser::OperandType storeValueInfo;
- OpAsmParser::OperandType memrefInfo;
+ OpAsmParser::OperandType memRefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
SmallVector<Type, 2> types;
- auto indexType = parser.getBuilder().getIndexType();
if (parser.parseOperand(storeValueInfo) || parser.parseComma() ||
- parser.parseOperand(memrefInfo) ||
+ parser.parseOperand(memRefInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonTypeList(types))
+ parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
return failure();
-
if (types.size() != 2)
- return parser.emitError(parser.getNameLoc(), "expected 2 types");
- VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
- if (!vectorType)
- return parser.emitError(parser.getNameLoc(), "vector type expected");
- MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
- if (!memrefType)
- return parser.emitError(parser.getNameLoc(), "memRef type expected");
-
+ return parser.emitError(typesLoc, "two types required");
+ auto indexType = parser.getBuilder().getIndexType();
+ Type vectorType = types[0], memRefType = types[1];
return failure(
- parser.resolveOperands(storeValueInfo, vectorType, result.operands) ||
- parser.resolveOperands(memrefInfo, memrefType, result.operands) ||
+ parser.resolveOperand(storeValueInfo, vectorType, result.operands) ||
+ parser.resolveOperand(memRefInfo, memRefType, result.operands) ||
parser.resolveOperands(indexInfo, indexType, result.operands));
}
-LogicalResult VectorTransferWriteOp::verify() {
- // Consistency of memref type in function type.
- if (llvm::empty(getOperands())) {
- return emitOpError(
- "requires at least a memref operand followed by 'rank' indices");
- }
- if (!getMemRef()->getType().isa<MemRefType>()) {
- return emitOpError("requires a memref first operand");
- }
- // Consistency of vector type in function type.
- if (!getVector()->getType().isa<VectorType>()) {
- return emitOpError("should have a vector input type in function type: "
- "(vector_type, memref_type [, elemental_type]) -> ()");
- }
+static LogicalResult verify(VectorTransferWriteOp op) {
// Consistency of elemental types in memref and vector.
- MemRefType memrefType = getMemRefType();
- VectorType vectorType = getVectorType();
+ MemRefType memrefType = op.getMemRefType();
+ VectorType vectorType = op.getVectorType();
if (memrefType.getElementType() != vectorType.getElementType())
- return emitOpError(
+ return op.emitOpError(
"requires memref and vector types of the same elemental type");
- // Consistency of number of input types.
- unsigned expectedNumOperands =
- Offsets::FirstIndexOffset + memrefType.getRank();
- // Checks on the actual operands and their types.
- if (getNumOperands() != expectedNumOperands) {
- return emitOpError() << "expects " << expectedNumOperands
- << " operands (of which " << memrefType.getRank()
- << " indices)";
- }
- // Consistency of indices types.
- unsigned numIndices = 0;
- for (auto *idx : getIndices()) {
- if (!idx->getType().isIndex()) {
- return emitOpError(
- "index to vector.transfer_write must have 'index' type");
- }
- numIndices++;
- }
- if (numIndices != memrefType.getRank()) {
- return emitOpError("requires at least a memref operand followed by ")
- << memrefType.getRank() << " indices";
- }
+ if (llvm::size(op.indices()) != memrefType.getRank())
+ return op.emitOpError("requires ") << memrefType.getRank() << " indices";
// Consistency of AffineMap attribute.
- if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
- return emitOpError("requires an AffineMapAttr named 'permutation_map'");
- }
- auto permutationMap = getPermutationMap();
- if (permutationMap.getNumSymbols() != 0) {
- return emitOpError("requires a permutation_map without symbols");
- }
- if (permutationMap.getNumInputs() != memrefType.getRank()) {
- return emitOpError("requires a permutation_map with input dims of the "
- "same rank as the memref type");
- }
- if (permutationMap.getNumResults() != vectorType.getRank()) {
- return emitOpError("requires a permutation_map with result dims of the "
- "same rank as the vector type (")
- << permutationMap.getNumResults() << " vs " << vectorType.getRank();
- }
+ auto permutationMap = op.permutation_map();
+ if (permutationMap.getNumSymbols() != 0)
+ return op.emitOpError("requires a symbol-less permutation_map");
+ if (permutationMap.getNumInputs() != memrefType.getRank())
+ return op.emitOpError("requires a permutation_map with input dims of the "
+ "same rank as the memref type: ")
+ << permutationMap.getNumInputs() << " vs " << memrefType;
+ if (permutationMap.getNumResults() != vectorType.getRank())
+ return op.emitOpError("requires a permutation_map with result dims of the "
+ "same rank as the vector type.")
+ << permutationMap.getNumResults() << " vs " << vectorType;
return verifyPermutationMap(permutationMap,
- [this](Twine t) { return emitOpError(t); });
+ [&op](Twine t) { return op.emitOpError(t); });
}
//===----------------------------------------------------------------------===//
// VectorTypeCastOp
//===----------------------------------------------------------------------===//
-void VectorTypeCastOp::build(Builder *builder, OperationState &result,
- Value *srcVector, Type dstType) {
- result.addOperands(srcVector);
- result.addTypes(dstType);
-}
-ParseResult VectorTypeCastOp::parse(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::OperandType operand;
- Type srcType, dstType;
- return failure(parser.parseOperand(operand) ||
- parser.parseOptionalAttrDict(result.attributes) ||
- parser.parseColonType(srcType) || parser.parseComma() ||
- parser.parseType(dstType) ||
- parser.addTypeToList(dstType, result.types) ||
- parser.resolveOperand(operand, srcType, result.operands));
+static MemRefType inferVectorTypeCastResultType(MemRefType t) {
+ return MemRefType::get({}, VectorType::get(t.getShape(), t.getElementType()));
}
-void VectorTypeCastOp::print(OpAsmPrinter &p) {
- p << getOperationName() << ' ' << *getOperand() << " : "
- << getOperand()->getType() << ", " << getType();
+void VectorTypeCastOp::build(Builder *builder, OperationState &result,
+ Value *source) {
+ result.addOperands(source);
+ result.addTypes(
+ inferVectorTypeCastResultType(source->getType().cast<MemRefType>()));
}
-LogicalResult VectorTypeCastOp::verify() {
- auto dstMemrefType = getType().dyn_cast<MemRefType>();
- if (!dstMemrefType)
- return emitOpError("expects target type to be a memref type");
- auto dstVectorType = dstMemrefType.getElementType().dyn_cast<VectorType>();
- if (!dstVectorType)
- return emitOpError(
- "expects vector as an element of the target memref type");
- if (!dstMemrefType.hasStaticShape())
- return emitOpError("does not support dynamic shapes");
-
- if (!getOperand()->getType().isa<MemRefType>())
- return emitOpError("expects source type to be a memref type");
+static void print(OpAsmPrinter &p, VectorTypeCastOp &op) {
+ auto type = op.getOperand()->getType().cast<MemRefType>();
+ p << op.getOperationName() << ' ' << *op.memref() << " : " << type << " to "
+ << inferVectorTypeCastResultType(type);
+}
+static LogicalResult verify(VectorTypeCastOp &op) {
+ auto resultType = inferVectorTypeCastResultType(op.getMemRefType());
+ if (op.getResultMemRefType() != resultType)
+ return op.emitOpError("expects result type to be: ") << resultType;
return success();
}
diff --git a/mlir/lib/Transforms/LowerVectorTransfers.cpp b/mlir/lib/Transforms/LowerVectorTransfers.cpp
index c517d74f221..57dd18dac0f 100644
--- a/mlir/lib/Transforms/LowerVectorTransfers.cpp
+++ b/mlir/lib/Transforms/LowerVectorTransfers.cpp
@@ -113,12 +113,6 @@ struct VectorTransferRewriter : public RewritePattern {
{}, 0);
}
- /// View of tmpMemRefType as one vector, used in vector load/store to tmp
- /// buffer.
- MemRefType vectorMemRefType(VectorTransferOpTy transfer) const {
- return MemRefType::get({1}, transfer.getVectorType(), {}, 0);
- }
-
/// Performs the rewrite.
PatternMatchResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
@@ -139,7 +133,7 @@ void coalesceCopy(VectorTransferOpTy transfer,
// the loop order for creating pointwise copies between remote and local
// memories.
int coalescedIdx = -1;
- auto exprs = transfer.getPermutationMap().getResults();
+ auto exprs = transfer.permutation_map().getResults();
for (auto en : llvm::enumerate(exprs)) {
auto dim = en.value().template dyn_cast<AffineDimExpr>();
if (!dim) {
@@ -170,7 +164,7 @@ llvm::SmallVector<edsc::ValueHandle, 8> clip(VectorTransferOpTy transfer,
using edsc::intrinsics::select;
IndexHandle zero(index_t(0)), one(index_t(1));
- llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.getIndices());
+ llvm::SmallVector<edsc::ValueHandle, 8> memRefAccess(transfer.indices());
llvm::SmallVector<edsc::ValueHandle, 8> clippedScalarAccessExprs(
memRefAccess.size(), edsc::IndexHandle());
@@ -180,7 +174,7 @@ llvm::SmallVector<edsc::ValueHandle, 8> clip(VectorTransferOpTy transfer,
++memRefDim) {
// Linear search on a small number of entries.
int loopIndex = -1;
- auto exprs = transfer.getPermutationMap().getResults();
+ auto exprs = transfer.permutation_map().getResults();
for (auto en : llvm::enumerate(exprs)) {
auto expr = en.value();
auto dim = expr.template dyn_cast<AffineDimExpr>();
@@ -273,9 +267,9 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
- IndexedValue remote(transfer.getMemRef());
- MemRefView view(transfer.getMemRef());
- VectorView vectorView(transfer.getVector());
+ IndexedValue remote(transfer.memref());
+ MemRefView view(transfer.memref());
+ VectorView vectorView(transfer.vector());
SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
SmallVector<ValueHandle *, 8> pivs =
makeIndexHandlePointers(MutableArrayRef<IndexHandle>(ivs));
@@ -291,12 +285,12 @@ VectorTransferRewriter<VectorTransferReadOp>::matchAndRewrite(
// 2. Emit alloc-copy-load-dealloc.
ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp);
- ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
+ ValueHandle vec = vector_type_cast(tmp);
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
local(ivs) = remote(clip(transfer, view, ivs));
});
- ValueHandle vectorValue = std_load(vec, {constant_index(0)});
+ ValueHandle vectorValue = std_load(vec);
(dealloc(tmp)); // vexing parse
// 3. Propagate.
@@ -336,10 +330,10 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
// 1. Setup all the captures.
ScopedContext scope(rewriter, transfer.getLoc());
- IndexedValue remote(transfer.getMemRef());
- MemRefView view(transfer.getMemRef());
- ValueHandle vectorValue(transfer.getVector());
- VectorView vectorView(transfer.getVector());
+ IndexedValue remote(transfer.memref());
+ MemRefView view(transfer.memref());
+ ValueHandle vectorValue(transfer.vector());
+ VectorView vectorView(transfer.vector());
SmallVector<IndexHandle, 8> ivs = makeIndexHandles(vectorView.rank());
SmallVector<ValueHandle *, 8> pivs = makeIndexHandlePointers(ivs);
coalesceCopy(transfer, &pivs, &vectorView);
@@ -354,8 +348,8 @@ VectorTransferRewriter<VectorTransferWriteOp>::matchAndRewrite(
// 2. Emit alloc-store-copy-dealloc.
ValueHandle tmp = alloc(tmpMemRefType(transfer));
IndexedValue local(tmp);
- ValueHandle vec = vector_type_cast(tmp, vectorMemRefType(transfer));
- std_store(vectorValue, vec, {constant_index(0)});
+ ValueHandle vec = vector_type_cast(tmp);
+ std_store(vectorValue, vec);
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
// Computes clippedScalarAccessExprs in the loop nest scope (ivs exist).
remote(clip(transfer, view, ivs)) = local(ivs);
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index a0b60dd3648..06016da5caa 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -465,7 +465,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer,
++dim;
},
superVectorType.getShape(), *optionalRatio);
- auto permutationMap = transfer.getPermutationMap();
+ auto permutationMap = transfer.permutation_map();
LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: "));
if (keep.empty()) {
return permutationMap;
@@ -486,16 +486,16 @@ static Operation *instantiate(OpBuilder b, VectorTransferReadOp read,
ArrayRef<unsigned> hwVectorInstance,
DenseMap<Value *, Value *> *substitutionsMap) {
SmallVector<Value *, 8> indices =
- map(makePtrDynCaster<Value>(), read.getIndices());
+ map(makePtrDynCaster<Value>(), read.indices());
auto affineIndices =
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
auto map = projectedPermutationMap(read, hwVectorType);
if (!map) {
return nullptr;
}
- auto cloned = b.create<VectorTransferReadOp>(read.getLoc(), hwVectorType,
- read.getMemRef(), affineIndices,
- map, read.getPaddingValue());
+ auto cloned = b.create<VectorTransferReadOp>(
+ read.getLoc(), hwVectorType, read.memref(), affineIndices,
+ AffineMapAttr::get(map), read.padding());
return cloned.getOperation();
}
@@ -510,14 +510,14 @@ static Operation *instantiate(OpBuilder b, VectorTransferWriteOp write,
ArrayRef<unsigned> hwVectorInstance,
DenseMap<Value *, Value *> *substitutionsMap) {
SmallVector<Value *, 8> indices =
- map(makePtrDynCaster<Value>(), write.getIndices());
+ map(makePtrDynCaster<Value>(), write.indices());
auto affineIndices =
reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
auto cloned = b.create<VectorTransferWriteOp>(
write.getLoc(),
- substitute(write.getVector(), hwVectorType, substitutionsMap),
- write.getMemRef(), affineIndices,
- projectedPermutationMap(write, hwVectorType));
+ substitute(write.vector(), hwVectorType, substitutionsMap),
+ write.memref(), affineIndices,
+ AffineMapAttr::get(projectedPermutationMap(write, hwVectorType)));
return cloned.getOperation();
}
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index a1e87568745..b3eea35a55f 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -35,6 +35,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
@@ -718,6 +719,8 @@ struct VectorizationState {
// Checks that the type of `op` is AffineStoreOp and adds it to the terminals
// set.
void registerTerminal(Operation *op);
+ // Folder used to factor out constant creation.
+ OperationFolder *folder;
private:
void registerReplacement(Value *key, Value *value);
@@ -832,7 +835,11 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv,
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<vector::VectorTransferReadOp>(
opInst->getLoc(), vectorType, memoryOp.getMemRef(),
- map(makePtrDynCaster<Value>(), indices), permutationMap);
+ map(makePtrDynCaster<Value>(), indices),
+ AffineMapAttr::get(permutationMap),
+ // TODO(b/144455320) add a proper padding value, not just 0.0 : f32
+ state->folder->create<ConstantFloatOp>(
+ b, opInst->getLoc(), llvm::APFloat(0.0f), b.getF32Type()));
state->registerReplacement(opInst, transfer.getOperation());
} else {
state->registerTerminal(opInst);
@@ -1058,7 +1065,8 @@ static Operation *vectorizeOneOperation(Operation *opInst,
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ");
LLVM_DEBUG(permutationMap.print(dbgs()));
auto transfer = b.create<vector::VectorTransferWriteOp>(
- opInst->getLoc(), vectorValue, memRef, indices, permutationMap);
+ opInst->getLoc(), vectorValue, memRef, indices,
+ AffineMapAttr::get(permutationMap));
auto *res = transfer.getOperation();
LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res);
// "Terminals" (i.e. AffineStoreOps) are erased on the spot.
@@ -1152,8 +1160,10 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) {
static LogicalResult vectorizeRootMatch(NestedMatch m,
VectorizationStrategy *strategy) {
auto loop = cast<AffineForOp>(m.getMatchedOperation());
+ OperationFolder folder(loop.getContext());
VectorizationState state;
state.strategy = strategy;
+ state.folder = &folder;
// Since patterns are recursive, they can very well intersect.
// Since we do not want a fully greedy strategy in general, we decouple
OpenPOWER on IntegriCloud