diff options
| -rw-r--r-- | mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 58 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/llvm.mlir | 26 |
2 files changed, 81 insertions, 3 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index ebb0fd75753..ff516d7ef29 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -340,6 +340,28 @@ public: } }; +template <typename LinalgOp> +static SmallVector<Type, 4> ExtractOperandTypes(Operation *op) { + return SmallVector<Type, 4>{op->getOperandTypes()}; +} + +template <> +SmallVector<Type, 4> ExtractOperandTypes<IndexedGenericOp>(Operation *op) { + auto ctx = op->getContext(); + auto indexedGenericOp = cast<IndexedGenericOp>(op); + auto numLoops = indexedGenericOp.getNumLoops(); + + SmallVector<Type, 4> result; + result.reserve(numLoops + op->getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + result.push_back(IndexType::get(ctx)); + } + for (auto type : op->getOperandTypes()) { + result.push_back(type); + } + return result; +} + // Get a SymbolRefAttr containing the library function name for the LinalgOp. // If the library function does not exist, insert a declaration. template <typename LinalgOp> @@ -359,7 +381,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, return fnNameAttr; } - SmallVector<Type, 4> inputTypes(op->getOperandTypes()); + SmallVector<Type, 4> inputTypes(ExtractOperandTypes<LinalgOp>(op)); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); @@ -430,6 +452,40 @@ public: } }; +/// Conversion pattern specialization for IndexedGenericOp. +template <> +class LinalgOpConversion<IndexedGenericOp> + : public OpRewritePattern<IndexedGenericOp> { +public: + using OpRewritePattern<IndexedGenericOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const override { + auto libraryCallName = + getLibraryCallSymbolRef<IndexedGenericOp>(op, rewriter); + if (!libraryCallName) + return this->matchFailure(); + + // TODO(pifon, ntv): Use induction variables values instead of zeros, when + // IndexedGenericOp is tiled. + auto zero = rewriter.create<mlir::ConstantOp>( + op.getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto indexedGenericOp = cast<IndexedGenericOp>(op); + auto numLoops = indexedGenericOp.getNumLoops(); + SmallVector<Value *, 4> operands; + operands.reserve(numLoops + op.getNumOperands()); + for (unsigned i = 0; i < numLoops; ++i) { + operands.push_back(zero); + } + for (auto operand : op.getOperands()) { + operands.push_back(operand); + } + rewriter.replaceOpWithNewOp<mlir::CallOp>(op, libraryCallName.getValue(), + ArrayRef<Type>{}, operands); + return this->matchSuccess(); + } +}; + /// A non-conversion rewrite pattern kicks in to convert CopyOp with /// permutations into a sequence of TransposeOp and permutation-free CopyOp. /// This interplays together with TransposeOpConversion and diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index 24ce8e36e8a..9b477a3b901 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -141,7 +141,7 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a n_views = [2, 1], iterator_types = ["parallel", "parallel", "reduction"], indexing_maps = #matmul_accesses, - library_call = "some_external_function_name_for_vector_outerproduct_matmul" + library_call = "external_outerproduct_matmul" } !vector_type_A = type vector<4xf32> @@ -162,7 +162,7 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C return } // CHECK-LABEL: func @matmul_vec_impl( -// CHECK: llvm.call @some_external_function_name_for_vector_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> () +// CHECK: llvm.call @external_outerproduct_matmul(%{{.*}}) : (!llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> () // LLVM-LOOPS-LABEL: func @matmul_vec_impl( // LLVM-LOOPS: llvm.shufflevector {{.*}} [0 : i32, 0 : i32, 0 : i32, 0 : i32] : !llvm<"<4 x float>">, !llvm<"<4 x float>"> @@ -172,3 +172,25 @@ func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C // LLVM-LOOPS-NEXT: llvm.extractvalue {{.*}}[3] : !llvm<"[4 x <4 x float>]"> // LLVM-LOOPS-NEXT: "llvm.intr.fmuladd"({{.*}}) : (!llvm<"<4 x float>">, !llvm<"<4 x float>">, !llvm<"<4 x float>">) -> !llvm<"<4 x float>"> // LLVM-LOOPS-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]"> + + +#indexed_matmul_trait = { + n_views = [2, 1], + iterator_types = ["parallel", "parallel", "reduction"], + indexing_maps = #matmul_accesses, + library_call = "external_indexed_outerproduct_matmul" +} +func @matmul_vec_indexed(%A: !matrix_type_A, + %B: !matrix_type_B, + %C: !matrix_type_C) { + linalg.indexed_generic #indexed_matmul_trait %A, %B, %C { + ^bb0(%i: index, %j: index, %k: index, + %a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C): + %d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B + linalg.yield %d: !vector_type_C + } : !matrix_type_A, !matrix_type_B, !matrix_type_C + return +} +// CHECK-LABEL: func @matmul_vec_indexed( +// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: llvm.call @external_indexed_outerproduct_matmul(%[[ZERO]], %[[ZERO]], %[[ZERO]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ <4 x float>*, <4 x float>*, i64, [2 x i64], [2 x i64] }*">, !llvm<"{ [4 x <4 x float>]*, [4 x <4 x float>]*, i64, [2 x i64], [2 x i64] }*">) -> () |

