diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 219 | ||||
| -rw-r--r-- | mlir/test/LLVMIR/convert-to-llvmir.mlir | 28 |
2 files changed, 227 insertions, 20 deletions
diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 9fdfe1ed112..38838288015 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -296,11 +296,43 @@ public: // Get the LLVM module in which the types are constructed. llvm::Module &getModule() const { return dialect.getLLVMModule(); } - // Get the MLIR integer type whose bit width is defined by the pointer size - // used in the LLVM module. - IntegerType getIndexType() const { - return IntegerType::get(getModule().getDataLayout().getPointerSizeInBits(), - dialect.getContext()); + // Get the MLIR type wrapping the LLVM integer type whose bit width is defined + // by the pointer size used in the LLVM module. + LLVM::LLVMType getIndexType() const { + llvm::Type *llvmType = llvm::Type::getIntNTy( + getContext(), getModule().getDataLayout().getPointerSizeInBits()); + return LLVM::LLVMType::get(dialect.getContext(), llvmType); + } + + // Get the MLIR type wrapping the LLVM i8* type. + LLVM::LLVMType getVoidPtrType() const { + return LLVM::LLVMType::get(dialect.getContext(), + llvm::Type::getInt8PtrTy(getContext())); + } + + // Create an LLVM IR pseudo-operation defining the given index constant. + Value *createIndexConstant(FuncBuilder &builder, Location loc, + uint64_t value) const { + auto attr = builder.getIntegerAttr(builder.getIndexType(), value); + auto attrId = builder.getIdentifier("value"); + auto namedAttr = NamedAttribute{attrId, attr}; + return builder.create<LLVM::ConstantOp>( + loc, getIndexType(), ArrayRef<Value *>{}, + ArrayRef<NamedAttribute>{namedAttr}); + } + + // Get the array attribute named "position" containing the given list of + // integers as integer attribute elements. + static NamedAttribute getPositionAttribute(FuncBuilder &builder, + ArrayRef<int64_t> positions) { + SmallVector<Attribute, 4> attrPositions; + attrPositions.reserve(positions.size()); + for (int64_t pos : positions) + attrPositions.push_back( + builder.getIntegerAttr(builder.getIndexType(), pos)); + auto attr = builder.getArrayAttr(attrPositions); + auto attrId = builder.getIdentifier("position"); + return {attrId, attr}; } protected: @@ -356,10 +388,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { SmallVector<Value *, 4> results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - auto positionAttr = ArrayAttr::get( - IntegerAttr::get(this->getIndexType(), i), mlirContext); - auto positionAttrID = Identifier::get("position", mlirContext); - auto positionNamedAttr = NamedAttribute{positionAttrID, positionAttr}; + auto positionNamedAttr = this->getPositionAttribute(rewriter, i); auto type = TypeConverter::convert(op->getResult(i)->getType(), this->dialect.getLLVMModule()); results.push_back(rewriter.create<LLVM::ExtractValueOp>( @@ -436,6 +465,160 @@ struct ConstLLVMOpLowering using Super::Super; }; +// Check if the MemRefType `type` is supported by the lowering. We currently do +// not support memrefs with affine maps and non-default memory spaces. +static bool isSupportedMemRefType(MemRefType type) { + if (!type.getAffineMaps().empty()) + return false; + if (type.getMemorySpace() != 0) + return false; + return true; +} + +// An `alloc` is converted into a definition of a memref descriptor value and +// a call to `malloc` to allocate the underlying data buffer. The memref +// descriptor is of the LLVM structure type where the first element is a pointer +// to the (typed) data buffer, and the remaining elements serve to store +// dynamic sizes of the memref using LLVM-converted `index` type. +struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { + using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern; + + PatternMatchResult match(Instruction *op) const override { + if (!LLVMLegalizationPattern<AllocOp>::match(op)) + return matchFailure(); + auto allocOp = op->cast<AllocOp>(); + MemRefType type = allocOp->getType(); + return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); + } + + SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands, + FuncBuilder &rewriter) const override { + auto allocOp = op->cast<AllocOp>(); + MemRefType type = allocOp->getType(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. + SmallVector<Value *, 4> sizes; + sizes.reserve(allocOp->getNumOperands()); + unsigned i = 0; + for (int64_t s : type.getShape()) + sizes.push_back(s == -1 ? operands[i++] + : createIndexConstant(rewriter, op->getLoc(), s)); + assert(!sizes.empty() && "zero-dimensional allocation"); + + // Compute the total number of memref elements. + Value *cumulativeSize = sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSize = rewriter.create<LLVM::MulOp>( + op->getLoc(), getIndexType(), + ArrayRef<Value *>{cumulativeSize, sizes[i]}); + + // Create the MemRef descriptor. + auto structType = TypeConverter::convert(type, getModule()); + Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>( + op->getLoc(), structType, ArrayRef<Value *>{}); + + // Compute the total amount of bytes to allocate. + auto elementType = type.getElementType(); + assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) && + "invalid memref element type"); + uint64_t elementSize = 0; + if (auto vectorType = elementType.dyn_cast<VectorType>()) + elementSize = vectorType.getNumElements() * + llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8); + else + elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); + cumulativeSize = rewriter.create<LLVM::MulOp>( + op->getLoc(), getIndexType(), + ArrayRef<Value *>{ + cumulativeSize, + createIndexConstant(rewriter, op->getLoc(), elementSize)}); + + // Insert the `malloc` declaration if it is not already present. + Function *mallocFunc = + op->getFunction()->getModule()->getNamedFunction("malloc"); + if (!mallocFunc) { + auto mallocType = + rewriter.getFunctionType(getIndexType(), getVoidPtrType()); + mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType); + op->getFunction()->getModule()->getFunctions().push_back(mallocFunc); + } + + // Allocate the underlying buffer and store a pointer to it in the MemRef + // descriptor. + auto mallocNamedAttr = NamedAttribute{rewriter.getIdentifier("callee"), + rewriter.getFunctionAttr(mallocFunc)}; + Value *allocated = rewriter.create<LLVM::CallOp>( + op->getLoc(), getVoidPtrType(), ArrayRef<Value *>(cumulativeSize), + llvm::makeArrayRef(mallocNamedAttr)); + auto structElementType = TypeConverter::convert(elementType, getModule()); + auto elementPtrType = LLVM::LLVMType::get( + op->getContext(), structElementType.cast<LLVM::LLVMType>() + .getUnderlyingType() + ->getPointerTo()); + allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType, + ArrayRef<Value *>(allocated)); + auto namedPositionAttr = getPositionAttribute(rewriter, 0); + memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), structType, + ArrayRef<Value *>{memRefDescriptor, allocated}, + llvm::makeArrayRef(namedPositionAttr)); + + // Store dynamically allocated sizes in the descriptor. Dynamic sizes are + // passed in as operands. + for (auto indexedSize : llvm::enumerate(operands)) { + auto positionAttr = + getPositionAttribute(rewriter, 1 + indexedSize.index()); + memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( + op->getLoc(), structType, + ArrayRef<Value *>{memRefDescriptor, indexedSize.value()}, + llvm::makeArrayRef(positionAttr)); + } + + // Return the final value of the descriptor. + return {memRefDescriptor}; + } +}; + +// A `dealloc` is converted into a call to `free` on the underlying data buffer. +// The memref descriptor being an SSA value, there is no need to clean it up +// in any way. +struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { + using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern; + + SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands, + FuncBuilder &rewriter) const override { + assert(operands.size() == 1 && "dealloc takes one operand"); + + // Insert the `free` declaration if it is not already present. + Function *freeFunc = + op->getFunction()->getModule()->getNamedFunction("free"); + if (!freeFunc) { + auto freeType = rewriter.getFunctionType(getVoidPtrType(), {}); + freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType); + op->getFunction()->getModule()->getFunctions().push_back(freeFunc); + } + + // Obtain the MLIR-wrapped LLVM IR element pointer type. + llvm::Type *structType = cast<llvm::StructType>( + operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType()); + auto elementPtrType = + rewriter.getType<LLVM::LLVMType>(structType->getStructElementType(0)); + + // Extract the pointer to the data buffer and pass it to `free`. + Value *bufferPtr = rewriter.create<LLVM::ExtractValueOp>( + op->getLoc(), elementPtrType, operands[0], + llvm::makeArrayRef(getPositionAttribute(rewriter, 0))); + Value *casted = rewriter.create<LLVM::BitcastOp>( + op->getLoc(), getVoidPtrType(), bufferPtr); + auto freeNamedAttr = NamedAttribute{rewriter.getIdentifier("callee"), + rewriter.getFunctionAttr(freeFunc)}; + rewriter.create<LLVM::Call0Op>(op->getLoc(), casted, + llvm::makeArrayRef(freeNamedAttr)); + return {}; + } +}; + // Base class for LLVM IR lowering terminator operations with successors. template <typename SourceOp, typename TargetOp> struct OneToOneLLVMTerminatorLowering @@ -488,11 +671,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { - // FIXME: introduce builder::getNamedAttr - auto positionAttr = ArrayAttr::get( - IntegerAttr::get(this->getIndexType(), i), mlirContext); - auto positionAttrID = Identifier::get("position", mlirContext); - auto positionNamedAttr = NamedAttribute{positionAttrID, positionAttr}; + auto positionNamedAttr = getPositionAttribute(rewriter, i); packed = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), packedType, llvm::ArrayRef<Value *>{packed, operands[i]}, @@ -544,12 +723,12 @@ protected: // FIXME: this should be tablegen'ed return ConversionListBuilder< - AddIOpLowering, SubIOpLowering, MulIOpLowering, DivISOpLowering, - DivIUOpLowering, RemISOpLowering, RemIUOpLowering, AddFOpLowering, - SubFOpLowering, MulFOpLowering, CmpIOpLowering, CallOpLowering, - Call0OpLowering, BranchOpLowering, CondBranchOpLowering, - ReturnOpLowering, ConstLLVMOpLowering>::build(&converterStorage, - *llvmDialect); + AllocOpLowering, DeallocOpLowering, AddIOpLowering, SubIOpLowering, + MulIOpLowering, DivISOpLowering, DivIUOpLowering, RemISOpLowering, + RemIUOpLowering, AddFOpLowering, SubFOpLowering, MulFOpLowering, + CmpIOpLowering, CallOpLowering, Call0OpLowering, BranchOpLowering, + CondBranchOpLowering, ReturnOpLowering, + ConstLLVMOpLowering>::build(&converterStorage, *llvmDialect); } // Convert types using the stored LLVM IR module. diff --git a/mlir/test/LLVMIR/convert-to-llvmir.mlir b/mlir/test/LLVMIR/convert-to-llvmir.mlir index ceebda618d5..72ee2ba03a7 100644 --- a/mlir/test/LLVMIR/convert-to-llvmir.mlir +++ b/mlir/test/LLVMIR/convert-to-llvmir.mlir @@ -417,3 +417,31 @@ func @dfs_block_order() -> (i32) { br ^bb1 } +// CHECK-LABEL: func @alloc(%arg0: !llvm<"i64">, %arg1: !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> { +func @alloc(%arg0: index, %arg1: index) -> memref<?x42x?xf32> { +// CHECK-NEXT: %0 = "llvm.constant"() {value: 42 : index} : () -> !llvm<"i64"> +// CHECK-NEXT: %1 = "llvm.mul"(%arg0, %0) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64"> +// CHECK-NEXT: %2 = "llvm.mul"(%1, %arg1) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64"> +// CHECK-NEXT: %3 = "llvm.undef"() : () -> !llvm<"{ float*, i64, i64 }"> +// CHECK-NEXT: %4 = "llvm.constant"() {value: 4 : index} : () -> !llvm<"i64"> +// CHECK-NEXT: %5 = "llvm.mul"(%2, %4) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64"> +// CHECK-NEXT: %6 = "llvm.call"(%5) {callee: @malloc : (!llvm<"i64">) -> !llvm<"i8*">} : (!llvm<"i64">) -> !llvm<"i8*"> +// CHECK-NEXT: %7 = "llvm.bitcast"(%6) : (!llvm<"i8*">) -> !llvm<"float*"> +// CHECK-NEXT: %8 = "llvm.insertvalue"(%3, %7) {position: [0]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"float*">) -> !llvm<"{ float*, i64, i64 }"> +// CHECK-NEXT: %9 = "llvm.insertvalue"(%8, %arg0) {position: [1]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> +// CHECK-NEXT: %10 = "llvm.insertvalue"(%9, %arg1) {position: [2]} : (!llvm<"{ float*, i64, i64 }">, !llvm<"i64">) -> !llvm<"{ float*, i64, i64 }"> + %0 = alloc(%arg0, %arg1) : memref<?x42x?xf32> +// CHECK-NEXT: "llvm.return"(%10) : (!llvm<"{ float*, i64, i64 }">) -> () + return %0 : memref<?x42x?xf32> +} + + +// CHECK-LABEL: func @dealloc(%arg0: !llvm<"{ float*, i64, i64 }">) { +func @dealloc(%arg0: memref<?x42x?xf32>) { +// CHECK-NEXT: %0 = "llvm.extractvalue"(%arg0) {position: [0]} : (!llvm<"{ float*, i64, i64 }">) -> !llvm<"float*"> +// CHECK-NEXT: %1 = "llvm.bitcast"(%0) : (!llvm<"float*">) -> !llvm<"i8*"> +// CHECK-NEXT: "llvm.call0"(%1) {callee: @free : (!llvm<"i8*">) -> ()} : (!llvm<"i8*">) -> () + dealloc %arg0 : memref<?x42x?xf32> +// CHECK-NEXT: "llvm.return"() : () -> () + return +} |

