diff options
Diffstat (limited to 'mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp')
| -rw-r--r-- | mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 58 |
1 files changed, 57 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index ebb0fd75753..ff516d7ef29 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -340,6 +340,28 @@ public: } }; +template <typename LinalgOp> +static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) { + return SmallVector<Type, 4>{op->getOperandTypes()}; +} + +template <> +SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) { + auto ctx = op->getContext(); + auto indexedGenericOp = cast<IndexedGenericOp>(op); + auto numLoops = indexedGenericOp.getNumLoops(); + + SmallVector<Type, 4> result; + result.reserve(numLoops + op->getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + result.push_back(IndexType::get(ctx)); + } + for (auto type : op->getOperandTypes()) { + result.push_back(type); + } + return result; +} + // Get a SymbolRefAttr containing the library function name for the LinalgOp. // If the library function does not exist, insert a declaration. template <typename LinalgOp> @@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, return fnNameAttr; } - SmallVector<Type, 4> inputTypes(op->getOperandTypes()); + SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op)); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); @@ -430,6 +452,40 @@ public: } }; +/// Conversion pattern specialization for IndexedGenericOp. +template <> +class LinalgOpConversion<IndexedGenericOp> + : public OpRewritePattern<IndexedGenericOp> { +public: + using OpRewritePattern<IndexedGenericOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const override { + auto libraryCallName = + getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter); + if (!libraryCallName) + return this->matchFailure(); + + // TODO(pifon, ntv): Use induction variables values instead of zeros, when + // IndexedGenericOp is tiled. + auto zero = rewriter.create<mlir::ConstantOp>( + op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto indexedGenericOp = cast<IndexedGenericOp>(op); + auto numLoops = indexedGenericOp.getNumLoops(); + SmallVector<Value *, 4> operands; + operands.reserve(numLoops + op.getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + operands.push_back(zero); + } + for (auto operand : op.getOperands()) { + operands.push_back(operand); + } + rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(), + ArrayRef<Type>{}, operands); + return this->matchSuccess(); + } +}; + /// A non-conversion rewrite pattern kicks in to convert CopyOp with /// permutations into a sequence of TransposeOp and permutation-free CopyOp. /// This interplays together with TransposeOpConversion and |

