summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp126
1 files changed, 124 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 71bed9516ca..416cb4c99a3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -612,14 +612,136 @@ public:
}
};
+class VectorPrintOpConversion : public LLVMOpLowering {
+public:
+ explicit VectorPrintOpConversion(MLIRContext *context,
+ LLVMTypeConverter &typeConverter)
+ : LLVMOpLowering(vector::PrintOp::getOperationName(), context,
+ typeConverter) {}
+
+ // Proof-of-concept lowering implementation that relies on a small
+ // runtime support library, which only needs to provide a few
+ // printing methods (single value for all data types, opening/closing
+ // bracket, comma, newline). The lowering fully unrolls a vector
+ // in terms of these elementary printing operations. The advantage
+ // of this approach is that the library can remain unaware of all
+ // low-level implementation details of vectors while still supporting
+ // output of any shaped and dimensioned vector. Due to full unrolling,
+ // this approach is less suited for very large vectors though.
+ //
+ // TODO(ajcbik): rely solely on libc in future? something else?
+ //
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto printOp = cast<vector::PrintOp>(op);
+ auto adaptor = vector::PrintOpOperandAdaptor(operands);
+ Type printType = printOp.getPrintType();
+
+ if (lowering.convertType(printType) == nullptr)
+ return matchFailure();
+
+ // Make sure element type has runtime support (currently just Float/Double).
+ VectorType vectorType = printType.dyn_cast<VectorType>();
+ Type eltType = vectorType ? vectorType.getElementType() : printType;
+ int64_t rank = vectorType ? vectorType.getRank() : 0;
+ Operation *printer;
+ if (eltType.isF32())
+ printer = getPrintFloat(op);
+ else if (eltType.isF64())
+ printer = getPrintDouble(op);
+ else
+ return matchFailure();
+
+ // Unroll vector into elementary print calls.
+ emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank);
+ emitCall(rewriter, op->getLoc(), getPrintNewline(op));
+ rewriter.eraseOp(op);
+ return matchSuccess();
+ }
+
+private:
+ void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
+ Value *value, VectorType vectorType, Operation *printer,
+ int64_t rank) const {
+ Location loc = op->getLoc();
+ if (rank == 0) {
+ emitCall(rewriter, loc, printer, value);
+ return;
+ }
+
+ emitCall(rewriter, loc, getPrintOpen(op));
+ Operation *printComma = getPrintComma(op);
+ int64_t dim = vectorType.getDimSize(0);
+ for (int64_t d = 0; d < dim; ++d) {
+ auto reducedType =
+ rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr;
+ auto llvmType = lowering.convertType(
+ rank > 1 ? reducedType : vectorType.getElementType());
+ Value *nestedVal =
+ extractOne(rewriter, lowering, loc, value, llvmType, rank, d);
+ emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1);
+ if (d != dim - 1)
+ emitCall(rewriter, loc, printComma);
+ }
+ emitCall(rewriter, loc, getPrintClose(op));
+ }
+
+ // Helper to emit a call.
+ static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
+ Operation *ref, ValueRange params = ValueRange()) {
+ rewriter.create<LLVM::CallOp>(loc, ArrayRef<Type>{},
+ rewriter.getSymbolRefAttr(ref), params);
+ }
+
+ // Helper for printer method declaration (first hit) and lookup.
+ static Operation *getPrint(Operation *op, LLVM::LLVMDialect *dialect,
+ StringRef name, ArrayRef<LLVM::LLVMType> params) {
+ auto module = op->getParentOfType<ModuleOp>();
+ auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ if (func)
+ return func;
+ OpBuilder moduleBuilder(module.getBodyRegion());
+ return moduleBuilder.create<LLVM::LLVMFuncOp>(
+ op->getLoc(), name,
+ LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(dialect),
+ params, /*isVarArg=*/false));
+ }
+
+ // Helpers for method names.
+ Operation *getPrintFloat(Operation *op) const {
+ LLVM::LLVMDialect *dialect = lowering.getDialect();
+ return getPrint(op, dialect, "print_f32",
+ LLVM::LLVMType::getFloatTy(dialect));
+ }
+ Operation *getPrintDouble(Operation *op) const {
+ LLVM::LLVMDialect *dialect = lowering.getDialect();
+ return getPrint(op, dialect, "print_f64",
+ LLVM::LLVMType::getDoubleTy(dialect));
+ }
+ Operation *getPrintOpen(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_open", {});
+ }
+ Operation *getPrintClose(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_close", {});
+ }
+ Operation *getPrintComma(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_comma", {});
+ }
+ Operation *getPrintNewline(Operation *op) const {
+ return getPrint(op, lowering.getDialect(), "print_newline", {});
+ }
+};
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<VectorBroadcastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorOuterProductOpConversion, VectorTypeCastOpConversion>(
- converter.getDialect()->getContext(), converter);
+ VectorOuterProductOpConversion, VectorTypeCastOpConversion,
+ VectorPrintOpConversion>(converter.getDialect()->getContext(),
+ converter);
}
namespace {
OpenPOWER on IntegriCloud