summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/VectorOps/VectorOps.td32
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp126
-rw-r--r--mlir/lib/Dialect/VectorOps/VectorOps.cpp17
-rw-r--r--mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir38
-rw-r--r--mlir/test/Dialect/VectorOps/invalid.mlir8
-rw-r--r--mlir/test/Dialect/VectorOps/ops.mlir7
-rw-r--r--mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h8
-rw-r--r--mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp11
8 files changed, 245 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
index e031d7cfb8c..401e424c862 100644
--- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
+++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td
@@ -987,4 +987,36 @@ def Vector_TupleGetOp :
}];
}
+def Vector_PrintOp :
+ Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> {
+ let summary = "print operation (for testing and debugging)";
+ let description = [{
+ Prints the source vector (or scalar) to stdout in human readable
+ format (for testing and debugging). No return value.
+
+ Examples:
+ ```
+ %0 = constant 0.0 : f32
+ %1 = vector.broadcast %0 : f32 to vector<4xf32>
+ vector.print %1 : vector<4xf32>
+
+ when lowered to LLVM, the vector print is unrolled into
+ elementary printing method calls that at runtime will yield
+
+ ( 0.0, 0.0, 0.0, 0.0 )
+
+ on stdout when linked with a small runtime support library,
+ which only needs to provide a few printing methods (single
+ value for all data types, opening/closing bracket, comma,
+ newline).
+ ```
+ }];
+ let verifier = ?;
+ let extraClassDeclaration = [{
+ Type getPrintType() {
+ return source()->getType();
+ }
+ }];
+}
+
#endif // VECTOR_OPS
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 71bed9516ca..416cb4c99a3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -612,14 +612,136 @@ public:
}
};
+class VectorPrintOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorPrintOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::PrintOp::getOperationName(), context,
+ typeConverter) {}
+
+ // Proof-of-concept lowering implementation that relies on a small
+ // runtime support library, which only needs to provide a few
+ // printing methods (single value for all data types, opening/closing
+ // bracket, comma, newline). The lowering fully unrolls a vector
+ // in terms of these elementary printing operations. The advantage
+ // of this approach is that the library can remain unaware of all
+ // low-level implementation details of vectors while still supporting
+ // output of any shaped and dimensioned vector. Due to full unrolling,
+ // this approach is less suited for very large vectors though.
+ //
+ // TODO(ajcbik): rely solely on libc in future? something else?
+ //
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto printOp = cast<vector::PrintOp>(op);
+ auto adaptor = vector::PrintOpOperandAdaptor(operands);
+ Type printType = printOp.getPrintType();
+
+ if (lowering.convertType(printType) == nullptr)
+ return matchFailure();
+
+ // Make sure element type has runtime support (currently just Float/Double).
+ VectorType vectorType = printType.dyn_cast<VectorType>();
+ Type eltType = vectorType ? vectorType.getElementType() : printType;
+ int64_t rank = vectorType ? vectorType.getRank() : 0;
+ Operation *printer;
+ if (eltType.isF32())
+ printer = getPrintFloat(op);
+ else if (eltType.isF64())
+ printer = getPrintDouble(op);
+ else
+ return matchFailure();
+
+ // Unroll vector into elementary print calls.
+ emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
+ emitCall(rewriter, op->getLoc(), getPrintNewline(op));
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+
+private:
+ void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
+ Value *value, VectorType vectorType, Operation *printer,
+ int64_t rank) const {
+ Location loc = op->getLoc();
+ if (rank == 0) {
+ emitCall(rewriter, loc, printer, value);
+ return;
+ }
+
+ emitCall(rewriter, loc, getPrintOpen(op));
+ Operation *printComma = getPrintComma(op);
+ int64_t dim = vectorType.getDimSize(0);
+ for (int64_t d = 0; d < dim; ++d) {
+ auto reducedType =
+ rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
+ auto llvmType = lowering.convertType(
+ rank > 1 ? reducedType : vectorType.getElementType());
+ Value *nestedVal =
+ extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
+ emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
+ if (d != dim - 1)
+ emitCall(rewriter, loc, printComma);
+ }
+ emitCall(rewriter, loc, getPrintClose(op));
+ }
+
+ // Helper to emit a call.
+ static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
+ Operation *ref, ValueRange params = ValueRange()) {
+ rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
+ rewriter.getSymbolRefAttr(ref), params);
+ }
+
+ // Helper for printer method declaration (first hit) and lookup.
+ static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
+ StringRef name, ArrayRef<LLVM::LLVMType> params) {
+ auto module = op->getParentOfType<ModuleOp>();
+ auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ if (func)
+ return func;
+ OpBuilder moduleBuilder(module.getBodyRegion());
+ return moduleBuilder.create<LLVM::LLVMFuncOp>(
+ op->getLoc(), name,
+ LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
+ params, /*isVarArg=*/false));
+ }
+
+ // Helpers for method names.
+ Operation *getPrintFloat(Operation *op) const {
+ LLVM::LLVMDialect *dialect = lowering.getDialect();
+ return getPrint(op, dialect, "print_f32",
+ LLVM::LLVMType::getFloatTy(dialect));
+ }
+ Operation *getPrintDouble(Operation *op) const {
+ LLVM::LLVMDialect *dialect = lowering.getDialect();
+ return getPrint(op, dialect, "print_f64",
+ LLVM::LLVMType::getDoubleTy(dialect));
+ }
+ Operation *getPrintOpen(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_open", {});
+ }
+ Operation *getPrintClose(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_close", {});
+ }
+ Operation *getPrintComma(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_comma", {});
+ }
+ Operation *getPrintNewline(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_newline", {});
+ }
+};
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
- converter.getDialect()->getContext(), converter);
+ VectorOuterProductOpConversion, VectorTypeCastOpConversion,
+ VectorPrintOpConversion>(converter.getDialect()->getContext(),
+ converter);
}
namespace {
diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
index ff4ff2cb540..4ed0902b292 100644
--- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp
+++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp
@@ -1587,6 +1587,23 @@ static LogicalResult verify(CreateMaskOp op) {
return success();
}
+//===----------------------------------------------------------------------===//
+// PrintOp
+//===----------------------------------------------------------------------===//
+
+ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::OperandType source;
+ Type sourceType;
+ return failure(parser.parseOperand(source) ||
+ parser.parseColonType(sourceType) ||
+ parser.resolveOperand(source, sourceType, result.operands));
+}
+
+static void print(OpAsmPrinter &p, PrintOp op) {
+ p << op.getOperationName() << ' ' << *op.source() << " : "
+ << op.getPrintType();
+}
+
namespace {
// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 73aba05b3b3..d3b1d409045 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -385,3 +385,41 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
// CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
// CHECK: llvm.mlir.constant(0 : index
// CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
+
+func @vector_print_scalar(%arg0: f32) {
+ vector.print %arg0 : f32
+ return
+}
+// CHECK-LABEL: vector_print_scalar(%arg0: !llvm.float)
+// CHECK: llvm.call @print_f32(%arg0) : (!llvm.float) -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
+
+func @vector_print_vector(%arg0: vector<2x2xf32>) {
+ vector.print %arg0 : vector<2x2xf32>
+ return
+}
+// CHECK-LABEL: vector_print_vector(%arg0: !llvm<"[2 x <2 x float>]">)
+// CHECK: llvm.call @print_open() : () -> ()
+// CHECK: %[[x0:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[2 x <2 x float>]">
+// CHECK: llvm.call @print_open() : () -> ()
+// CHECK: %[[x1:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: %[[x2:.*]] = llvm.extractelement %[[x0]][%[[x1]] : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.call @print_f32(%[[x2]]) : (!llvm.float) -> ()
+// CHECK: llvm.call @print_comma() : () -> ()
+// CHECK: %[[x3:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: %[[x4:.*]] = llvm.extractelement %[[x0]][%[[x3]] : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.call @print_f32(%[[x4]]) : (!llvm.float) -> ()
+// CHECK: llvm.call @print_close() : () -> ()
+// CHECK: llvm.call @print_comma() : () -> ()
+// CHECK: %[[x5:.*]] = llvm.extractvalue %arg0[1] : !llvm<"[2 x <2 x float>]">
+// CHECK: llvm.call @print_open() : () -> ()
+// CHECK: %[[x6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
+// CHECK: %[[x7:.*]] = llvm.extractelement %[[x5]][%[[x6]] : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.call @print_f32(%[[x7]]) : (!llvm.float) -> ()
+// CHECK: llvm.call @print_comma() : () -> ()
+// CHECK: %[[x8:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+// CHECK: %[[x9:.*]] = llvm.extractelement %[[x5]][%[[x8]] : !llvm.i64] : !llvm<"<2 x float>">
+// CHECK: llvm.call @print_f32(%[[x9]]) : (!llvm.float) -> ()
+// CHECK: llvm.call @print_close() : () -> ()
+// CHECK: llvm.call @print_close() : () -> ()
+// CHECK: llvm.call @print_newline() : () -> ()
diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir
index 3c2dd6075c8..7e8fce93294 100644
--- a/mlir/test/Dialect/VectorOps/invalid.mlir
+++ b/mlir/test/Dialect/VectorOps/invalid.mlir
@@ -818,3 +818,11 @@ func @insert_slices_invalid_tuple_element_type(%arg0 : tuple<vector<2x2xf32>, ve
: tuple<vector<2x2xf32>, vector<4x2xf32>> into vector<4x2xf32>
return
}
+
+// -----
+
+func @print_no_result(%arg0 : f32) -> i32 {
+ // expected-error@+1 {{cannot name an operation with no results}}
+ %0 = vector.print %arg0 : f32
+ return %0
+}
diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir
index f1db45e2716..b43c675893e 100644
--- a/mlir/test/Dialect/VectorOps/ops.mlir
+++ b/mlir/test/Dialect/VectorOps/ops.mlir
@@ -198,3 +198,10 @@ func @insert_slices(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>)
: tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
return %0 : vector<4x2xf32>
}
+
+// CHECK-LABEL: @vector_print
+func @vector_print(%arg0: vector<8x4xf32>) {
+ // CHECK: vector.print %{{.*}} : vector<8x4xf32>
+ vector.print %arg0 : vector<8x4xf32>
+ return
+}
diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
index ba6829503a7..7671db9f34f 100644
--- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
+++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h
@@ -285,4 +285,12 @@ print_memref_4d_f32(StridedMemRefType<float, 4> *M);
extern "C" MLIR_RUNNER_UTILS_EXPORT void
print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
+// Small runtime support "lib" for vector.print lowering.
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f64(double d);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_open();
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_close();
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_comma();
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_newline();
+
#endif // MLIR_CPU_RUNNER_MLIRUTILS_H_
diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
index f8007d79de4..c2a4cf452ed 100644
--- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
+++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp
@@ -63,3 +63,14 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) {
extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
impl::printMemRef(*M);
}
+
+// Small runtime support "lib" for vector.print lowering.
+// By providing elementary printing methods only, this
+// library can remain fully unaware of low-level implementation
+// details of our vectors.
+extern "C" void print_f32(float f) { std::cout << f; }
+extern "C" void print_f64(double d) { std::cout << d; }
+extern "C" void print_open() { std::cout << "( "; }
+extern "C" void print_close() { std::cout << " )"; }
+extern "C" void print_comma() { std::cout << ", "; }
+extern "C" void print_newline() { std::cout << "\n"; }
OpenPOWER on IntegriCloud