diff options
-rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.td | 32 | ||||
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 126 | ||||
-rw-r--r-- | mlir/lib/Dialect/VectorOps/VectorOps.cpp | 17 | ||||
-rw-r--r-- | mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 38 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/invalid.mlir | 8 | ||||
-rw-r--r-- | mlir/test/Dialect/VectorOps/ops.mlir | 7 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h | 8 | ||||
-rw-r--r-- | mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp | 11 |
8 files changed, 245 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index e031d7cfb8c..401e424c862 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -987,4 +987,36 @@ def Vector_TupleGetOp : }]; } +def Vector_PrintOp : + Vector_Op<"print", []>, Arguments<(ins AnyType:$source)> { + let summary = "print operation (for testing and debugging)"; + let description = [{ + Prints the source vector (or scalar) to stdout in human readable + format (for testing and debugging). No return value. + + Examples: + ``` + %0 = constant 0.0 : f32 + %1 = vector.broadcast %0 : f32 to vector<4xf32> + vector.print %1 : vector<4xf32> + + when lowered to LLVM, the vector print is unrolled into + elementary printing method calls that at runtime will yield + + ( 0.0, 0.0, 0.0, 0.0 ) + + on stdout when linked with a small runtime support library, + which only needs to provide a few printing methods (single + value for all data types, opening/closing bracket, comma, + newline). + ``` + }]; + let verifier = ?; + let extraClassDeclaration = [{ + Type getPrintType() { + return source()->getType(); + } + }]; +} + #endif // VECTOR_OPS diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 71bed9516ca..416cb4c99a3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -612,14 +612,136 @@ public: } }; +class VectorPrintOpConversion : public LLVMOpLowering { +public: + explicit VectorPrintOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::PrintOp::getOperationName(), context, + typeConverter) {} + + // Proof-of-concept lowering implementation that relies on a small + // runtime support library, which only needs to provide a few + // printing methods (single value for all data types, opening/closing + // bracket, comma, newline). The lowering fully unrolls a vector + // in terms of these elementary printing operations. The advantage + // of this approach is that the library can remain unaware of all + // low-level implementation details of vectors while still supporting + // output of any shaped and dimensioned vector. Due to full unrolling, + // this approach is less suited for very large vectors though. + // + // TODO(ajcbik): rely solely on libc in future? something else? + // + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto printOp = cast<vector::PrintOp>(op); + auto adaptor = vector::PrintOpOperandAdaptor(operands); + Type printType = printOp.getPrintType(); + + if (lowering.convertType(printType) == nullptr) + return matchFailure(); + + // Make sure element type has runtime support (currently just Float/Double). + VectorType vectorType = printType.dyn_cast<VectorType>(); + Type eltType = vectorType ? vectorType.getElementType() : printType; + int64_t rank = vectorType ? vectorType.getRank() : 0; + Operation *printer; + if (eltType.isF32()) + printer = getPrintFloat(op); + else if (eltType.isF64()) + printer = getPrintDouble(op); + else + return matchFailure(); + + // Unroll vector into elementary print calls. + emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); + emitCall(rewriter, op->getLoc(), getPrintNewline(op)); + rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, + Value *value, VectorType vectorType, Operation *printer, + int64_t rank) const { + Location loc = op->getLoc(); + if (rank == 0) { + emitCall(rewriter, loc, printer, value); + return; + } + + emitCall(rewriter, loc, getPrintOpen(op)); + Operation *printComma = getPrintComma(op); + int64_t dim = vectorType.getDimSize(0); + for (int64_t d = 0; d < dim; ++d) { + auto reducedType = + rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; + auto llvmType = lowering.convertType( + rank > 1 ? reducedType : vectorType.getElementType()); + Value *nestedVal = + extractOne(rewriter, lowering, loc, value, llvmType, rank, d); + emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); + if (d != dim - 1) + emitCall(rewriter, loc, printComma); + } + emitCall(rewriter, loc, getPrintClose(op)); + } + + // Helper to emit a call. + static void emitCall(ConversionPatternRewriter &rewriter, Location loc, + Operation *ref, ValueRange params = ValueRange()) { + rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, + rewriter.getSymbolRefAttr(ref), params); + } + + // Helper for printer method declaration (first hit) and lookup. + static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, + StringRef name, ArrayRef<LLVM::LLVMType> params) { + auto module = op->getParentOfType<ModuleOp>(); + auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); + if (func) + return func; + OpBuilder moduleBuilder(module.getBodyRegion()); + return moduleBuilder.create<LLVM::LLVMFuncOp>( + op->getLoc(), name, + LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), + params, /*isVarArg=*/false)); + } + + // Helpers for method names. + Operation *getPrintFloat(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_f32", + LLVM::LLVMType::getFloatTy(dialect)); + } + Operation *getPrintDouble(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_f64", + LLVM::LLVMType::getDoubleTy(dialect)); + } + Operation *getPrintOpen(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_open", {}); + } + Operation *getPrintClose(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_close", {}); + } + Operation *getPrintComma(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_comma", {}); + } + Operation *getPrintNewline(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_newline", {}); + } +}; + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, VectorExtractElementOpConversion, VectorExtractOpConversion, VectorInsertElementOpConversion, VectorInsertOpConversion, - VectorOuterProductOpConversion, VectorTypeCastOpConversion>( - converter.getDialect()->getContext(), converter); + VectorOuterProductOpConversion, VectorTypeCastOpConversion, + VectorPrintOpConversion>(converter.getDialect()->getContext(), + converter); } namespace { diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index ff4ff2cb540..4ed0902b292 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1587,6 +1587,23 @@ static LogicalResult verify(CreateMaskOp op) { return success(); } +//===----------------------------------------------------------------------===// +// PrintOp +//===----------------------------------------------------------------------===// + +ParseResult parsePrintOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType source; + Type sourceType; + return failure(parser.parseOperand(source) || + parser.parseColonType(sourceType) || + parser.resolveOperand(source, sourceType, result.operands)); +} + +static void print(OpAsmPrinter &p, PrintOp op) { + p << op.getOperationName() << ' ' << *op.source() << " : " + << op.getPrintType(); +} + namespace { // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 73aba05b3b3..d3b1d409045 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -385,3 +385,41 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> { // CHECK: llvm.insertvalue %[[alignedBit]], {{.*}}[1] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> // CHECK: llvm.mlir.constant(0 : index // CHECK: llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }"> + +func @vector_print_scalar(%arg0: f32) { + vector.print %arg0 : f32 + return +} +// CHECK-LABEL: vector_print_scalar(%arg0: !llvm.float) +// CHECK: llvm.call @print_f32(%arg0) : (!llvm.float) -> () +// CHECK: llvm.call @print_newline() : () -> () + +func @vector_print_vector(%arg0: vector<2x2xf32>) { + vector.print %arg0 : vector<2x2xf32> + return +} +// CHECK-LABEL: vector_print_vector(%arg0: !llvm<"[2 x <2 x float>]">) +// CHECK: llvm.call @print_open() : () -> () +// CHECK: %[[x0:.*]] = llvm.extractvalue %arg0[0] : !llvm<"[2 x <2 x float>]"> +// CHECK: llvm.call @print_open() : () -> () +// CHECK: %[[x1:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[x2:.*]] = llvm.extractelement %[[x0]][%[[x1]] : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.call @print_f32(%[[x2]]) : (!llvm.float) -> () +// CHECK: llvm.call @print_comma() : () -> () +// CHECK: %[[x3:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: %[[x4:.*]] = llvm.extractelement %[[x0]][%[[x3]] : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.call @print_f32(%[[x4]]) : (!llvm.float) -> () +// CHECK: llvm.call @print_close() : () -> () +// CHECK: llvm.call @print_comma() : () -> () +// CHECK: %[[x5:.*]] = llvm.extractvalue %arg0[1] : !llvm<"[2 x <2 x float>]"> +// CHECK: llvm.call @print_open() : () -> () +// CHECK: %[[x6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: %[[x7:.*]] = llvm.extractelement %[[x5]][%[[x6]] : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.call @print_f32(%[[x7]]) : (!llvm.float) -> () +// CHECK: llvm.call @print_comma() : () -> () +// CHECK: %[[x8:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: %[[x9:.*]] = llvm.extractelement %[[x5]][%[[x8]] : !llvm.i64] : !llvm<"<2 x float>"> +// CHECK: llvm.call @print_f32(%[[x9]]) : (!llvm.float) -> () +// CHECK: llvm.call @print_close() : () -> () +// CHECK: llvm.call @print_close() : () -> () +// CHECK: llvm.call @print_newline() : () -> () diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index 3c2dd6075c8..7e8fce93294 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -818,3 +818,11 @@ func @insert_slices_invalid_tuple_element_type(%arg0 : tuple<vector<2x2xf32>, ve : tuple<vector<2x2xf32>, vector<4x2xf32>> into vector<4x2xf32> return } + +// ----- + +func @print_no_result(%arg0 : f32) -> i32 { + // expected-error@+1 {{cannot name an operation with no results}} + %0 = vector.print %arg0 : f32 + return %0 +} diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index f1db45e2716..b43c675893e 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -198,3 +198,10 @@ func @insert_slices(%arg0 : tuple<vector<2x2xf32>, vector<2x2xf32>>) : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32> return %0 : vector<4x2xf32> } + +// CHECK-LABEL: @vector_print +func @vector_print(%arg0: vector<8x4xf32>) { + // CHECK: vector.print %{{.*}} : vector<8x4xf32> + vector.print %arg0 : vector<8x4xf32> + return +} diff --git a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h index ba6829503a7..7671db9f34f 100644 --- a/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h +++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @@ -285,4 +285,12 @@ print_memref_4d_f32(StridedMemRefType<float, 4> *M); extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_vector_4x4xf32(StridedMemRefType<Vector2D<4, 4, float>, 2> *M); +// Small runtime support "lib" for vector.print lowering. +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f); +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f64(double d); +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_open(); +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_close(); +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_comma(); +extern "C" MLIR_RUNNER_UTILS_EXPORT void print_newline(); + #endif // MLIR_CPU_RUNNER_MLIRUTILS_H_ diff --git a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp index f8007d79de4..c2a4cf452ed 100644 --- a/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp +++ b/mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp @@ -63,3 +63,14 @@ extern "C" void print_memref_3d_f32(StridedMemRefType<float, 3> *M) { extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) { impl::printMemRef(*M); } + +// Small runtime support "lib" for vector.print lowering. +// By providing elementary printing methods only, this +// library can remain fully unaware of low-level implementation +// details of our vectors. +extern "C" void print_f32(float f) { std::cout << f; } +extern "C" void print_f64(double d) { std::cout << d; } +extern "C" void print_open() { std::cout << "( "; } +extern "C" void print_close() { std::cout << " )"; } +extern "C" void print_comma() { std::cout << ", "; } +extern "C" void print_newline() { std::cout << "\n"; } |