diff options
| author | Nicolas Vasilache <ntv@google.com> | 2019-12-04 14:15:24 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-04 14:16:00 -0800 |
| commit | b3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2 (patch) | |
| tree | c87f323372730ace55cb079a0812040abdcb3707 | |
| parent | d20d763241020161ea173efe358d207b93310a34 (diff) | |
| download | bcm5719-llvm-b3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2.tar.gz bcm5719-llvm-b3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2.zip | |
Add a CL option to Standard to LLVM lowering to use alloca instead of malloc/free.
In the future, a more configurable malloc and free interface should be used and exposed via
extra parameters to the `createLowerToLLVMPass`. Until requirements are gathered, a simple CL flag allows generating code that runs successfully on hardware that cannot use the stdlib.
PiperOrigin-RevId: 283833424
3 files changed, 116 insertions, 41 deletions
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 98e105aa2b5..c5c17b36f5e 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -57,25 +57,40 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. -std::unique_ptr<OpPassBase<ModuleOp>> createLowerToLLVMPass(); +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. +std::unique_ptr<OpPassBase<ModuleOp>> +createLowerToLLVMPass(bool useAlloca = false); /// Creates a pass to convert operations to the LLVMIR dialect. The conversion /// is defined by a list of patterns and a type converter that will be obtained /// during the pass using the provided callbacks. +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. std::unique_ptr<OpPassBase<ModuleOp>> createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, - LLVMTypeConverterMaker typeConverterMaker); + LLVMTypeConverterMaker typeConverterMaker, + bool useAlloca = false); /// Creates a pass to convert operations to the LLVMIR dialect. The conversion /// is defined by a list of patterns obtained during the pass using the provided /// callback and an optional type conversion class, an instance is created /// during the pass. +/// By default stdlib malloc/free are used for allocating MemRef payloads. +/// Specifying `useAlloca-true` emits stack allocations instead. In the future +/// this may become an enum when we have concrete uses for other options. template <typename TypeConverter = LLVMTypeConverter> std::unique_ptr<OpPassBase<ModuleOp>> -createLowerToLLVMPass(LLVMPatternListFiller patternListFiller) { - return createLowerToLLVMPass(patternListFiller, [](MLIRContext *context) { - return std::make_unique<TypeConverter>(context); - }); +createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, + bool useAlloca = false) { + return createLowerToLLVMPass( + patternListFiller, + [](MLIRContext *context) { + return std::make_unique<TypeConverter>(context); + }, + useAlloca); } namespace LLVM { diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 793997e9045..23c7be310a9 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -38,9 +38,20 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" using namespace mlir; +#define PASS_NAME "convert-std-to-llvm" + +static llvm::cl::OptionCategory + clOptionsCategory("Standard to LLVM lowering options"); + +static llvm::cl::opt<bool> + clUseAlloca(PASS_NAME "-use-alloca", + llvm::cl::desc("Replace emission of malloc/free by alloca"), + llvm::cl::init(false)); + LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) { assert(llvmDialect && "LLVM IR dialect is not registered"); @@ -764,6 +775,11 @@ static bool isSupportedMemRefType(MemRefType type) { struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern; + AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter, + bool useAlloca = false) + : LLVMLegalizationPattern<AllocOp>(dialect_, converter), + useAlloca(useAlloca) {} + PatternMatchResult match(Operation *op) const override { MemRefType type = cast<AllocOp>(op).getType(); if (isSupportedMemRefType(type)) @@ -825,32 +841,43 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { cumulativeSize = rewriter.create<LLVM::MulOp>( loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize}); - // Insert the `malloc` declaration if it is not already present. - auto module = op->getParentOfType<ModuleOp>(); - auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc"); - if (!mallocFunc) { - OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion()); - mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>( - rewriter.getUnknownLoc(), "malloc", - LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(), - /*isVarArg=*/false)); - } - // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - Value *align = nullptr; - if (auto alignAttr = allocOp.alignment()) { - align = createIndexConstant(rewriter, loc, - alignAttr.getValue().getSExtValue()); - cumulativeSize = rewriter.create<LLVM::SubOp>( - loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one); + Value *allocated = nullptr; + int alignment = 0; + Value *alignmentValue = nullptr; + if (auto alignAttr = allocOp.alignment()) + alignment = alignAttr.getValue().getSExtValue(); + + if (useAlloca) { + allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(), + cumulativeSize, alignment); + } else { + // Insert the `malloc` declaration if it is not already present. + auto module = op->getParentOfType<ModuleOp>(); + auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc"); + if (!mallocFunc) { + OpBuilder moduleBuilder( + op->getParentOfType<ModuleOp>().getBodyRegion()); + mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>( + rewriter.getUnknownLoc(), "malloc", + LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(), + /*isVarArg=*/false)); + } + if (alignment != 0) { + alignmentValue = createIndexConstant(rewriter, loc, alignment); + cumulativeSize = rewriter.create<LLVM::SubOp>( + loc, + rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue), + one); + } + allocated = rewriter + .create<LLVM::CallOp>( + loc, getVoidPtrType(), + rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize) + .getResult(0); } - Value *allocated = - rewriter - .create<LLVM::CallOp>(loc, getVoidPtrType(), - rewriter.getSymbolRefAttr(mallocFunc), - cumulativeSize) - .getResult(0); + auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo( type.getMemorySpace()); @@ -878,13 +905,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { // Field 2: Actual aligned pointer to payload. Value *bitcastAligned = bitcastAllocated; - if (align) { + if (!useAlloca && alignment != 0) { + assert(alignmentValue); // offset = (align - (ptr % align))% align Value *intVal = rewriter.create<LLVM::PtrToIntOp>( loc, this->getIndexType(), allocated); - Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align); - Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign); - Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align); + Value *ptrModAlign = + rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue); + Value *subbed = + rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign); + Value *offset = + rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue); Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(), allocated, offset); bitcastAligned = rewriter.create<LLVM::BitcastOp>( @@ -930,6 +961,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } + + bool useAlloca; }; // A CallOp automatically promotes MemRefType to a sequence of alloca/store and @@ -1001,9 +1034,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> { struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern; + DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter, + bool useAlloca = false) + : LLVMLegalizationPattern<DeallocOp>(dialect_, converter), + useAlloca(useAlloca) {} + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { + if (useAlloca) + return rewriter.eraseOp(op), matchSuccess(); + assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor<DeallocOp> transformed(operands); @@ -1026,6 +1067,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted); return matchSuccess(); } + + bool useAlloca; }; struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { @@ -1759,7 +1802,6 @@ void mlir::populateStdToLLVMConversionPatterns( patterns.insert< AddFOpLowering, AddIOpLowering, - AllocOpLowering, AndOpLowering, BranchOpLowering, CallIndirectOpLowering, @@ -1768,7 +1810,6 @@ void mlir::populateStdToLLVMConversionPatterns( CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, - DeallocOpLowering, DimOpLowering, DivFOpLowering, DivISOpLowering, @@ -1800,6 +1841,10 @@ void mlir::populateStdToLLVMConversionPatterns( ViewOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); + patterns.insert< + AllocOpLowering, + DeallocOpLowering>( + *converter.getDialect(), converter, clUseAlloca.getValue()); // clang-format on } @@ -1873,6 +1918,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> { // By default, the patterns are those converting Standard operations to the // LLVMIR dialect. explicit LLVMLoweringPass( + bool useAlloca = false, LLVMPatternListFiller patternListFiller = populateStdToLLVMConversionPatterns, LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) @@ -1911,17 +1957,25 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> { }; } // end namespace -std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass() { - return std::make_unique<LLVMLoweringPass>(); +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createLowerToLLVMPass(bool useAlloca) { + return std::make_unique<LLVMLoweringPass>(useAlloca); } std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller, - LLVMTypeConverterMaker typeConverterMaker) { - return std::make_unique<LLVMLoweringPass>(patternListFiller, + LLVMTypeConverterMaker typeConverterMaker, + bool useAlloca) { + return std::make_unique<LLVMLoweringPass>(useAlloca, patternListFiller, typeConverterMaker); } static PassRegistration<LLVMLoweringPass> - pass("convert-std-to-llvm", "Convert scalar and vector operations from the " - "Standard to the LLVM dialect"); + pass("convert-std-to-llvm", + "Convert scalar and vector operations from the " + "Standard to the LLVM dialect", + [] { + return std::make_unique<LLVMLoweringPass>( + clUseAlloca.getValue(), populateStdToLLVMConversionPatterns, + makeStandardToLLVMTypeConverter); + }); diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir index b18e84f6363..375c1ac4b17 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA // CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">) func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) { @@ -20,6 +21,7 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { } // CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { +// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> { func @zero_d_alloc() -> memref<f32> { // CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*"> @@ -34,6 +36,10 @@ func @zero_d_alloc() -> memref<f32> { // CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }"> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }"> + +// ALLOCA-NOT: malloc +// ALLOCA: alloca +// ALLOCA-NOT: malloc %0 = alloc() : memref<f32> return %0 : memref<f32> } |

