summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp58
1 files changed, 57 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index ebb0fd75753..ff516d7ef29 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -340,6 +340,28 @@ public:
}
};
+template <typename LinalgOp>
+static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
+ return SmallVector<Type, 4>{op->getOperandTypes()};
+}
+
+template <>
+SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
+ auto ctx = op->getContext();
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+
+ SmallVector<Type, 4> result;
+ result.reserve(numLoops + op->getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i) {
+ result.push_back(IndexType::get(ctx));
+ }
+ for (auto type : op->getOperandTypes()) {
+ result.push_back(type);
+ }
+ return result;
+}
+
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
// If the library function does not exist, insert a declaration.
template <typename LinalgOp>
@@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
return fnNameAttr;
}
- SmallVector<Type, 4> inputTypes(op->getOperandTypes());
+ SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
@@ -430,6 +452,40 @@ public:
}
};
+/// Conversion pattern specialization for IndexedGenericOp.
+template <>
+class LinalgOpConversion<IndexedGenericOp>
+ : public OpRewritePattern<IndexedGenericOp> {
+public:
+ using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(IndexedGenericOp op,
+ PatternRewriter &rewriter) const override {
+ auto libraryCallName =
+ getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
+ if (!libraryCallName)
+ return this->matchFailure();
+
+ // TODO(pifon, ntv): Use induction variables values instead of zeros, when
+ // IndexedGenericOp is tiled.
+ auto zero = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+ SmallVector<Value *, 4> operands;
+ operands.reserve(numLoops + op.getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i) {
+ operands.push_back(zero);
+ }
+ for (auto operand : op.getOperands()) {
+ operands.push_back(operand);
+ }
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
+ ArrayRef<Type>{}, operands);
+ return this->matchSuccess();
+ }
+};
+
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
/// This interplays together with TransposeOpConversion and
OpenPOWER on IntegriCloud