summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlexander Belyaev <pifon@google.com>2019-12-02 06:30:19 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-02 06:30:52 -0800
commit9630fcbc52dbd83dfa5be7d757a5abd41a30a652 (patch)
tree41b22adb50cb51b306aebc094f0568fd49c98ea8
parentd5e627f84b440cb4dd30802930629ea970dd4342 (diff)
downloadbcm5719-llvm-9630fcbc52dbd83dfa5be7d757a5abd41a30a652.tar.gz
bcm5719-llvm-9630fcbc52dbd83dfa5be7d757a5abd41a30a652.zip
Lower linalg.indexed_generic with libcall to LLVM.
PiperOrigin-RevId: 283328994
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp58
-rw-r--r--mlir/test/Dialect/Linalg/llvm.mlir26
2 files changed, 81 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index ebb0fd75753..ff516d7ef29 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -340,6 +340,28 @@ public:
}
};
+template <typename LinalgOp>
+static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) {
+ return SmallVector<Type, 4>{op->getOperandTypes()};
+}
+
+template <>
+SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) {
+ auto ctx = op->getContext();
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+
+ SmallVector<Type, 4> result;
+ result.reserve(numLoops + op->getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i) {
+ result.push_back(IndexType::get(ctx));
+ }
+ for (auto type : op->getOperandTypes()) {
+ result.push_back(type);
+ }
+ return result;
+}
+
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
// If the library function does not exist, insert a declaration.
template <typename LinalgOp>
@@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
return fnNameAttr;
}
- SmallVector<Type, 4> inputTypes(op->getOperandTypes());
+ SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op));
assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that "
"have void return types");
@@ -430,6 +452,40 @@ public:
}
};
+/// Conversion pattern specialization for IndexedGenericOp.
+template <>
+class LinalgOpConversion<IndexedGenericOp>
+ : public OpRewritePattern<IndexedGenericOp> {
+public:
+ using OpRewritePattern<IndexedGenericOp>::OpRewritePattern;
+
+ PatternMatchResult matchAndRewrite(IndexedGenericOp op,
+ PatternRewriter &rewriter) const override {
+ auto libraryCallName =
+ getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter);
+ if (!libraryCallName)
+ return this->matchFailure();
+
+ // TODO(pifon, ntv): Use induction variables values instead of zeros, when
+ // IndexedGenericOp is tiled.
+ auto zero = rewriter.create<mlir::ConstantOp>(
+ op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
+ auto indexedGenericOp = cast<IndexedGenericOp>(op);
+ auto numLoops = indexedGenericOp.getNumLoops();
+ SmallVector<Value *, 4> operands;
+ operands.reserve(numLoops + op.getNumOperands());
+ for (unsigned i = 0; i < numLoops; ++i) {
+ operands.push_back(zero);
+ }
+ for (auto operand : op.getOperands()) {
+ operands.push_back(operand);
+ }
+ rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(),
+ ArrayRef<Type>{}, operands);
+ return this->matchSuccess();
+ }
+};
+
/// A non-conversion rewrite pattern kicks in to convert CopyOp with
/// permutations into a sequence of TransposeOp and permutation-free CopyOp.
/// This interplays together with TransposeOpConversion and
diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir
index 24ce8e36e8a..9b477a3b901 100644
--- a/mlir/test/Dialect/Linalg/llvm.mlir
+++ b/mlir/test/Dialect/Linalg/llvm.mlir
@@ -141,7 +141,7 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
n_views = [2, 1],
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
- library_call = "some_external_function_name_for_vector_outerproduct_matmul"
+ library_call = "external_outerproduct_matmul"
}
!vector_type_A = type vector<4xf32>
@@ -162,7 +162,7 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
return
}
// CHECK-LABEL: func @matmul_vec_impl(
-// CHECK: llvm.call @some_external_function_name_for_vector_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
+// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
// LLVM-LOOPS-LABEL: func @matmul_vec_impl(
// LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>">
@@ -172,3 +172,25 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C
// LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]">
// LLVM-LOOPS-NEXT: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>">
// LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
+
+
+#indexed_matmul_trait = {
+ n_views = [2, 1],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ indexing_maps = #matmul_accesses,
+ library_call = "external_indexed_outerproduct_matmul"
+}
+func @matmul_vec_indexed(%A: !matrix_type_A,
+ %B: !matrix_type_B,
+ %C: !matrix_type_C) {
+ linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
+ ^bb0(%i: index, %j: index, %k: index,
+ %a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
+ %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
+ linalg.yield %d: !vector_type_C
+ } : !matrix_type_A, !matrix_type_B, !matrix_type_C
+ return
+}
+// CHECK-LABEL: func @matmul_vec_indexed(
+// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> ()
OpenPOWER on IntegriCloud