summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-12-04 14:15:24 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-04 14:16:00 -0800
commitb3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2 (patch)
treec87f323372730ace55cb079a0812040abdcb3707
parentd20d763241020161ea173efe358d207b93310a34 (diff)
downloadbcm5719-llvm-b3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2.tar.gz
bcm5719-llvm-b3f7cf80a7dc7e9edd5b53827a942bada4a6aeb2.zip
Add a CL option to Standard to LLVM lowering to use alloca instead of malloc/free.
In the future, a more configurable malloc and free interface should be used and exposed via extra parameters to the `createLowerToLLVMPass`. Until requirements are gathered, a simple CL flag allows generating code that runs successfully on hardware that cannot use the stdlib. PiperOrigin-RevId: 283833424
-rw-r--r--mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h27
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp124
-rw-r--r--mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir6
3 files changed, 116 insertions, 41 deletions
diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
index 98e105aa2b5..c5c17b36f5e 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
@@ -57,25 +57,40 @@ void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
-std::unique_ptr<OpPassBase<ModuleOp>> createLowerToLLVMPass();
+/// By default stdlib malloc/free are used for allocating MemRef payloads.
+/// Specifying `useAlloca-true` emits stack allocations instead. In the future
+/// this may become an enum when we have concrete uses for other options.
+std::unique_ptr<OpPassBase<ModuleOp>>
+createLowerToLLVMPass(bool useAlloca = false);
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
/// is defined by a list of patterns and a type converter that will be obtained
/// during the pass using the provided callbacks.
+/// By default stdlib malloc/free are used for allocating MemRef payloads.
+/// Specifying `useAlloca-true` emits stack allocations instead. In the future
+/// this may become an enum when we have concrete uses for other options.
std::unique_ptr<OpPassBase<ModuleOp>>
createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
- LLVMTypeConverterMaker typeConverterMaker);
+ LLVMTypeConverterMaker typeConverterMaker,
+ bool useAlloca = false);
/// Creates a pass to convert operations to the LLVMIR dialect. The conversion
/// is defined by a list of patterns obtained during the pass using the provided
/// callback and an optional type conversion class, an instance is created
/// during the pass.
+/// By default stdlib malloc/free are used for allocating MemRef payloads.
+/// Specifying `useAlloca-true` emits stack allocations instead. In the future
+/// this may become an enum when we have concrete uses for other options.
template <typename TypeConverter = LLVMTypeConverter>
std::unique_ptr<OpPassBase<ModuleOp>>
-createLowerToLLVMPass(LLVMPatternListFiller patternListFiller) {
- return createLowerToLLVMPass(patternListFiller, [](MLIRContext *context) {
- return std::make_unique<TypeConverter>(context);
- });
+createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
+ bool useAlloca = false) {
+ return createLowerToLLVMPass(
+ patternListFiller,
+ [](MLIRContext *context) {
+ return std::make_unique<TypeConverter>(context);
+ },
+ useAlloca);
}
namespace LLVM {
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 793997e9045..23c7be310a9 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -38,9 +38,20 @@
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Type.h"
+#include "llvm/Support/CommandLine.h"
using namespace mlir;
+#define PASS_NAME "convert-std-to-llvm"
+
+static llvm::cl::OptionCategory
+ clOptionsCategory("Standard to LLVM lowering options");
+
+static llvm::cl::opt<bool>
+ clUseAlloca(PASS_NAME "-use-alloca",
+ llvm::cl::desc("Replace emission of malloc/free by alloca"),
+ llvm::cl::init(false));
+
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
assert(llvmDialect && "LLVM IR dialect is not registered");
@@ -764,6 +775,11 @@ static bool isSupportedMemRefType(MemRefType type) {
struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
+ AllocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
+ bool useAlloca = false)
+ : LLVMLegalizationPattern<AllocOp>(dialect_, converter),
+ useAlloca(useAlloca) {}
+
PatternMatchResult match(Operation *op) const override {
MemRefType type = cast<AllocOp>(op).getType();
if (isSupportedMemRefType(type))
@@ -825,32 +841,43 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
cumulativeSize = rewriter.create<LLVM::MulOp>(
loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize});
- // Insert the `malloc` declaration if it is not already present.
- auto module = op->getParentOfType<ModuleOp>();
- auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
- if (!mallocFunc) {
- OpBuilder moduleBuilder(op->getParentOfType<ModuleOp>().getBodyRegion());
- mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
- rewriter.getUnknownLoc(), "malloc",
- LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
- /*isVarArg=*/false));
- }
-
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
- Value *align = nullptr;
- if (auto alignAttr = allocOp.alignment()) {
- align = createIndexConstant(rewriter, loc,
- alignAttr.getValue().getSExtValue());
- cumulativeSize = rewriter.create<LLVM::SubOp>(
- loc, rewriter.create<LLVM::AddOp>(loc, cumulativeSize, align), one);
+ Value *allocated = nullptr;
+ int alignment = 0;
+ Value *alignmentValue = nullptr;
+ if (auto alignAttr = allocOp.alignment())
+ alignment = alignAttr.getValue().getSExtValue();
+
+ if (useAlloca) {
+ allocated = rewriter.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
+ cumulativeSize, alignment);
+ } else {
+ // Insert the `malloc` declaration if it is not already present.
+ auto module = op->getParentOfType<ModuleOp>();
+ auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
+ if (!mallocFunc) {
+ OpBuilder moduleBuilder(
+ op->getParentOfType<ModuleOp>().getBodyRegion());
+ mallocFunc = moduleBuilder.create<LLVM::LLVMFuncOp>(
+ rewriter.getUnknownLoc(), "malloc",
+ LLVM::LLVMType::getFunctionTy(getVoidPtrType(), getIndexType(),
+ /*isVarArg=*/false));
+ }
+ if (alignment != 0) {
+ alignmentValue = createIndexConstant(rewriter, loc, alignment);
+ cumulativeSize = rewriter.create<LLVM::SubOp>(
+ loc,
+ rewriter.create<LLVM::AddOp>(loc, cumulativeSize, alignmentValue),
+ one);
+ }
+ allocated = rewriter
+ .create<LLVM::CallOp>(
+ loc, getVoidPtrType(),
+ rewriter.getSymbolRefAttr(mallocFunc), cumulativeSize)
+ .getResult(0);
}
- Value *allocated =
- rewriter
- .create<LLVM::CallOp>(loc, getVoidPtrType(),
- rewriter.getSymbolRefAttr(mallocFunc),
- cumulativeSize)
- .getResult(0);
+
auto structElementType = lowering.convertType(elementType);
auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo(
type.getMemorySpace());
@@ -878,13 +905,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Field 2: Actual aligned pointer to payload.
Value *bitcastAligned = bitcastAllocated;
- if (align) {
+ if (!useAlloca && alignment != 0) {
+ assert(alignmentValue);
// offset = (align - (ptr % align))% align
Value *intVal = rewriter.create<LLVM::PtrToIntOp>(
loc, this->getIndexType(), allocated);
- Value *ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, align);
- Value *subbed = rewriter.create<LLVM::SubOp>(loc, align, ptrModAlign);
- Value *offset = rewriter.create<LLVM::URemOp>(loc, subbed, align);
+ Value *ptrModAlign =
+ rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue);
+ Value *subbed =
+ rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign);
+ Value *offset =
+ rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue);
Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(),
allocated, offset);
bitcastAligned = rewriter.create<LLVM::BitcastOp>(
@@ -930,6 +961,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
}
+
+ bool useAlloca;
};
// A CallOp automatically promotes MemRefType to a sequence of alloca/store and
@@ -1001,9 +1034,17 @@ struct CallIndirectOpLowering : public CallOpInterfaceLowering<CallIndirectOp> {
struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
+ DeallocOpLowering(LLVM::LLVMDialect &dialect_, LLVMTypeConverter &converter,
+ bool useAlloca = false)
+ : LLVMLegalizationPattern<DeallocOp>(dialect_, converter),
+ useAlloca(useAlloca) {}
+
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
+ if (useAlloca)
+ return rewriter.eraseOp(op), matchSuccess();
+
assert(operands.size() == 1 && "dealloc takes one operand");
OperandAdaptor<DeallocOp> transformed(operands);
@@ -1026,6 +1067,8 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
op, ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
return matchSuccess();
}
+
+ bool useAlloca;
};
struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
@@ -1759,7 +1802,6 @@ void mlir::populateStdToLLVMConversionPatterns(
patterns.insert<
AddFOpLowering,
AddIOpLowering,
- AllocOpLowering,
AndOpLowering,
BranchOpLowering,
CallIndirectOpLowering,
@@ -1768,7 +1810,6 @@ void mlir::populateStdToLLVMConversionPatterns(
CmpIOpLowering,
CondBranchOpLowering,
ConstLLVMOpLowering,
- DeallocOpLowering,
DimOpLowering,
DivFOpLowering,
DivISOpLowering,
@@ -1800,6 +1841,10 @@ void mlir::populateStdToLLVMConversionPatterns(
ViewOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
+ patterns.insert<
+ AllocOpLowering,
+ DeallocOpLowering>(
+ *converter.getDialect(), converter, clUseAlloca.getValue());
// clang-format on
}
@@ -1873,6 +1918,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
// By default, the patterns are those converting Standard operations to the
// LLVMIR dialect.
explicit LLVMLoweringPass(
+ bool useAlloca = false,
LLVMPatternListFiller patternListFiller =
populateStdToLLVMConversionPatterns,
LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter)
@@ -1911,17 +1957,25 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
};
} // end namespace
-std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerToLLVMPass() {
- return std::make_unique<LLVMLoweringPass>();
+std::unique_ptr<OpPassBase<ModuleOp>>
+mlir::createLowerToLLVMPass(bool useAlloca) {
+ return std::make_unique<LLVMLoweringPass>(useAlloca);
}
std::unique_ptr<OpPassBase<ModuleOp>>
mlir::createLowerToLLVMPass(LLVMPatternListFiller patternListFiller,
- LLVMTypeConverterMaker typeConverterMaker) {
- return std::make_unique<LLVMLoweringPass>(patternListFiller,
+ LLVMTypeConverterMaker typeConverterMaker,
+ bool useAlloca) {
+ return std::make_unique<LLVMLoweringPass>(useAlloca, patternListFiller,
typeConverterMaker);
}
static PassRegistration<LLVMLoweringPass>
- pass("convert-std-to-llvm", "Convert scalar and vector operations from the "
- "Standard to the LLVM dialect");
+ pass("convert-std-to-llvm",
+ "Convert scalar and vector operations from the "
+ "Standard to the LLVM dialect",
+ [] {
+ return std::make_unique<LLVMLoweringPass>(
+ clUseAlloca.getValue(), populateStdToLLVMConversionPatterns,
+ makeStandardToLLVMTypeConverter);
+ });
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir
index b18e84f6363..375c1ac4b17 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s
+// RUN: mlir-opt -convert-std-to-llvm -convert-std-to-llvm-use-alloca=1 %s | FileCheck %s --check-prefix=ALLOCA
// CHECK-LABEL: func @check_arguments(%arg0: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg1: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">, %arg2: !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }*">)
func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref<?x?xf32>, %mixed : memref<10x?xf32>) {
@@ -20,6 +21,7 @@ func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> {
}
// CHECK-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
+// ALLOCA-LABEL: func @zero_d_alloc() -> !llvm<"{ float*, float*, i64 }"> {
func @zero_d_alloc() -> memref<f32> {
// CHECK-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm<"float*">
@@ -34,6 +36,10 @@ func @zero_d_alloc() -> memref<f32> {
// CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm<"{ float*, float*, i64 }">
// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm<"{ float*, float*, i64 }">
+
+// ALLOCA-NOT: malloc
+// ALLOCA: alloca
+// ALLOCA-NOT: malloc
%0 = alloc() : memref<f32>
return %0 : memref<f32>
}
OpenPOWER on IntegriCloud