diff options
| -rw-r--r-- | mlir/include/mlir/Analysis/VectorAnalysis.h | 18 | ||||
| -rw-r--r-- | mlir/include/mlir/StandardOps/StandardOps.h | 144 | ||||
| -rw-r--r-- | mlir/include/mlir/Support/Functional.h | 9 | ||||
| -rw-r--r-- | mlir/lib/Analysis/LoopAnalysis.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Analysis/VectorAnalysis.cpp | 35 | ||||
| -rw-r--r-- | mlir/lib/StandardOps/StandardOps.cpp | 427 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 245 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 54 | ||||
| -rw-r--r-- | mlir/test/IR/core-ops.mlir | 23 | ||||
| -rw-r--r-- | mlir/test/IR/invalid-ops.mlir | 181 | ||||
| -rw-r--r-- | mlir/test/Transforms/materialize_vectors.mlir | 52 | ||||
| -rw-r--r-- | mlir/test/Transforms/vectorize.mlir | 67 |
12 files changed, 1036 insertions, 222 deletions
diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 82bffb8fa7d..a3d31b2f964 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -22,15 +22,11 @@ namespace mlir { +class AffineMap; +class MemRefType; class OperationStmt; class VectorType; -// TODO(ntv): Drop this once we have proper Ops. -static constexpr auto kVectorTransferReadOpName = "vector_transfer_read"; -static constexpr auto kVectorTransferWriteOpName = "vector_transfer_write"; -bool isaVectorTransferRead(const OperationStmt &stmt); -bool isaVectorTransferWrite(const OperationStmt &stmt); - /// Computes and returns the multi-dimensional ratio of `superShape` to /// `subShape`. This is calculated by performing a traversal from minor to major /// dimensions (i.e. in reverse shape order). If integral division is not @@ -49,6 +45,16 @@ shapeRatio(ArrayRef<int> superShape, ArrayRef<int> subShape); llvm::Optional<llvm::SmallVector<unsigned, 4>> shapeRatio(VectorType superVectorType, VectorType subVectorType); +/// Creates a permutation map to be used as an attribute in VectorTransfer ops. +/// Currently only returns the minor vectorType.rank identity submatrix. +/// +/// For example, assume memrefType is of rank 5 and vectorType is of rank 3, +/// returns the affine map: +/// (d0, d1, d2, d3, d4) -> (d2, d3, d4) +/// +/// TODO(ntv): support real permutations. +AffineMap makePermutationMap(MemRefType memrefType, VectorType vectorType); + namespace matcher { /// Matches vector_transfer_read, vector_transfer_write and ops that return a diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index 8301b131add..ffd903aa779 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -28,6 +28,7 @@ #include "mlir/IR/OpDefinition.h" namespace mlir { +class AffineMap; class Builder; class MLValue; @@ -795,6 +796,149 @@ private: explicit TensorCastOp(const Operation *state) : CastOp(state) {} }; +/// VectorTransferReadOp performs a blocking read from a scalar memref +/// location into a super-vector of the same elemental type. This operation is +/// called 'read' by opposition to 'load' because the super-vector granularity +/// is generally not representable with a single hardware register. As a +/// consequence, memory transfers will generally be required when lowering +/// VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction +/// that supports super-vectorization with non-effecting padding for full-tile +/// only code. +// +/// A vector transfer read has semantics similar to a vector load, with +/// additional support for: +/// 1. an optional value of the elemental type of the MemRef. This value +/// supports non-effecting padding and is inserted in places where the +/// vector read exceeds the MemRef bounds. If the value is not specified, +/// the access is statically guaranteed to be within bounds; +/// 2. an attribute of type AffineMap to specify a slice of the original +/// MemRef access and its transposition into the super-vector shape. The +/// permutation_map is an unbounded AffineMap that must represent a +/// permutation from the MemRef dim space projected onto the vector dim +/// space. +// +/// Example: +/// ```mlir +/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32> +/// ... +/// %val = `ssa-value` : f32 +/// // let %i, %j, %k, %l be ssa-values of type index +/// %v0 = vector_transfer_read %src, %i, %j, %k, %l +/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : +/// (memref<?x?x?x?xf32>, index, index, index, index) -> +/// vector<16x32x64xf32> +/// %v1 = vector_transfer_read %src, %i, %j, %k, %l, %val +/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : +/// (memref<?x?x?x?xf32>, index, index, index, index, f32) -> +/// vector<16x32x64xf32> +/// ``` +class VectorTransferReadOp + : public Op<VectorTransferReadOp, OpTrait::VariadicOperands, + OpTrait::OneResult> { + enum Offsets : unsigned { MemRefOffset = 0, FirstIndexOffset = 1 }; + +public: + static StringRef getOperationName() { return "vector_transfer_read"; } + static StringRef getPermutationMapAttrName() { return "permutation_map"; } + static void build(Builder *builder, OperationState *result, + VectorType vectorType, SSAValue *srcMemRef, + ArrayRef<SSAValue *> srcIndices, AffineMap permutationMap, + Optional<SSAValue *> paddingValue = None); + VectorType getResultType() const { + return getResult()->getType().cast<VectorType>(); + } + SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); } + const SSAValue *getMemRef() const { + return getOperand(Offsets::MemRefOffset); + } + MemRefType getMemRefType() const { + return getMemRef()->getType().cast<MemRefType>(); + } + llvm::iterator_range<Operation::operand_iterator> getIndices(); + llvm::iterator_range<Operation::const_operand_iterator> getIndices() const; + Optional<SSAValue *> getPaddingValue(); + Optional<const SSAValue *> getPaddingValue() const; + AffineMap getPermutationMap() const; + + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + +private: + friend class Operation; + explicit VectorTransferReadOp(const Operation *state) : Op(state) {} +}; + +/// VectorTransferWriteOp performs a blocking write from a super-vector to +/// a scalar memref of the same elemental type. This operation is +/// called 'write' by opposition to 'store' because the super-vector granularity +/// is generally not representable with a single hardware register. As a +/// consequence, memory transfers will generally be required when lowering +/// VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level +/// abstraction that supports super-vectorization with non-effecting padding for +/// full-tile only code. +/// +/// A vector transfer write has semantics similar to a vector store, with +/// additional support for handling out-of-bounds situations. It is the +/// responsibility of vector_transfer_write's implementation to ensure the +/// memory writes are valid. Different implementations may be pertinent +/// depending on the hardware support including: +/// 1. predication; +/// 2. explicit control-flow; +/// 3. Read-Modify-Write; +/// 4. writing out of bounds of the memref when the allocation allows it. +/// +/// Example: +/// ```mlir +/// %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>. +/// %val = `ssa-value` : vector<16x32x64xf32> +/// // let %i, %j, %k, %l be ssa-values of type index +/// vector_transfer_write %val, %src, %i, %j, %k, %l +/// {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} : +/// vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index +/// ``` +class VectorTransferWriteOp + : public Op<VectorTransferWriteOp, OpTrait::VariadicOperands, + OpTrait::ZeroResult> { + enum Offsets : unsigned { + VectorOffset = 0, + MemRefOffset = 1, + FirstIndexOffset = 2 + }; + +public: + static StringRef getOperationName() { return "vector_transfer_write"; } + static StringRef getPermutationMapAttrName() { return "permutation_map"; } + static void build(Builder *builder, OperationState *result, + SSAValue *srcVector, SSAValue *dstMemRef, + ArrayRef<SSAValue *> dstIndices, AffineMap permutationMap); + SSAValue *getVector() { return getOperand(Offsets::VectorOffset); } + const SSAValue *getVector() const { + return getOperand(Offsets::VectorOffset); + } + VectorType getVectorType() const { + return getVector()->getType().cast<VectorType>(); + } + SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); } + const SSAValue *getMemRef() const { + return getOperand(Offsets::MemRefOffset); + } + MemRefType getMemRefType() const { + return getMemRef()->getType().cast<MemRefType>(); + } + llvm::iterator_range<Operation::operand_iterator> getIndices(); + llvm::iterator_range<Operation::const_operand_iterator> getIndices() const; + AffineMap getPermutationMap() const; + + static bool parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p) const; + bool verify() const; + +private: + friend class Operation; + explicit VectorTransferWriteOp(const Operation *state) : Op(state) {} +}; + } // end namespace mlir #endif diff --git a/mlir/include/mlir/Support/Functional.h b/mlir/include/mlir/Support/Functional.h index 071a611151c..e1b1ee5ce58 100644 --- a/mlir/include/mlir/Support/Functional.h +++ b/mlir/include/mlir/Support/Functional.h @@ -19,6 +19,7 @@ #define MLIR_SUPPORT_FUNCTIONAL_H_ #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" /// This file provides some simple template functional-style sugar to operate /// on **value** types. Make sure when using that the stored type is cheap to @@ -78,6 +79,14 @@ void zipApply(Fun fun, ContainerType1 input1, ContainerType2 input2) { } } +/// Unwraps a pointer type to another type (possibly the same). +/// Used in particular to allow easier compositions of +/// llvm::iterator_range<ForStmt::operand_iterator> types. +template <typename T, typename ToType = T> +inline std::function<ToType *(T *)> makePtrDynCaster() { + return [](T *val) { return llvm::dyn_cast<ToType>(val); }; +} + /// Simple ScopeGuard. struct ScopeGuard { explicit ScopeGuard(std::function<void(void)> destruct) diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 8406a37d793..de98849136c 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -194,7 +194,8 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) { // TODO(ntv): make the following into MLIR instructions, then use isa<>. static bool isVectorTransferReadOrWrite(const Statement &stmt) { const auto *opStmt = cast<OperationStmt>(&stmt); - return isaVectorTransferRead(*opStmt) || isaVectorTransferWrite(*opStmt); + return opStmt->isa<VectorTransferReadOp>() || + opStmt->isa<VectorTransferWriteOp>(); } using VectorizableStmtFun = diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 75f62299e1b..9c2160c1450 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -18,6 +18,7 @@ #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Statements.h" +#include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" @@ -28,14 +29,6 @@ using namespace mlir; -bool mlir::isaVectorTransferRead(const OperationStmt &stmt) { - return stmt.getName().getStringRef().str() == kVectorTransferReadOpName; -} - -bool mlir::isaVectorTransferWrite(const OperationStmt &stmt) { - return stmt.getName().getStringRef().str() == kVectorTransferWriteOpName; -} - Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(ArrayRef<int> superShape, ArrayRef<int> subShape) { if (superShape.size() < subShape.size()) { @@ -83,6 +76,20 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType, return shapeRatio(superVectorType.getShape(), subVectorType.getShape()); } +AffineMap mlir::makePermutationMap(MemRefType memrefType, + VectorType vectorType) { + unsigned memRefRank = memrefType.getRank(); + unsigned vectorRank = vectorType.getRank(); + assert(memRefRank >= vectorRank && "Broadcast not supported"); + unsigned offset = memRefRank - vectorRank; + SmallVector<AffineExpr, 4> perm; + perm.reserve(memRefRank); + for (unsigned i = 0; i < vectorRank; ++i) { + perm.push_back(getAffineDimExpr(offset + i, memrefType.getContext())); + } + return AffineMap::get(memRefRank, 0, perm, {}); +} + bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt, VectorType subVectorType) { // First, extract the vector type and ditinguish between: @@ -96,15 +103,11 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt, /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; VectorType superVectorType; - if (isaVectorTransferRead(opStmt)) { - superVectorType = opStmt.getResult(0)->getType().cast<VectorType>(); + if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) { + superVectorType = read->getResultType(); mustDivide = true; - } else if (isaVectorTransferWrite(opStmt)) { - // TODO(ntv): if vector_transfer_write had store-like semantics we could - // have written something similar to: - // auto store = storeOp->cast<StoreOp>(); - // auto *value = store->getValueToStore(); - superVectorType = opStmt.getOperand(0)->getType().cast<VectorType>(); + } else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) { + superVectorType = write->getVectorType(); mustDivide = true; } else if (opStmt.getNumResults() == 0) { assert(opStmt.isa<ReturnOp>() && diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 4de951a12f9..4d71ddeab16 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -40,7 +40,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context) addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp, DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp, LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp, - SubIOp, TensorCastOp>(); + SubIOp, TensorCastOp, VectorTransferReadOp, + VectorTransferWriteOp>(); } //===----------------------------------------------------------------------===// @@ -1321,3 +1322,427 @@ bool TensorCastOp::verify() const { return false; } + +//===----------------------------------------------------------------------===// +// VectorTransferReadOp +//===----------------------------------------------------------------------===// +template <typename EmitFun> +static bool verifyPermutationMap(AffineMap permutationMap, + EmitFun emitOpError) { + SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); + for (auto expr : permutationMap.getResults()) { + auto dim = expr.dyn_cast<AffineDimExpr>(); + if (!dim) { + return emitOpError( + "requires a permutation_map that is an actual permutation"); + } + if (seen[dim.getPosition()]) { + return emitOpError( + "requires a permutation_map that is a full column-rank " + "permutation (i.e. a permutation composed with an " + "orthogonal projection)"); + } + seen[dim.getPosition()] = true; + } + return false; +} + +void VectorTransferReadOp::build(Builder *builder, OperationState *result, + VectorType vectorType, SSAValue *srcMemRef, + ArrayRef<SSAValue *> srcIndices, + AffineMap permutationMap, + Optional<SSAValue *> paddingValue) { + result->addOperands(srcMemRef); + result->addOperands(srcIndices); + if (paddingValue) { + result->addOperands({*paddingValue}); + } + result->addAttribute(getPermutationMapAttrName(), + builder->getAffineMapAttr(permutationMap)); + result->addTypes(vectorType); +} + +llvm::iterator_range<Operation::operand_iterator> +VectorTransferReadOp::getIndices() { + auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; + auto end = begin + getMemRefType().getRank(); + return {begin, end}; +} + +llvm::iterator_range<Operation::const_operand_iterator> +VectorTransferReadOp::getIndices() const { + auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; + auto end = begin + getMemRefType().getRank(); + return {begin, end}; +} + +Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() { + auto memRefRank = getMemRefType().getRank(); + if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { + return None; + } + return Optional<SSAValue *>( + getOperand(Offsets::FirstIndexOffset + memRefRank)); +} + +Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const { + auto memRefRank = getMemRefType().getRank(); + if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) { + return None; + } + return Optional<const SSAValue *>( + getOperand(Offsets::FirstIndexOffset + memRefRank)); +} + +AffineMap VectorTransferReadOp::getPermutationMap() const { + return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue(); +} + +void VectorTransferReadOp::print(OpAsmPrinter *p) const { + *p << getOperationName() << " "; + p->printOperand(getMemRef()); + *p << ", "; + p->printOperands(getIndices()); + auto optionalPaddingValue = getPaddingValue(); + if (optionalPaddingValue) { + *p << ", "; + p->printOperand(*optionalPaddingValue); + } + p->printOptionalAttrDict(getAttrs()); + // Construct the FunctionType and print it. + llvm::SmallVector<Type, 8> inputs{getMemRefType()}; + // Must have at least one actual index, see verify. + const SSAValue *firstIndex = *(getIndices().begin()); + Type indexType = firstIndex->getType(); + inputs.append(getMemRefType().getRank(), indexType); + if (optionalPaddingValue) { + inputs.push_back((*optionalPaddingValue)->getType()); + } + *p << " : " + << FunctionType::get(inputs, {getResultType()}, indexType.getContext()); +} + +bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector<OpAsmParser::OperandType, 8> parsedOperands; + Type type; + + // Parsing with support for optional paddingValue. + auto fail = parser->parseOperandList(parsedOperands) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonType(type); + if (fail) { + return true; + } + + // Resolution. + auto funType = type.dyn_cast<FunctionType>(); + if (!funType) { + parser->emitError(parser->getNameLoc(), "Function type expected"); + return true; + } + if (funType.getNumInputs() < 1) { + parser->emitError(parser->getNameLoc(), + "Function type expects at least one input"); + return true; + } + MemRefType memrefType = + funType.getInput(Offsets::MemRefOffset).dyn_cast<MemRefType>(); + if (!memrefType) { + parser->emitError(parser->getNameLoc(), + "MemRef type expected for first input"); + return true; + } + if (funType.getNumResults() < 1) { + parser->emitError(parser->getNameLoc(), + "Function type expects exactly one vector result"); + return true; + } + VectorType vectorType = funType.getResult(0).dyn_cast<VectorType>(); + if (!vectorType) { + parser->emitError(parser->getNameLoc(), + "Vector type expected for first result"); + return true; + } + if (parsedOperands.size() != funType.getNumInputs()) { + parser->emitError(parser->getNameLoc(), "requires " + + Twine(funType.getNumInputs()) + + " operands"); + return true; + } + + // Extract optional paddingValue. + OpAsmParser::OperandType memrefInfo = parsedOperands[0]; + // At this point, indexInfo may contain the optional paddingValue, pop it out. + SmallVector<OpAsmParser::OperandType, 8> indexInfo{ + parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()}; + Type paddingType; + OpAsmParser::OperandType paddingValue; + bool hasPaddingValue = indexInfo.size() > memrefType.getRank(); + unsigned expectedNumOperands = Offsets::FirstIndexOffset + + memrefType.getRank() + + (hasPaddingValue ? 1 : 0); + if (hasPaddingValue) { + paddingType = funType.getInputs().back(); + paddingValue = indexInfo.pop_back_val(); + } + if (funType.getNumInputs() != expectedNumOperands) { + parser->emitError( + parser->getNameLoc(), + "requires actual number of operands to match function type"); + return true; + } + + auto indexType = parser->getBuilder().getIndexType(); + return parser->resolveOperand(memrefInfo, memrefType, result->operands) || + parser->resolveOperands(indexInfo, indexType, result->operands) || + (hasPaddingValue && parser->resolveOperand(paddingValue, paddingType, + result->operands)) || + parser->addTypeToList(vectorType, result->types); +} + +bool VectorTransferReadOp::verify() const { + // Consistency of memref type in function type. + if (llvm::empty(getOperands())) { + return emitOpError( + "requires at least a memref operand followed by 'rank' indices"); + } + if (!getMemRef()->getType().isa<MemRefType>()) { + return emitOpError("requires a memref as first operand"); + } + // Consistency of vector type in function type. + if (!getResult()->getType().isa<VectorType>()) { + return emitOpError("should have a vector result type in function type: " + "(memref_type [, elemental_type]) -> vector_type"); + } + // Consistency of elemental types in memref and vector. + MemRefType memrefType = getMemRefType(); + VectorType vectorType = getResultType(); + if (memrefType.getElementType() != vectorType.getElementType()) + return emitOpError( + "requires memref and vector types of the same elemental type"); + // Consistency of number of input types. + auto optionalPaddingValue = getPaddingValue(); + unsigned expectedNumOperands = Offsets::FirstIndexOffset + + memrefType.getRank() + + (optionalPaddingValue ? 1 : 0); + // Checks on the actual operands and their types. + if (getNumOperands() != expectedNumOperands) { + return emitOpError("expects " + Twine(expectedNumOperands) + + " operands to match the types"); + } + // Consistency of padding value with vector type. + if (optionalPaddingValue) { + auto paddingValue = *optionalPaddingValue; + auto elementalType = paddingValue->getType(); + if (!VectorType::isValidElementType(elementalType)) { + return emitOpError("requires valid padding vector elemental type"); + } + if (elementalType != vectorType.getElementType()) { + return emitOpError( + "requires formal padding and vector of the same elemental type"); + } + } + // Consistency of indices types. + unsigned numIndices = 0; + for (auto *idx : getIndices()) { + if (!idx->getType().isIndex()) { + return emitOpError( + "index to vector_transfer_read must have 'index' type"); + } + ++numIndices; + } + if (numIndices != memrefType.getRank()) { + return emitOpError("requires at least a memref operand followed by " + + Twine(memrefType.getRank()) + " indices"); + } + + // Consistency of AffineMap attribute. + if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) { + return emitOpError("requires an AffineMapAttr named 'permutation_map'"); + } + auto permutationMap = getPermutationMap(); + if (!permutationMap.getRangeSizes().empty()) { + return emitOpError("requires an unbounded permutation_map"); + } + if (permutationMap.getNumSymbols() != 0) { + return emitOpError("requires a permutation_map without symbols"); + } + if (permutationMap.getNumInputs() != memrefType.getRank()) { + return emitOpError("requires a permutation_map with input dims of the " + "same rank as the memref type"); + } + if (permutationMap.getNumResults() != vectorType.getRank()) { + return emitOpError("requires a permutation_map with result dims of the " + "same rank as the vector type"); + } + return verifyPermutationMap(permutationMap, + [this](Twine t) { return emitOpError(t); }); +} + +//===----------------------------------------------------------------------===// +// VectorTransferWriteOp +//===----------------------------------------------------------------------===// +void VectorTransferWriteOp::build(Builder *builder, OperationState *result, + SSAValue *srcVector, SSAValue *dstMemRef, + ArrayRef<SSAValue *> dstIndices, + AffineMap permutationMap) { + result->addOperands({srcVector, dstMemRef}); + result->addOperands(dstIndices); + result->addAttribute(getPermutationMapAttrName(), + builder->getAffineMapAttr(permutationMap)); +} + +llvm::iterator_range<Operation::operand_iterator> +VectorTransferWriteOp::getIndices() { + auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; + auto end = begin + getMemRefType().getRank(); + return {begin, end}; +} + +llvm::iterator_range<Operation::const_operand_iterator> +VectorTransferWriteOp::getIndices() const { + auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset; + auto end = begin + getMemRefType().getRank(); + return {begin, end}; +} + +AffineMap VectorTransferWriteOp::getPermutationMap() const { + return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue(); +} + +void VectorTransferWriteOp::print(OpAsmPrinter *p) const { + *p << getOperationName(); + *p << " " << *getVector(); + *p << ", " << *getMemRef(); + *p << ", "; + p->printOperands(getIndices()); + p->printOptionalAttrDict(getAttrs()); + Type indexType = (*getIndices().begin())->getType(); + *p << " : "; + p->printType(getVectorType()); + *p << ", "; + p->printType(getMemRefType()); + for (unsigned r = 0, n = getMemRefType().getRank(); r < n; ++r) { + *p << ", "; + p->printType(indexType); + } +} + +bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) { + SmallVector<OpAsmParser::OperandType, 8> parsedOperands; + SmallVector<Type, 8> types; + + // Parsing with support for optional paddingValue. + auto fail = parser->parseOperandList(parsedOperands) || + parser->parseOptionalAttributeDict(result->attributes) || + parser->parseColonTypeList(types); + if (fail) { + return true; + } + + // Resolution. + if (parsedOperands.size() != types.size()) { + parser->emitError(parser->getNameLoc(), + "requires number of operands and input types to match"); + return true; + } + if (parsedOperands.size() < Offsets::FirstIndexOffset) { + parser->emitError(parser->getNameLoc(), + "requires at least vector and memref operands"); + return true; + } + VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>(); + if (!vectorType) { + parser->emitError(parser->getNameLoc(), + "Vector type expected for first input type"); + return true; + } + MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>(); + if (!memrefType) { + parser->emitError(parser->getNameLoc(), + "MemRef type expected for second input type"); + return true; + } + + unsigned expectedNumOperands = + Offsets::FirstIndexOffset + memrefType.getRank(); + if (parsedOperands.size() != expectedNumOperands) { + parser->emitError(parser->getNameLoc(), + "requires " + Twine(expectedNumOperands) + " operands"); + return true; + } + + OpAsmParser::OperandType vectorInfo = parsedOperands[Offsets::VectorOffset]; + OpAsmParser::OperandType memrefInfo = parsedOperands[Offsets::MemRefOffset]; + SmallVector<OpAsmParser::OperandType, 8> indexInfo{ + parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()}; + auto indexType = parser->getBuilder().getIndexType(); + return parser->resolveOperand(vectorInfo, vectorType, result->operands) || + parser->resolveOperand(memrefInfo, memrefType, result->operands) || + parser->resolveOperands(indexInfo, indexType, result->operands); +} + +bool VectorTransferWriteOp::verify() const { + // Consistency of memref type in function type. + if (llvm::empty(getOperands())) { + return emitOpError( + "requires at least a memref operand followed by 'rank' indices"); + } + if (!getMemRef()->getType().isa<MemRefType>()) { + return emitOpError("requires a memref first operand"); + } + // Consistency of vector type in function type. + if (!getVector()->getType().isa<VectorType>()) { + return emitOpError("should have a vector input type in function type: " + "(vector_type, memref_type [, elemental_type]) -> ()"); + } + // Consistency of elemental types in memref and vector. + MemRefType memrefType = getMemRefType(); + VectorType vectorType = getVectorType(); + if (memrefType.getElementType() != vectorType.getElementType()) + return emitOpError( + "requires memref and vector types of the same elemental type"); + // Consistency of number of input types. + unsigned expectedNumOperands = + Offsets::FirstIndexOffset + memrefType.getRank(); + // Checks on the actual operands and their types. + if (getNumOperands() != expectedNumOperands) { + return emitOpError("expects " + Twine(expectedNumOperands) + + " operands to match the types"); + } + // Consistency of indices types. + unsigned numIndices = 0; + for (auto *idx : getIndices()) { + if (!idx->getType().isIndex()) { + return emitOpError( + "index to vector_transfer_write must have 'index' type"); + } + numIndices++; + } + if (numIndices != memrefType.getRank()) { + return emitOpError("requires at least a memref operand followed by " + + Twine(memrefType.getRank()) + " indices"); + } + + // Consistency of AffineMap attribute. + if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) { + return emitOpError("requires an AffineMapAttr named 'permutation_map'"); + } + auto permutationMap = getPermutationMap(); + if (!permutationMap.getRangeSizes().empty()) { + return emitOpError("requires an unbounded permutation_map"); + } + if (permutationMap.getNumSymbols() != 0) { + return emitOpError("requires a permutation_map without symbols"); + } + if (permutationMap.getNumInputs() != memrefType.getRank()) { + return emitOpError("requires a permutation_map with input dims of the " + "same rank as the memref type"); + } + if (permutationMap.getNumResults() != vectorType.getRank()) { + return emitOpError("requires a permutation_map with result dims of the " + "same rank as the vector type"); + } + return verifyPermutationMap(permutationMap, + [this](Twine t) { return emitOpError(t); }); +} diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 60f0c06aad5..400b4fdf934 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -89,6 +89,7 @@ using llvm::SetVector; using namespace mlir; +using functional::makePtrDynCaster; using functional::map; static llvm::cl::list<int> @@ -243,11 +244,11 @@ substitute(SSAValue *v, /// TODO(ntv): support a concrete AffineMap and compose with it. /// TODO(ntv): these implementation details should be captured in a /// vectorization trait at the op level directly. -static SmallVector<MLValue *, 8> -reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType, +static SmallVector<SSAValue *, 8> +reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, ArrayRef<SSAValue *> memrefIndices) { - auto vectorShape = hwVectorType.cast<VectorType>().getShape(); + auto vectorShape = hwVectorType.getShape(); assert(hwVectorInstance.size() >= vectorShape.size()); unsigned numIndices = memrefIndices.size(); @@ -287,78 +288,21 @@ reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType, // TODO(ntv): support a concrete map and composition. auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(), affineMap, memrefIndices); - unsigned numResults = app->getNumResults(); - SmallVector<MLValue *, 8> res; - for (unsigned i = 0; i < numResults; ++i) { - res.push_back(cast<MLValue>(app->getResult(i))); - } - return res; + return SmallVector<SSAValue *, 8>{app->getResults()}; } -/// Returns the cloned operands of `opStmt` for the instance of -/// `hwVectorInstance` when lowering from a super-vector type to -/// `hwVectorType`. `hwVectorInstance` represents one particular instance of -/// `hwVectorType` int the covering of the super-vector type. For a more -/// detailed description of the problem, see the description of -/// reindexAffineIndices. -static SmallVector<MLValue *, 8> -cloneAndUnrollOperands(OperationStmt *opStmt, Type hwVectorType, - ArrayRef<unsigned> hwVectorInstance, - DenseMap<const MLValue *, MLValue *> *substitutionsMap) { - using functional::map; - - // For Ops that are not vector_transfer_read/vector_transfer_write we can just - // substitute and be done. - if (!isaVectorTransferRead(*opStmt) && !isaVectorTransferWrite(*opStmt)) { - return map([substitutionsMap]( - SSAValue *v) { return substitute(v, *substitutionsMap); }, - opStmt->getOperands()); - } - - // TODO(ntv): this error-prone boilerplate can be removed once we have a - // proper Op for vectr_transfer. - unsigned offset = 0; - unsigned numIndices = 0; - SmallVector<MLValue *, 8> res; - auto operands = opStmt->getOperands(); - if (isaVectorTransferRead(*opStmt)) { - offset = 1; - numIndices = opStmt->getNumOperands() - 1; - } else if (isaVectorTransferWrite(*opStmt)) { - offset = 2; - numIndices = opStmt->getNumOperands() - 2; - } - // Copy as-is the [optional valueToStore], memref. - for (unsigned i = 0; i < offset; ++i) { - res.push_back(substitute(*(operands.begin() + i), *substitutionsMap)); - } - - MLFuncBuilder b(opStmt); - // TODO(ntv): indices extraction is brittle and unsafe before we have an Op. - SmallVector<SSAValue *, 8> indices; - for (auto it = operands.begin() + offset; it != operands.end(); ++it) { - indices.push_back(*it); - } - auto affineValues = - reindexAffineIndices(&b, hwVectorType, hwVectorInstance, indices); - res.append(affineValues.begin(), affineValues.end()); - - return res; -} - -// Returns attributes with the following substitutions applied: -// - splat of `superVectorType` is replaced by splat of `hwVectorType`. -// TODO(ntv): add more substitutions on a per-need basis. -static SmallVector<NamedAttribute, 2> +/// Returns attributes with the following substitutions applied: +/// - splat of `superVectorType` is replaced by splat of `hwVectorType`. +/// TODO(ntv): add more substitutions on a per-need basis. +static SmallVector<NamedAttribute, 1> materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, VectorType hwVectorType) { - SmallVector<NamedAttribute, 2> res; + SmallVector<NamedAttribute, 1> res; for (auto a : opStmt->getAttrs()) { auto splat = a.second.dyn_cast<SplatElementsAttr>(); bool splatOfSuperVectorType = splat && (splat.getType() == superVectorType); if (splatOfSuperVectorType) { - auto attr = SplatElementsAttr::get(hwVectorType.cast<VectorType>(), - splat.getValue()); + auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue()); res.push_back(NamedAttribute(a.first, attr)); } else { res.push_back(a); @@ -367,6 +311,70 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, return res; } +/// Creates an instantiated version of `opStmt`. +/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no +/// affine reindexing. Just substitute their SSAValue* operands and be done. For +/// this case the actual instance is irrelevant. Just use the SSA values in +/// substitutionsMap. +static OperationStmt * +instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType superVectorType, + VectorType hwVectorType, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + assert(!opStmt->isa<VectorTransferReadOp>() && + "Should call the function specialized for VectorTransferReadOp"); + assert(!opStmt->isa<VectorTransferWriteOp>() && + "Should call the function specialized for VectorTransferWriteOp"); + auto operands = + map([substitutionsMap]( + SSAValue *v) { return substitute(v, *substitutionsMap); }, + opStmt->getOperands()); + return b->createOperation( + opStmt->getLoc(), opStmt->getName(), operands, {hwVectorType}, + materializeAttributes(opStmt, superVectorType, hwVectorType)); +} + +/// Creates an instantiated version of `read` for the instance of +/// `hwVectorInstance` when lowering from a super-vector type to +/// `hwVectorType`. `hwVectorInstance` represents one particular instance of +/// `hwVectorType` int the covering of the super-vector type. For a more +/// detailed description of the problem, see the description of +/// reindexAffineIndices. +static OperationStmt * +instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, + VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + SmallVector<SSAValue *, 8> indices = + map(makePtrDynCaster<SSAValue>(), read->getIndices()); + auto affineIndices = + reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); + auto cloned = b->create<VectorTransferReadOp>( + read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, + makePermutationMap(read->getMemRefType(), hwVectorType), + read->getPaddingValue()); + return cast<OperationStmt>(cloned->getOperation()); +} + +/// Creates an instantiated version of `write` for the instance of +/// `hwVectorInstance` when lowering from a super-vector type to +/// `hwVectorType`. `hwVectorInstance` represents one particular instance of +/// `hwVectorType` int the covering of th3e super-vector type. For a more +/// detailed description of the problem, see the description of +/// reindexAffineIndices. +static OperationStmt * +instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write, + VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + SmallVector<SSAValue *, 8> indices = + map(makePtrDynCaster<SSAValue>(), write->getIndices()); + auto affineIndices = + reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); + auto cloned = b->create<VectorTransferWriteOp>( + write->getLoc(), substitute(write->getVector(), *substitutionsMap), + write->getMemRef(), affineIndices, + makePermutationMap(write->getMemRefType(), hwVectorType)); + return cast<OperationStmt>(cloned->getOperation()); +} + /// Returns `true` if stmt instance is properly cloned and inserted, false /// otherwise. /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of @@ -386,45 +394,52 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, /// type, all operands are substituted according to `substitutions`. Thanks /// to the topological order of a slice, the substitution is always /// possible. -static bool cloneAndInsertHardwareVectorInstance(Statement *stmt, - MaterializationState *state) { - LLVM_DEBUG(dbgs() << "\nclone" << *stmt); - if (auto *opStmt = dyn_cast<OperationStmt>(stmt)) { - // TODO(ntv): Is it worth considering an OperationStmt.clone operation - // which changes the type so we can promote an OperationStmt with less - // boilerplate? - assert(opStmt->getNumResults() <= 1 && "NYI: opStmt has > 1 results"); - auto operands = cloneAndUnrollOperands(opStmt, state->hwVectorType, - state->hwVectorInstance, - state->substitutionsMap); - MLFuncBuilder b(stmt); - if (opStmt->getNumResults() == 0) { - // vector_transfer_write - b.createOperation(stmt->getLoc(), opStmt->getName(), operands, {}, - materializeAttributes(opStmt, state->superVectorType, - state->hwVectorType)); - } else { - // vector_transfer_read - auto *cloned = b.createOperation( - stmt->getLoc(), opStmt->getName(), operands, {state->hwVectorType}, - materializeAttributes(opStmt, state->superVectorType, - state->hwVectorType)); - state->substitutionsMap->insert(std::make_pair( - cast<MLValue>(opStmt->getResult(0)), - cast<MLValue>(cast<OperationStmt>(cloned)->getResult(0)))); - } - return false; - } +static bool instantiateMaterialization(Statement *stmt, + MaterializationState *state) { + LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt); + // Fail hard and wake up when needed. if (isa<ForStmt>(stmt)) { - // Fail hard and wake up when needed. stmt->emitError("NYI path ForStmt"); return true; } // Fail hard and wake up when needed. - stmt->emitError("NYI path IfStmt"); - return true; + if (isa<IfStmt>(stmt)) { + stmt->emitError("NYI path IfStmt"); + return true; + } + + // Create a builder here for unroll-and-jam effects. + MLFuncBuilder b(stmt); + auto *opStmt = cast<OperationStmt>(stmt); + if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) { + instantiate(&b, &*write, state->hwVectorType, state->hwVectorInstance, + state->substitutionsMap); + return false; + } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) { + auto *clone = instantiate(&b, &*read, state->hwVectorType, + state->hwVectorInstance, state->substitutionsMap); + state->substitutionsMap->insert(std::make_pair( + cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0)))); + return false; + } + // The only op with 0 results reaching this point must, by construction, be + // VectorTransferWriteOps and have been caught above. Ops with >= 2 results + // are not yet supported. So just support 1 result. + if (opStmt->getNumResults() != 1) { + stmt->emitError("NYI: ops with != 1 results"); + return true; + } + if (opStmt->getResult(0)->getType() != state->superVectorType) { + stmt->emitError("Op does not return a supervector."); + return true; + } + auto *clone = instantiate(&b, opStmt, state->superVectorType, + state->hwVectorType, state->substitutionsMap); + state->substitutionsMap->insert(std::make_pair( + cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0)))); + return false; } /// Takes a slice and rewrites the operations in it so that occurrences @@ -463,15 +478,22 @@ static void emitSlice(MaterializationState *state, scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. for (auto *stmt : *slice) { - auto fail = cloneAndInsertHardwareVectorInstance(stmt, &scopedState); + auto fail = instantiateMaterialization(stmt, &scopedState); (void)fail; assert(!fail && "Unhandled super-vector materialization failure"); } } + + LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); + LLVM_DEBUG( + cast<OperationStmt>((*slice)[0])->getOperationFunction()->print(dbgs())); + // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* // dereference. for (int idx = slice->size() - 1; idx >= 0; --idx) { + LLVM_DEBUG(dbgs() << "\nErase: "); + LLVM_DEBUG((*slice)[idx]->print(dbgs())); (*slice)[idx]->erase(); } } @@ -497,25 +519,21 @@ static void materialize(MLFunction *f, const SetVector<OperationStmt *> &terminators, MaterializationState *state) { DenseSet<Statement *> seen; - for (auto terminator : terminators) { - LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *terminator); - + for (auto *term : terminators) { // Short-circuit test, a given terminator may have been reached by some // other previous transitive use-def chains. - if (seen.count(terminator) > 0) { + if (seen.count(term) > 0) { continue; } - // Terminators are vector_transfer_write with 0 results by construction atm. - assert(isaVectorTransferWrite(*terminator) && ""); - assert(terminator->getNumResults() == 0 && - "NYI: terminators must have 0 results"); + auto terminator = term->cast<VectorTransferWriteOp>(); + LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term); // Get the transitive use-defs starting from terminator, limited to the // current enclosing scope of the terminator. See the top of the function // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. - auto *enclosingScope = terminator->getParentStmt(); + auto *enclosingScope = term->getParentStmt(); auto keepIfInSameScope = [enclosingScope](Statement *stmt) { assert(stmt && "NULL stmt"); if (!enclosingScope) { @@ -525,7 +543,7 @@ static void materialize(MLFunction *f, return properlyDominates(*enclosingScope, *stmt); }; SetVector<Statement *> slice = - getSlice(terminator, keepIfInSameScope, keepIfInSameScope); + getSlice(term, keepIfInSameScope, keepIfInSameScope); assert(!slice.empty()); // Sanity checks: transitive slice must be completely disjoint from @@ -540,10 +558,9 @@ static void materialize(MLFunction *f, // Emit the current slice. // Set scoped super-vector and corresponding hw vector types. - state->superVectorType = - terminator->getOperand(0)->getType().cast<VectorType>(); + state->superVectorType = terminator->getVectorType(); assert((state->superVectorType.getElementType() == - Type::getF32(terminator->getContext())) && + Type::getF32(term->getContext())) && "Only f32 supported for now"); state->hwVectorType = VectorType::get( state->hwVectorSize, state->superVectorType.getElementType()); @@ -568,7 +585,7 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) { // super-vector of subVectorType. auto filter = [subVectorType](const Statement &stmt) { const auto &opStmt = cast<OperationStmt>(stmt); - if (!isaVectorTransferWrite(opStmt)) { + if (!opStmt.isa<VectorTransferWriteOp>()) { return false; } return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 5a408b0a2d7..e4822c27ac9 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -541,6 +541,7 @@ using namespace mlir; #define DEBUG_TYPE "early-vect" using functional::apply; +using functional::makePtrDynCaster; using functional::map; using functional::ScopeGuard; using llvm::dbgs; @@ -820,23 +821,15 @@ void VectorizationState::registerReplacement(const SSAValue *key, /// TODO(andydavis,bondhugula,ntv): /// 1. generalize to support padding semantics and offsets within vector type. static OperationStmt * -createVectorTransferRead(MLFuncBuilder *b, Location loc, VectorType vectorType, +createVectorTransferRead(OperationStmt *loadOp, VectorType vectorType, SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices) { - SmallVector<SSAValue *, 8> operands; - operands.reserve(1 + srcIndices.size()); - operands.insert(operands.end(), srcMemRef); - operands.insert(operands.end(), srcIndices.begin(), srcIndices.end()); - OperationState opState(b->getContext(), loc, kVectorTransferReadOpName, - operands, vectorType); - return b->createOperation(opState); -} - -/// Unwraps a pointer type to another type (possibly the same). -/// Used in particular to allow easier compositions of -/// llvm::iterator_range<ForStmt::operand_iterator> types. -template <typename T, typename ToType = T> -static std::function<ToType *(T *)> unwrapPtr() { - return [](T *val) { return dyn_cast<ToType>(val); }; + auto memRefType = srcMemRef->getType().cast<MemRefType>(); + MLFuncBuilder b(loadOp); + // TODO(ntv): neutral for noneffective padding. + auto transfer = b.create<VectorTransferReadOp>( + loadOp->getLoc(), vectorType, srcMemRef, srcIndices, + makePermutationMap(memRefType, vectorType)); + return cast<OperationStmt>(transfer->getOperation()); } /// Handles the vectorization of load and store MLIR operations. @@ -865,15 +858,14 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp, // Materialize a MemRef with 1 vector. auto *opStmt = cast<OperationStmt>(memoryOp->getOperation()); - MLFuncBuilder b(opStmt); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations // as needed by various targets. if (opStmt->template isa<LoadOp>()) { auto *transfer = createVectorTransferRead( - &b, opStmt->getLoc(), vectorType, memoryOp->getMemRef(), - map(unwrapPtr<SSAValue>(), memoryOp->getIndices())); + opStmt, vectorType, memoryOp->getMemRef(), + map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices())); state->registerReplacement(opStmt, transfer); } else { state->registerTerminator(opStmt); @@ -1008,7 +1000,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant, auto *splat = cast<OperationStmt>(b.createOperation( loc, constantOpStmt->getName(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)})); - return cast<MLValue>(cast<OperationStmt>(splat)->getResult(0)); + return cast<MLValue>(splat->getResult(0)); } /// Returns a uniqu'ed VectorType. @@ -1106,17 +1098,17 @@ static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt, static OperationStmt *createVectorTransferWrite(OperationStmt *storeOp, VectorizationState *state) { auto store = storeOp->cast<StoreOp>(); + auto *memRef = store->getMemRef(); + auto memRefType = memRef->getType().cast<MemRefType>(); auto *value = store->getValueToStore(); - auto indices = map(unwrapPtr<SSAValue>(), store->getIndices()); - SmallVector<SSAValue *, 8> operands; - operands.reserve(1 + 1 + indices.size()); - operands.insert(operands.end(), vectorizeOperand(value, storeOp, state)); - operands.insert(operands.end(), store->getMemRef()); - operands.insert(operands.end(), indices.begin(), indices.end()); + auto *vectorValue = vectorizeOperand(value, storeOp, state); + auto vectorType = vectorValue->getType().cast<VectorType>(); + auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices()); MLFuncBuilder b(storeOp); - OperationState opState(b.getContext(), storeOp->getLoc(), - kVectorTransferWriteOpName, operands, {}); - return b.createOperation(opState); + auto transfer = b.create<VectorTransferWriteOp>( + storeOp->getLoc(), vectorValue, memRef, indices, + makePermutationMap(memRefType, vectorType)); + return cast<OperationStmt>(transfer->getOperation()); } /// Encodes OperationStmt-specific behavior for vectorization. In general we @@ -1134,9 +1126,9 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b, // Sanity checks. assert(!stmt->isa<LoadOp>() && "all loads must have already been fully vectorized independently"); - assert(!isaVectorTransferRead(*stmt) && + assert(!stmt->isa<VectorTransferReadOp>() && "vector_transfer_read cannot be further vectorized"); - assert(!isaVectorTransferWrite(*stmt) && + assert(!stmt->isa<VectorTransferWriteOp>() && "vector_transfer_write cannot be further vectorized"); if (stmt->isa<StoreOp>()) { diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index f9e0e5f9404..9a056994c6b 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -9,6 +9,9 @@ // CHECK: #map2 = (d0, d1)[s0, s1] -> (d0 + s1, d1 + s0) // CHECK: #map3 = ()[s0] -> (s0 + 1) +// CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0) +// CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1) +// CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0) // CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) { cfgfunc @cfgfunc_with_ops(f32) { @@ -259,3 +262,23 @@ mlfunc @test_dimop(%arg0 : tensor<4x4x?xf32>) { return } + +// CHECK-LABEL: mlfunc @test_vector_transfer_ops(%arg0 +mlfunc @test_vector_transfer_ops(%arg0 : memref<?x?xf32>) { + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // CHECK: %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d0]]} : (memref<?x?xf32>, index, index) -> vector<128xf32> + %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : (memref<?x?xf32>, index, index) -> vector<128xf32> + // CHECK: %1 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d1d0]]} : (memref<?x?xf32>, index, index) -> vector<3x7xf32> + %1 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d1, d0)} : (memref<?x?xf32>, index, index) -> vector<3x7xf32> + // CHECK: %2 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: #[[map_proj_d0d1_d0]]} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32> + %2 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: (d0, d1)->(d0)} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32> + // CHECK: %3 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32> + %3 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: (d0, d1)->(d1)} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32> + // + // CHECK: vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d0]]} : vector<128xf32>, memref<?x?xf32>, index, index + vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index + // CHECK: vector_transfer_write %1, %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d1d0]]} : vector<3x7xf32>, memref<?x?xf32>, index, index + vector_transfer_write %1, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d1, d0)} : vector<3x7xf32>, memref<?x?xf32>, index, index + return +} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 75c812b9cd4..03474a8be57 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -286,3 +286,184 @@ bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): // expected-error@+1 {{requires the condition to have the same shape as arguments}} %r = "select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{expected 4 operand types but had 3}} + %0 = "vector_transfer_read"(%arg0, %c3, %c3, %c3) : (memref<?x?xf32>, index, index) -> vector<128xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires 3 operands}} + %0 = vector_transfer_read %arg0, %c3, %c3, %c3 : (memref<?x?xf32>, index, index) -> vector<128xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}} + %0 = vector_transfer_read %arg0, %c3, %c3 : (memref<?x?xf32>, index, index) -> vector<128xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}} + %0 = vector_transfer_read %arg0, %c3, %c3 {perm: (d0)->(d0)} : (memref<?x?xf32>, index, index) -> vector<128xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}} + %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0)->(d0)} : (memref<?x?xf32>, index, index) -> vector<128xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires a permutation_map with result dims of the same rank as the vector type}} + %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0, d1)} : (memref<?x?xf32>, index, index) -> vector<128xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires a permutation_map that is an actual permutation}} + %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + d1)} : (memref<?x?xf32>, index, index) -> vector<128xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires a permutation_map that is an actual permutation}} + %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + 1)} : (memref<?x?xf32>, index, index) -> vector<128xf32> +} +// ----- + +cfgfunc @test_vector_transfer_read(memref<?x?x?xf32>) { +bb0(%arg0 : memref<?x?x?xf32>): + %c3 = constant 3 : index + %cst = constant 3.0 : f32 + // expected-error@+1 {{requires a permutation_map that is a full column-rank permutation}} + %0 = vector_transfer_read %arg0, %c3, %c3, %c3 {permutation_map: (d0, d1, d2)->(d0, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<3x7xf32> +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{expected 5 operand types but had 4}} + %0 = "vector_transfer_write"(%cst, %arg0, %c3, %c3, %c3) : (vector<128xf32>, memref<?x?xf32>, index, index) -> () +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{requires number of operands and input types to match}} + vector_transfer_write %cst, %arg0, %c3, %c3, %c3 : vector<128xf32>, memref<?x?xf32>, index, index +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}} + vector_transfer_write %cst, %arg0, %c3, %c3 : vector<128xf32>, memref<?x?xf32>, index, index +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}} + vector_transfer_write %cst, %arg0, %c3, %c3 {perm: (d0)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}} + vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{requires a permutation_map with result dims of the same rank as the vector type}} + vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0, d1)} : vector<128xf32>, memref<?x?xf32>, index, index +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{requires a permutation_map that is an actual permutation}} + vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + d1)} : vector<128xf32>, memref<?x?xf32>, index, index +} + +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?xf32>) { +bb0(%arg0 : memref<?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<128 x f32>, 3.0> : vector<128 x f32> + // expected-error@+1 {{requires a permutation_map that is an actual permutation}} + vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + 1)} : vector<128xf32>, memref<?x?xf32>, index, index +} +// ----- + +cfgfunc @test_vector_transfer_write(memref<?x?x?xf32>) { +bb0(%arg0 : memref<?x?x?xf32>): + %c3 = constant 3 : index + %cst = constant splat<vector<3 x 7 x f32>, 3.0> : vector<3 x 7 x f32> + // expected-error@+1 {{requires a permutation_map that is a full column-rank permutation}} + vector_transfer_write %cst, %arg0, %c3, %c3, %c3 {permutation_map: (d0, d1, d2)->(d0, d0)} : vector<3x7xf32>, memref<?x?x?xf32>, index, index, index +} + + + diff --git a/mlir/test/Transforms/materialize_vectors.mlir b/mlir/test/Transforms/materialize_vectors.mlir index cc38442cb1a..93d17ea10d7 100644 --- a/mlir/test/Transforms/materialize_vectors.mlir +++ b/mlir/test/Transforms/materialize_vectors.mlir @@ -2,21 +2,25 @@ // RUN: mlir-opt %s -vectorize -virtual-vector-size 3 -virtual-vector-size 16 --test-fastest-varying=1 --test-fastest-varying=0 -materialize-vectors -vector-size=8 | FileCheck %s -check-prefix=VEC2DTO1D // RUN: mlir-opt %s -vectorize -virtual-vector-size 3 -virtual-vector-size 32 --test-fastest-varying=1 --test-fastest-varying=0 -materialize-vectors -vector-size=3 -vector-size=16 | FileCheck %s -check-prefix=VEC2DTO2D +// Capture permutation maps used in vectorization. +// VEC1DTO1D-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1) +// VEC2DTO1D-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1) +// VEC2DTO2D-DAG: #[[map_proj_d0d1_d0d1:map[0-9]+]] = (d0, d1) -> (d0, d1) + // vector<32xf32> -> vector<8xf32> -// VEC1DTO1D: [[MAP0:#.*]] = (d0, d1) -> (d0, d1) -// VEC1DTO1D: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8) -// VEC1DTO1D: [[MAP2:#.*]] = (d0, d1) -> (d0, d1 + 16) -// VEC1DTO1D: [[MAP3:#.*]] = (d0, d1) -> (d0, d1 + 24) +// VEC1DTO1D-DAG: [[MAP0:#.*]] = (d0, d1) -> (d0, d1) +// VEC1DTO1D-DAG: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8) +// VEC1DTO1D-DAG: [[MAP2:#.*]] = (d0, d1) -> (d0, d1 + 16) +// VEC1DTO1D-DAG: [[MAP3:#.*]] = (d0, d1) -> (d0, d1 + 24) // vector<3x16xf32> -> vector<8xf32> -// VEC2DTO1D: [[MAP0:#.*]] = (d0, d1) -> (d0, d1) -// VEC2DTO1D: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8) -// VEC2DTO1D: [[MAP2:#.*]] = (d0, d1) -> (d0 + 1, d1) -// VEC2DTO1D: [[MAP3:#.*]] = (d0, d1) -> (d0 + 1, d1 + 8) -// VEC2DTO1D: [[MAP4:#.*]] = (d0, d1) -> (d0 + 2, d1) -// VEC2DTO1D: [[MAP5:#.*]] = (d0, d1) -> (d0 + 2, d1 + 8) +// VEC2DTO1D-DAG: [[MAP0:#.*]] = (d0, d1) -> (d0, d1) +// VEC2DTO1D-DAG: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8) +// VEC2DTO1D-DAG: [[MAP2:#.*]] = (d0, d1) -> (d0 + 1, d1) +// VEC2DTO1D-DAG: [[MAP3:#.*]] = (d0, d1) -> (d0 + 1, d1 + 8) +// VEC2DTO1D-DAG: [[MAP4:#.*]] = (d0, d1) -> (d0 + 2, d1) +// VEC2DTO1D-DAG: [[MAP5:#.*]] = (d0, d1) -> (d0 + 2, d1 + 8) // vector<3x32xf32> -> vector<3x16xf32> -// VEC2DTO2D: [[MAP0:#.*]] = (d0, d1) -> (d0, d1) -// VEC2DTO2D: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 16) +// VEC2DTO2D-DAG: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 16) mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { %A = alloc (%M, %N) : memref<?x?xf32, 0> %B = alloc (%M, %N) : memref<?x?xf32, 0> @@ -32,13 +36,13 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { // VEC1DTO1D: [[CST2:%.*]] = constant splat<vector<8xf32>, 1.000000e+00> : vector<8xf32> // VEC1DTO1D: [[CST3:%.*]] = constant splat<vector<8xf32>, 1.000000e+00> : vector<8xf32> // VEC1DTO1D: [[VAL0:%.*]] = affine_apply [[MAP0]]{{.*}} - // VEC1DTO1D: "vector_transfer_write"([[CST0]], {{.*}}, [[VAL0]]#0, [[VAL0]]#1) : (vector<8xf32> + // VEC1DTO1D: vector_transfer_write [[CST0]], {{.*}}, [[VAL0]]#0, [[VAL0]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32> // VEC1DTO1D: [[VAL1:%.*]] = affine_apply [[MAP1]]{{.*}} - // VEC1DTO1D: "vector_transfer_write"([[CST1]], {{.*}}, [[VAL1]]#0, [[VAL1]]#1) : (vector<8xf32> + // VEC1DTO1D: vector_transfer_write [[CST1]], {{.*}}, [[VAL1]]#0, [[VAL1]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32> // VEC1DTO1D: [[VAL2:%.*]] = affine_apply [[MAP2]]{{.*}} - // VEC1DTO1D:"vector_transfer_write"([[CST2]], {{.*}}, [[VAL2]]#0, [[VAL2]]#1) : (vector<8xf32> + // VEC1DTO1D:vector_transfer_write [[CST2]], {{.*}}, [[VAL2]]#0, [[VAL2]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32> // VEC1DTO1D: [[VAL3:%.*]] = affine_apply [[MAP3]]{{.*}} - // VEC1DTO1D:"vector_transfer_write"([[CST3]], {{.*}}, [[VAL3]]#0, [[VAL3]]#1) : (vector<8xf32> + // VEC1DTO1D:vector_transfer_write [[CST3]], {{.*}}, [[VAL3]]#0, [[VAL3]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32> // store %f1, %A[%i0, %i1] : memref<?x?xf32, 0> } @@ -49,10 +53,10 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { // VEC2DTO1D does (3x4)x unrolling. // VEC2DTO1D-COUNT-6: {{.*}} = constant splat<vector<8xf32>, 1.000000e+00> : vector<8xf32> // VEC2DTO1D: [[VAL0:%.*]] = affine_apply [[MAP0]]{{.*}} - // VEC2DTO1D: "vector_transfer_write"({{.*}}, [[VAL0]]#0, [[VAL0]]#1) : (vector<8xf32> + // VEC2DTO1D: vector_transfer_write {{.*}}, [[VAL0]]#0, [[VAL0]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32> // ... 4 other interleaved affine_apply, vector_transfer_write // VEC2DTO1D: [[VAL5:%.*]] = affine_apply [[MAP5]]{{.*}} - // VEC2DTO1D: "vector_transfer_write"({{.*}}, [[VAL5]]#0, [[VAL5]]#1) : (vector<8xf32> + // VEC2DTO1D: vector_transfer_write {{.*}}, [[VAL5]]#0, [[VAL5]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32> // store %f2, %B[%i2, %i3] : memref<?x?xf32, 0> } @@ -60,19 +64,19 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { for %i4 = 0 to %M { for %i5 = 0 to %N { // VEC2DTO2D: %7 = affine_apply #map0(%i4, %i5) - // VEC2DTO2D: %8 = "vector_transfer_read"(%0, %7#0, %7#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32> + // VEC2DTO2D: %8 = vector_transfer_read %0, %7#0, %7#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32> // VEC2DTO2D: %9 = affine_apply #map1(%i4, %i5) - // VEC2DTO2D: %10 = "vector_transfer_read"(%0, %9#0, %9#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32> + // VEC2DTO2D: %10 = vector_transfer_read %0, %9#0, %9#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32> // VEC2DTO2D: %11 = affine_apply #map0(%i4, %i5) - // VEC2DTO2D: %12 = "vector_transfer_read"(%1, %11#0, %11#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32> + // VEC2DTO2D: %12 = vector_transfer_read %1, %11#0, %11#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32> // VEC2DTO2D: %13 = affine_apply #map1(%i4, %i5) - // VEC2DTO2D: %14 = "vector_transfer_read"(%1, %13#0, %13#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32> + // VEC2DTO2D: %14 = vector_transfer_read %1, %13#0, %13#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32> // VEC2DTO2D: %15 = addf %8, %12 : vector<3x16xf32> // VEC2DTO2D: %16 = addf %10, %14 : vector<3x16xf32> // VEC2DTO2D: %17 = affine_apply #map0(%i4, %i5) - // VEC2DTO2D: "vector_transfer_write"(%15, %2, %17#0, %17#1) : (vector<3x16xf32>, memref<?x?xf32>, index, index) -> () + // VEC2DTO2D: vector_transfer_write %15, %2, %17#0, %17#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<3x16xf32>, memref<?x?xf32>, index, index // VEC2DTO2D: %18 = affine_apply #map1(%i4, %i5) - // VEC2DTO2D: "vector_transfer_write"(%16, %2, %18#0, %18#1) : (vector<3x16xf32>, memref<?x?xf32>, index, index) -> () + // VEC2DTO2D: vector_transfer_write %16, %2, %18#0, %18#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<3x16xf32>, memref<?x?xf32>, index, index // %a5 = load %A[%i4, %i5] : memref<?x?xf32, 0> %b5 = load %B[%i4, %i5] : memref<?x?xf32, 0> diff --git a/mlir/test/Transforms/vectorize.mlir b/mlir/test/Transforms/vectorize.mlir index 824e8167b06..3533ef76013 100644 --- a/mlir/test/Transforms/vectorize.mlir +++ b/mlir/test/Transforms/vectorize.mlir @@ -5,6 +5,14 @@ // RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=2 | FileCheck %s -check-prefix=VEC2D_OT // RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 64 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s -check-prefix=VEC3D +// Permutation maps used in vectorization. +// VEC1D: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1) +// VEC2D: #[[map_proj_d0d1_d0d1:map[0-9]+]] = (d0, d1) -> (d0, d1) +// VEC2D_T: #[[map_proj_d0d1d2_d1d2:map[0-9]+]] = (d0, d1, d2) -> (d1, d2) +// VEC2D_O: #[[map_proj_d0d1d2_d1d2:map[0-9]+]] = (d0, d1, d2) -> (d1, d2) +// VEC2D_OT: #[[map_proj_d0d1d2_d1d2:map[0-9]+]] = (d0, d1, d2) -> (d1, d2) +// VEC3D: #[[map_proj_d0d1d2_d0d1d2:map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2) + #map0 = (d0) -> (d0) #map1 = (d0, d1) -> (d0, d1) #map1_t = (d0, d1) -> (d1, d0) @@ -15,6 +23,7 @@ #mapadd2 = (d0) -> (d0 + 2) #mapadd3 = (d0) -> (d0 + 3) #set0 = (i) : (i >= 0) + // Maps introduced to vectorize fastest varying memory index. mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) { // VEC1D-DAG: [[C0:%[a-z0-9_]+]] = constant 0 : index @@ -26,26 +35,26 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) { %P = dim %B, 2 : memref<?x?x?xf32> %cst0 = constant 0 : index // VEC1D:for [[IV0:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 -// VEC1D-NEXT: {{.*}} = "vector_transfer_read"(%arg0, [[C0]], [[C0]]) : (memref<?x?xf32>, index, index) -> vector<128xf32> +// VEC1D-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index) -> vector<128xf32> // For this simple loop, the current transformation generates: // for %i0 = 0 to %0 step 128 { -// %3 = "vector_transfer_read"(%arg0, %c0_0, %c0_0) : (memref<?x?xf32>, index, index) -> vector<128xf32> +// %3 = vector_transfer_read %arg0, %c0_0, %c0_0 : (memref<?x?xf32>, index, index) -> vector<128xf32> // } - for %i0 = 0 to %M { // vectorized due to scalar -> vector + for %i0 = 0 to %M { // vectorized due to scalar -> vector %a0 = load %A[%cst0, %cst0] : memref<?x?xf32> } // VEC1D:for {{.*}} [[ARG_M]] { - for %i1 = 0 to %M { // not vectorized + for %i1 = 0 to %M { // not vectorized %a1 = load %A[%i1, %i1] : memref<?x?xf32> } // VEC1D: for %i{{[0-9]*}} = 0 to [[ARG_M]] { - for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 + for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 %r2 = affine_apply (d0) -> (d0) (%i2) %a2 = load %A[%r2#0, %cst0] : memref<?x?xf32> } // VEC1D:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 // VEC1D-NEXT: [[APP3:%[a-zA-Z0-9]+]] = affine_apply {{.*}}[[IV3]] -// VEC1D-NEXT: {{.*}} = "vector_transfer_read"(%arg0, [[C0]], [[APP3]]) : {{.*}} -> vector<128xf32> +// VEC1D-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[APP3]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> for %i3 = 0 to %M { // vectorized %r3 = affine_apply (d0) -> (d0) (%i3) %a3 = load %A[%cst0, %r3#0] : memref<?x?xf32> @@ -53,8 +62,8 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) { // VEC1D:for [[IV4:%[i0-9]+]] = 0 to [[ARG_M]] step 128 { // VEC1D-NEXT: for [[IV5:%[i0-9]*]] = 0 to [[ARG_N]] { // VEC1D-NEXT: [[APP5:%[0-9]+]] = affine_apply {{.*}}([[IV4]], [[IV5]]) -// VEC1D-NEXT: {{.*}} = "vector_transfer_read"(%arg0, [[APP5]]#0, [[APP5]]#1) : {{.*}} -> vector<128xf32> - for %i4 = 0 to %M { // vectorized +// VEC1D-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP5]]#0, [[APP5]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> + for %i4 = 0 to %M { // vectorized for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1 %r5 = affine_apply #map1_t (%i4, %i5) %a5 = load %A[%r5#0, %r5#1] : memref<?x?xf32> @@ -71,7 +80,7 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) { // VEC1D:for [[IV8:%[i0-9]+]] = 0 to [[ARG_M]] step 128 // VEC1D-NEXT: for [[IV9:%[i0-9]*]] = 0 to [[ARG_N]] { // VEC1D-NEXT: [[APP9:%[0-9]+]] = affine_apply {{.*}}([[IV8]], [[IV9]]) -// VEC1D-NEXT: {{.*}} = "vector_transfer_read"(%arg0, [[APP9]]#0, [[APP9]]#1) : {{.*}} -> vector<128xf32> +// VEC1D-NEXT: {{.*}} = vector_transfer_read %arg0, [[APP9]]#0, [[APP9]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> for %i8 = 0 to %M { // vectorized for %i9 = 0 to %N { %r9 = affine_apply #map3 (%i8, %i9) @@ -80,8 +89,8 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) { } // VEC1D: for [[IV10:%[i0-9]*]] = 0 to %{{[0-9]*}} { // VEC1D: for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} { - for %i10 = 0 to %M { // not vectorized, need per load transposes - for %i11 = 0 to %N { // not vectorized, need per load transposes + for %i10 = 0 to %M { // not vectorized, need per load transposes + for %i11 = 0 to %N { // not vectorized, need per load transposes %r11 = affine_apply #map1 (%i10, %i11) %a11 = load %A[%r11#0, %r11#1] : memref<?x?xf32> %r12 = affine_apply #map1_t (%i10, %i11) @@ -112,7 +121,7 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) { } // VEC1D: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} { // VEC1D: for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128 -// VEC1D: {{.*}} = "vector_transfer_read"(%arg0, [[C0]], [[C0]]) : {{.*}} -> vector<128xf32> +// VEC1D: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32> for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17 for %i18 = 0 to %M { // vectorized due to scalar -> vector %a18 = load %A[%cst0, %cst0] : memref<?x?xf32> @@ -211,22 +220,22 @@ mlfunc @vec2d_imperfectly_nested(%A : memref<?x?x?xf32>) { // VEC2D_T: for %i0 = 0 to %0 step 32 { // VEC2D_T: for %i1 = 0 to %1 step 256 { // VEC2D_T: for %i2 = 0 to %2 { - // VEC2D_T: %3 = "vector_transfer_read"(%arg0, %i2, %i1, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> + // VEC2D_T: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> // VEC2D_T: for %i3 = 0 to %1 { // VEC2D_T: for %i4 = 0 to %2 step 256 { - // VEC2D_T: %4 = "vector_transfer_read"(%arg0, %i3, %i4, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> + // VEC2D_T: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> // VEC2D_T: for %i5 = 0 to %2 step 256 { - // VEC2D_T: %5 = "vector_transfer_read"(%arg0, %i3, %i5, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> + // VEC2D_T: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> // // VEC2D_OT: for %i0 = 0 to %0 step 32 { // VEC2D_OT: for %i1 = 0 to %1 { // VEC2D_OT: for %i2 = 0 to %2 step 256 { - // VEC2D_OT: %3 = "vector_transfer_read"(%arg0, %i2, %i1, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> + // VEC2D_OT: %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> // VEC2D_OT: for %i3 = 0 to %1 step 256 { // VEC2D_OT: for %i4 = 0 to %2 { - // VEC2D_OT: %4 = "vector_transfer_read"(%arg0, %i3, %i4, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> + // VEC2D_OT: %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> // VEC2D_OT: for %i5 = 0 to %2 { - // VEC2D_OT: %5 = "vector_transfer_read"(%arg0, %i3, %i5, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> + // VEC2D_OT: %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32> for %i0 = 0 to %0 { for %i1 = 0 to %1 { for %i2 = 0 to %2 { @@ -254,7 +263,7 @@ mlfunc @vec3d(%A : memref<?x?x?xf32>) { // VEC3D: for %i2 = 0 to %0 step 32 { // VEC3D: for %i3 = 0 to %1 step 64 { // VEC3D: for %i4 = 0 to %2 step 256 { - // VEC3D: %3 = "vector_transfer_read"(%arg0, %i2, %i3, %i4) : (memref<?x?x?xf32>, index, index, index) -> vector<32x64x256xf32> + // VEC3D: %3 = vector_transfer_read %arg0, %i2, %i3, %i4 {permutation_map: #[[map_proj_d0d1d2_d0d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x64x256xf32> for %t0 = 0 to %0 { for %t1 = 0 to %0 { for %i0 = 0 to %0 { @@ -278,9 +287,9 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { for %i0 = 0 to %M { for %i1 = 0 to %N { // VEC1D: [[C1:%.*]] = constant splat<vector<128xf32>, 1.000000e+00> : vector<128xf32> - // VEC1D: "vector_transfer_write"([[C1]], {{.*}}) : (vector<128xf32>, memref<?x?xf32>, index, index) -> () + // VEC1D: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref<?x?xf32>, index, index // VEC2D: [[C1:%.*]] = constant splat<vector<32x256xf32>, 1.000000e+00> : vector<32x256xf32> - // VEC2D: "vector_transfer_write"([[C1]], {{.*}}) : (vector<32x256xf32>, memref<?x?xf32>, index, index) -> () + // VEC2D: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref<?x?xf32>, index, index // non-scoped %f1 store %f1, %A[%i0, %i1] : memref<?x?xf32, 0> } @@ -288,9 +297,9 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { for %i2 = 0 to %M { for %i3 = 0 to %N { // VEC1D: [[C3:%.*]] = constant splat<vector<128xf32>, 2.000000e+00> : vector<128xf32> - // VEC1D: "vector_transfer_write"([[C3]], {{.*}}) : (vector<128xf32>, memref<?x?xf32>, index, index) -> () + // VEC1D: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref<?x?xf32>, index, index // VEC2D: [[C3:%.*]] = constant splat<vector<32x256xf32>, 2.000000e+00> : vector<32x256xf32> - // VEC2D: "vector_transfer_write"([[C3]], {{.*}}) : (vector<32x256xf32>, memref<?x?xf32>, index, index) -> () + // VEC2D: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref<?x?xf32>, index, index // non-scoped %f2 store %f2, %B[%i2, %i3] : memref<?x?xf32, 0> } @@ -298,25 +307,25 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 { for %i4 = 0 to %M { for %i5 = 0 to %N { // - // VEC1D: [[A5:%.*]] = "vector_transfer_read"(%0, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<128xf32> - // VEC1D: [[B5:%.*]] = "vector_transfer_read"(%1, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<128xf32> + // VEC1D: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index) -> vector<128xf32> + // VEC1D: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index) -> vector<128xf32> // VEC1D: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<128xf32> // VEC1D: [[SPLAT1:%.*]] = constant splat<vector<128xf32>, 1.000000e+00> : vector<128xf32> // VEC1D: [[S6:%.*]] = addf [[S5]], [[SPLAT1]] : vector<128xf32> // VEC1D: [[SPLAT2:%.*]] = constant splat<vector<128xf32>, 2.000000e+00> : vector<128xf32> // VEC1D: [[S7:%.*]] = addf [[S5]], [[SPLAT2]] : vector<128xf32> // VEC1D: [[S8:%.*]] = addf [[S7]], [[S6]] : vector<128xf32> - // VEC1D: "vector_transfer_write"([[S8]], {{.*}}) : (vector<128xf32>, memref<?x?xf32>, index, index) -> () + // VEC1D: vector_transfer_write [[S8]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref<?x?xf32>, index, index // - // VEC2D: [[A5:%.*]] = "vector_transfer_read"(%0, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<32x256xf32> - // VEC2D: [[B5:%.*]] = "vector_transfer_read"(%1, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<32x256xf32> + // VEC2D: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<32x256xf32> + // VEC2D: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<32x256xf32> // VEC2D: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<32x256xf32> // VEC2D: [[SPLAT1:%.*]] = constant splat<vector<32x256xf32>, 1.000000e+00> : vector<32x256xf32> // VEC2D: [[S6:%.*]] = addf [[S5]], [[SPLAT1]] : vector<32x256xf32> // VEC2D: [[SPLAT2:%.*]] = constant splat<vector<32x256xf32>, 2.000000e+00> : vector<32x256xf32> // VEC2D: [[S7:%.*]] = addf [[S5]], [[SPLAT2]] : vector<32x256xf32> // VEC2D: [[S8:%.*]] = addf [[S7]], [[S6]] : vector<32x256xf32> - // VEC2D: "vector_transfer_write"([[S8]], {{.*}}) : (vector<32x256xf32>, memref<?x?xf32>, index, index) -> () + // VEC2D: vector_transfer_write [[S8]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref<?x?xf32>, index, index // %a5 = load %A[%i4, %i5] : memref<?x?xf32, 0> %b5 = load %B[%i4, %i5] : memref<?x?xf32, 0> |

