diff options
author | Aart Bik <ajcbik@google.com> | 2019-12-18 11:23:16 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-18 11:31:34 -0800 |
commit | d9b500d3bb151bfb96073b0d66e8338a5c0186d5 (patch) | |
tree | fd385a47ba52f5c4f274882f50cad6aecb3a2d0b /mlir/lib/Conversion/VectorToLLVM | |
parent | c169852fc5c5efb4b01600477da00e6ef2517231 (diff) | |
download | bcm5719-llvm-d9b500d3bb151bfb96073b0d66e8338a5c0186d5.tar.gz bcm5719-llvm-d9b500d3bb151bfb96073b0d66e8338a5c0186d5.zip |
[VectorOps] Add vector.print definition, with lowering support
Examples:
vector.print %f : f32
vector.print %x : vector<4xf32>
vector.print %y : vector<3x4xf32>
vector.print %z : vector<2x3x4xf32>
LLVM lowering replaces these with fully unrolled calls
into a small runtime support library that provides some
basic printing operations (single value, opening closing
bracket, comma, newline).
PiperOrigin-RevId: 286230325
Diffstat (limited to 'mlir/lib/Conversion/VectorToLLVM')
-rw-r--r-- | mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 126 |
1 files changed, 124 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 71bed9516ca..416cb4c99a3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -612,14 +612,136 @@ public: } }; +class VectorPrintOpConversion : public LLVMOpLowering { +public: + explicit VectorPrintOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : LLVMOpLowering(vector::PrintOp::getOperationName(), context, + typeConverter) {} + + // Proof-of-concept lowering implementation that relies on a small + // runtime support library, which only needs to provide a few + // printing methods (single value for all data types, opening/closing + // bracket, comma, newline). The lowering fully unrolls a vector + // in terms of these elementary printing operations. The advantage + // of this approach is that the library can remain unaware of all + // low-level implementation details of vectors while still supporting + // output of any shaped and dimensioned vector. Due to full unrolling, + // this approach is less suited for very large vectors though. + // + // TODO(ajcbik): rely solely on libc in future? something else? + // + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto printOp = cast<vector::PrintOp>(op); + auto adaptor = vector::PrintOpOperandAdaptor(operands); + Type printType = printOp.getPrintType(); + + if (lowering.convertType(printType) == nullptr) + return matchFailure(); + + // Make sure element type has runtime support (currently just Float/Double). + VectorType vectorType = printType.dyn_cast<VectorType>(); + Type eltType = vectorType ? vectorType.getElementType() : printType; + int64_t rank = vectorType ? vectorType.getRank() : 0; + Operation *printer; + if (eltType.isF32()) + printer = getPrintFloat(op); + else if (eltType.isF64()) + printer = getPrintDouble(op); + else + return matchFailure(); + + // Unroll vector into elementary print calls. + emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); + emitCall(rewriter, op->getLoc(), getPrintNewline(op)); + rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, + Value *value, VectorType vectorType, Operation *printer, + int64_t rank) const { + Location loc = op->getLoc(); + if (rank == 0) { + emitCall(rewriter, loc, printer, value); + return; + } + + emitCall(rewriter, loc, getPrintOpen(op)); + Operation *printComma = getPrintComma(op); + int64_t dim = vectorType.getDimSize(0); + for (int64_t d = 0; d < dim; ++d) { + auto reducedType = + rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; + auto llvmType = lowering.convertType( + rank > 1 ? reducedType : vectorType.getElementType()); + Value *nestedVal = + extractOne(rewriter, lowering, loc, value, llvmType, rank, d); + emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1); + if (d != dim - 1) + emitCall(rewriter, loc, printComma); + } + emitCall(rewriter, loc, getPrintClose(op)); + } + + // Helper to emit a call. + static void emitCall(ConversionPatternRewriter &rewriter, Location loc, + Operation *ref, ValueRange params = ValueRange()) { + rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{}, + rewriter.getSymbolRefAttr(ref), params); + } + + // Helper for printer method declaration (first hit) and lookup. + static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect, + StringRef name, ArrayRef<LLVM::LLVMType> params) { + auto module = op->getParentOfType<ModuleOp>(); + auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name); + if (func) + return func; + OpBuilder moduleBuilder(module.getBodyRegion()); + return moduleBuilder.create<LLVM::LLVMFuncOp>( + op->getLoc(), name, + LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect), + params, /*isVarArg=*/false)); + } + + // Helpers for method names. + Operation *getPrintFloat(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_f32", + LLVM::LLVMType::getFloatTy(dialect)); + } + Operation *getPrintDouble(Operation *op) const { + LLVM::LLVMDialect *dialect = lowering.getDialect(); + return getPrint(op, dialect, "print_f64", + LLVM::LLVMType::getDoubleTy(dialect)); + } + Operation *getPrintOpen(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_open", {}); + } + Operation *getPrintClose(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_close", {}); + } + Operation *getPrintComma(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_comma", {}); + } + Operation *getPrintNewline(Operation *op) const { + return getPrint(op, lowering.getDialect(), "print_newline", {}); + } +}; + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion, VectorExtractElementOpConversion, VectorExtractOpConversion, VectorInsertElementOpConversion, VectorInsertOpConversion, - VectorOuterProductOpConversion, VectorTypeCastOpConversion>( - converter.getDialect()->getContext(), converter); + VectorOuterProductOpConversion, VectorTypeCastOpConversion, + VectorPrintOpConversion>(converter.getDialect()->getContext(), + converter); } namespace { |