summaryrefslogtreecommitdiffstats
path: root/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp')
-rw-r--r--mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp219
1 files changed, 199 insertions, 20 deletions
diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
index 9fdfe1ed112..38838288015 100644
--- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
+++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
@@ -296,11 +296,43 @@ public:
// Get the LLVM module in which the types are constructed.
llvm::Module &getModule() const { return dialect.getLLVMModule(); }
- // Get the MLIR integer type whose bit width is defined by the pointer size
- // used in the LLVM module.
- IntegerType getIndexType() const {
- return IntegerType::get(getModule().getDataLayout().getPointerSizeInBits(),
- dialect.getContext());
+ // Get the MLIR type wrapping the LLVM integer type whose bit width is defined
+ // by the pointer size used in the LLVM module.
+ LLVM::LLVMType getIndexType() const {
+ llvm::Type *llvmType = llvm::Type::getIntNTy(
+ getContext(), getModule().getDataLayout().getPointerSizeInBits());
+ return LLVM::LLVMType::get(dialect.getContext(), llvmType);
+ }
+
+ // Get the MLIR type wrapping the LLVM i8* type.
+ LLVM::LLVMType getVoidPtrType() const {
+ return LLVM::LLVMType::get(dialect.getContext(),
+ llvm::Type::getInt8PtrTy(getContext()));
+ }
+
+ // Create an LLVM IR pseudo-operation defining the given index constant.
+ Value *createIndexConstant(FuncBuilder &builder, Location loc,
+ uint64_t value) const {
+ auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
+ auto attrId = builder.getIdentifier("value");
+ auto namedAttr = NamedAttribute{attrId, attr};
+ return builder.create<LLVM::ConstantOp>(
+ loc, getIndexType(), ArrayRef<Value *>{},
+ ArrayRef<NamedAttribute>{namedAttr});
+ }
+
+ // Get the array attribute named "position" containing the given list of
+ // integers as integer attribute elements.
+ static NamedAttribute getPositionAttribute(FuncBuilder &builder,
+ ArrayRef<int64_t> positions) {
+ SmallVector<Attribute, 4> attrPositions;
+ attrPositions.reserve(positions.size());
+ for (int64_t pos : positions)
+ attrPositions.push_back(
+ builder.getIntegerAttr(builder.getIndexType(), pos));
+ auto attr = builder.getArrayAttr(attrPositions);
+ auto attrId = builder.getIdentifier("position");
+ return {attrId, attr};
}
protected:
@@ -356,10 +388,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
SmallVector<Value *, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- auto positionAttr = ArrayAttr::get(
- IntegerAttr::get(this->getIndexType(), i), mlirContext);
- auto positionAttrID = Identifier::get("position", mlirContext);
- auto positionNamedAttr = NamedAttribute{positionAttrID, positionAttr};
+ auto positionNamedAttr = this->getPositionAttribute(rewriter, i);
auto type = TypeConverter::convert(op->getResult(i)->getType(),
this->dialect.getLLVMModule());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
@@ -436,6 +465,160 @@ struct ConstLLVMOpLowering
using Super::Super;
};
+// Check if the MemRefType `type` is supported by the lowering. We currently do
+// not support memrefs with affine maps and non-default memory spaces.
+static bool isSupportedMemRefType(MemRefType type) {
+ if (!type.getAffineMaps().empty())
+ return false;
+ if (type.getMemorySpace() != 0)
+ return false;
+ return true;
+}
+
+// An `alloc` is converted into a definition of a memref descriptor value and
+// a call to `malloc` to allocate the underlying data buffer. The memref
+// descriptor is of the LLVM structure type where the first element is a pointer
+// to the (typed) data buffer, and the remaining elements serve to store
+// dynamic sizes of the memref using LLVM-converted `index` type.
+struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
+ using LLVMLegalizationPattern<AllocOp>::LLVMLegalizationPattern;
+
+ PatternMatchResult match(Instruction *op) const override {
+ if (!LLVMLegalizationPattern<AllocOp>::match(op))
+ return matchFailure();
+ auto allocOp = op->cast<AllocOp>();
+ MemRefType type = allocOp->getType();
+ return isSupportedMemRefType(type) ? matchSuccess() : matchFailure();
+ }
+
+ SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
+ FuncBuilder &rewriter) const override {
+ auto allocOp = op->cast<AllocOp>();
+ MemRefType type = allocOp->getType();
+
+ // Get actual sizes of the memref as values: static sizes are constant
+ // values and dynamic sizes are passed to 'alloc' as operands.
+ SmallVector<Value *, 4> sizes;
+ sizes.reserve(allocOp->getNumOperands());
+ unsigned i = 0;
+ for (int64_t s : type.getShape())
+ sizes.push_back(s == -1 ? operands[i++]
+ : createIndexConstant(rewriter, op->getLoc(), s));
+ assert(!sizes.empty() && "zero-dimensional allocation");
+
+ // Compute the total number of memref elements.
+ Value *cumulativeSize = sizes.front();
+ for (unsigned i = 1, e = sizes.size(); i < e; ++i)
+ cumulativeSize = rewriter.create<LLVM::MulOp>(
+ op->getLoc(), getIndexType(),
+ ArrayRef<Value *>{cumulativeSize, sizes[i]});
+
+ // Create the MemRef descriptor.
+ auto structType = TypeConverter::convert(type, getModule());
+ Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
+ op->getLoc(), structType, ArrayRef<Value *>{});
+
+ // Compute the total amount of bytes to allocate.
+ auto elementType = type.getElementType();
+ assert((elementType.isIntOrFloat() || elementType.isa<VectorType>()) &&
+ "invalid memref element type");
+ uint64_t elementSize = 0;
+ if (auto vectorType = elementType.dyn_cast<VectorType>())
+ elementSize = vectorType.getNumElements() *
+ llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
+ else
+ elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
+ cumulativeSize = rewriter.create<LLVM::MulOp>(
+ op->getLoc(), getIndexType(),
+ ArrayRef<Value *>{
+ cumulativeSize,
+ createIndexConstant(rewriter, op->getLoc(), elementSize)});
+
+ // Insert the `malloc` declaration if it is not already present.
+ Function *mallocFunc =
+ op->getFunction()->getModule()->getNamedFunction("malloc");
+ if (!mallocFunc) {
+ auto mallocType =
+ rewriter.getFunctionType(getIndexType(), getVoidPtrType());
+ mallocFunc = new Function(rewriter.getUnknownLoc(), "malloc", mallocType);
+ op->getFunction()->getModule()->getFunctions().push_back(mallocFunc);
+ }
+
+ // Allocate the underlying buffer and store a pointer to it in the MemRef
+ // descriptor.
+ auto mallocNamedAttr = NamedAttribute{rewriter.getIdentifier("callee"),
+ rewriter.getFunctionAttr(mallocFunc)};
+ Value *allocated = rewriter.create<LLVM::CallOp>(
+ op->getLoc(), getVoidPtrType(), ArrayRef<Value *>(cumulativeSize),
+ llvm::makeArrayRef(mallocNamedAttr));
+ auto structElementType = TypeConverter::convert(elementType, getModule());
+ auto elementPtrType = LLVM::LLVMType::get(
+ op->getContext(), structElementType.cast<LLVM::LLVMType>()
+ .getUnderlyingType()
+ ->getPointerTo());
+ allocated = rewriter.create<LLVM::BitcastOp>(op->getLoc(), elementPtrType,
+ ArrayRef<Value *>(allocated));
+ auto namedPositionAttr = getPositionAttribute(rewriter, 0);
+ memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), structType,
+ ArrayRef<Value *>{memRefDescriptor, allocated},
+ llvm::makeArrayRef(namedPositionAttr));
+
+ // Store dynamically allocated sizes in the descriptor. Dynamic sizes are
+ // passed in as operands.
+ for (auto indexedSize : llvm::enumerate(operands)) {
+ auto positionAttr =
+ getPositionAttribute(rewriter, 1 + indexedSize.index());
+ memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
+ op->getLoc(), structType,
+ ArrayRef<Value *>{memRefDescriptor, indexedSize.value()},
+ llvm::makeArrayRef(positionAttr));
+ }
+
+ // Return the final value of the descriptor.
+ return {memRefDescriptor};
+ }
+};
+
+// A `dealloc` is converted into a call to `free` on the underlying data buffer.
+// The memref descriptor being an SSA value, there is no need to clean it up
+// in any way.
+struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
+ using LLVMLegalizationPattern<DeallocOp>::LLVMLegalizationPattern;
+
+ SmallVector<Value *, 4> rewrite(Instruction *op, ArrayRef<Value *> operands,
+ FuncBuilder &rewriter) const override {
+ assert(operands.size() == 1 && "dealloc takes one operand");
+
+ // Insert the `free` declaration if it is not already present.
+ Function *freeFunc =
+ op->getFunction()->getModule()->getNamedFunction("free");
+ if (!freeFunc) {
+ auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
+ freeFunc = new Function(rewriter.getUnknownLoc(), "free", freeType);
+ op->getFunction()->getModule()->getFunctions().push_back(freeFunc);
+ }
+
+ // Obtain the MLIR-wrapped LLVM IR element pointer type.
+ llvm::Type *structType = cast<llvm::StructType>(
+ operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType());
+ auto elementPtrType =
+ rewriter.getType<LLVM::LLVMType>(structType->getStructElementType(0));
+
+ // Extract the pointer to the data buffer and pass it to `free`.
+ Value *bufferPtr = rewriter.create<LLVM::ExtractValueOp>(
+ op->getLoc(), elementPtrType, operands[0],
+ llvm::makeArrayRef(getPositionAttribute(rewriter, 0)));
+ Value *casted = rewriter.create<LLVM::BitcastOp>(
+ op->getLoc(), getVoidPtrType(), bufferPtr);
+ auto freeNamedAttr = NamedAttribute{rewriter.getIdentifier("callee"),
+ rewriter.getFunctionAttr(freeFunc)};
+ rewriter.create<LLVM::Call0Op>(op->getLoc(), casted,
+ llvm::makeArrayRef(freeNamedAttr));
+ return {};
+ }
+};
+
// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
@@ -488,11 +671,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
- // FIXME: introduce builder::getNamedAttr
- auto positionAttr = ArrayAttr::get(
- IntegerAttr::get(this->getIndexType(), i), mlirContext);
- auto positionAttrID = Identifier::get("position", mlirContext);
- auto positionNamedAttr = NamedAttribute{positionAttrID, positionAttr};
+ auto positionNamedAttr = getPositionAttribute(rewriter, i);
packed = rewriter.create<LLVM::InsertValueOp>(
op->getLoc(), packedType,
llvm::ArrayRef<Value *>{packed, operands[i]},
@@ -544,12 +723,12 @@ protected:
// FIXME: this should be tablegen'ed
return ConversionListBuilder<
- AddIOpLowering, SubIOpLowering, MulIOpLowering, DivISOpLowering,
- DivIUOpLowering, RemISOpLowering, RemIUOpLowering, AddFOpLowering,
- SubFOpLowering, MulFOpLowering, CmpIOpLowering, CallOpLowering,
- Call0OpLowering, BranchOpLowering, CondBranchOpLowering,
- ReturnOpLowering, ConstLLVMOpLowering>::build(&converterStorage,
- *llvmDialect);
+ AllocOpLowering, DeallocOpLowering, AddIOpLowering, SubIOpLowering,
+ MulIOpLowering, DivISOpLowering, DivIUOpLowering, RemISOpLowering,
+ RemIUOpLowering, AddFOpLowering, SubFOpLowering, MulFOpLowering,
+ CmpIOpLowering, CallOpLowering, Call0OpLowering, BranchOpLowering,
+ CondBranchOpLowering, ReturnOpLowering,
+ ConstLLVMOpLowering>::build(&converterStorage, *llvmDialect);
}
// Convert types using the stored LLVM IR module.
OpenPOWER on IntegriCloud