summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp97
1 files changed, 39 insertions, 58 deletions
diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
index 332d6324879..8150defbd7d 100644
--- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
+++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp
@@ -375,37 +375,32 @@ public:
Value *createIndexConstant(FuncBuilder &builder, Location loc,
uint64_t value) const {
auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
- auto namedAttr = builder.getNamedAttr("value", attr);
- return builder.create<LLVM::ConstantOp>(
- loc, getIndexType(), ArrayRef<Value *>{},
- ArrayRef<NamedAttribute>{namedAttr});
+ return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
}
// Get the array attribute named "position" containing the given list of
// integers as integer attribute elements.
- static NamedAttribute getPositionAttribute(FuncBuilder &builder,
- ArrayRef<int64_t> positions) {
- SmallVector<Attribute, 4> attrPositions;
- attrPositions.reserve(positions.size());
- for (int64_t pos : positions)
- attrPositions.push_back(
- builder.getIntegerAttr(builder.getIndexType(), pos));
- return builder.getNamedAttr("position",
- builder.getArrayAttr(attrPositions));
+ static ArrayAttr getIntegerArrayAttr(FuncBuilder &builder,
+ ArrayRef<int64_t> values) {
+ SmallVector<Attribute, 4> attrs;
+ attrs.reserve(values.size());
+ for (int64_t pos : values)
+ attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
+ return builder.getArrayAttr(attrs);
}
// Extract raw data pointer value from a value representing a memref.
static Value *extractMemRefElementPtr(FuncBuilder &builder, Location loc,
Value *convertedMemRefValue,
Type elementTypePtr,
- bool statically_shaped) {
+ bool hasStaticShape) {
Value *buffer;
- if (statically_shaped)
+ if (hasStaticShape)
return convertedMemRefValue;
else
return builder.create<LLVM::ExtractValueOp>(
- loc, elementTypePtr, ArrayRef<Value *>{convertedMemRefValue},
- getPositionAttribute(builder, 0));
+ loc, elementTypePtr, convertedMemRefValue,
+ getIntegerArrayAttr(builder, 0));
return buffer;
}
@@ -461,13 +456,11 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
SmallVector<Value *, 4> results;
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
- auto positionNamedAttr = this->getPositionAttribute(rewriter, i);
auto type = TypeConverter::convert(op->getResult(i)->getType(),
this->dialect.getLLVMModule());
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- op->getLoc(), type,
- ArrayRef<Value *>(newOp->getInstruction()->getResult(0)),
- llvm::makeArrayRef(positionNamedAttr)));
+ op->getLoc(), type, newOp->getInstruction()->getResult(0),
+ this->getIntegerArrayAttr(rewriter, i)));
}
return results;
}
@@ -608,13 +601,11 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
// Allocate the underlying buffer and store a pointer to it in the MemRef
// descriptor.
- auto mallocNamedAttr =
- rewriter.getNamedAttr("callee", rewriter.getFunctionAttr(mallocFunc));
Value *allocated =
rewriter
.create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(),
- ArrayRef<Value *>(cumulativeSize),
- llvm::makeArrayRef(mallocNamedAttr))
+ rewriter.getFunctionAttr(mallocFunc),
+ cumulativeSize)
->getResult(0);
auto structElementType = TypeConverter::convert(elementType, getModule());
auto elementPtrType = LLVM::LLVMType::get(
@@ -634,21 +625,16 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>(
op->getLoc(), structType, ArrayRef<Value *>{});
- auto namedPositionAttr = getPositionAttribute(rewriter, 0);
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType,
- ArrayRef<Value *>{memRefDescriptor, allocated},
- llvm::makeArrayRef(namedPositionAttr));
+ op->getLoc(), structType, memRefDescriptor, allocated,
+ getIntegerArrayAttr(rewriter, 0));
// Store dynamically allocated sizes in the descriptor. Dynamic sizes are
// passed in as operands.
for (auto indexedSize : llvm::enumerate(operands)) {
- auto positionAttr =
- getPositionAttribute(rewriter, 1 + indexedSize.index());
memRefDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType,
- ArrayRef<Value *>{memRefDescriptor, indexedSize.value()},
- llvm::makeArrayRef(positionAttr));
+ op->getLoc(), structType, memRefDescriptor, indexedSize.value(),
+ getIntegerArrayAttr(rewriter, 1 + indexedSize.index()));
}
// Return the final value of the descriptor.
@@ -677,20 +663,18 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
auto *type =
operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType();
- auto statically_shaped = type->isPointerTy();
+ auto hasStaticShape = type->isPointerTy();
Type elementPtrType =
- (statically_shaped)
+ (hasStaticShape)
? rewriter.getType<LLVM::LLVMType>(type)
: rewriter.getType<LLVM::LLVMType>(
cast<llvm::StructType>(type)->getStructElementType(0));
Value *bufferPtr = extractMemRefElementPtr(
- rewriter, op->getLoc(), operands[0], elementPtrType, statically_shaped);
+ rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape);
Value *casted = rewriter.create<LLVM::BitcastOp>(
op->getLoc(), getVoidPtrType(), bufferPtr);
- auto freeNamedAttr =
- rewriter.getNamedAttr("callee", rewriter.getFunctionAttr(freeFunc));
- rewriter.create<LLVM::CallOp>(op->getLoc(), casted,
- llvm::makeArrayRef(freeNamedAttr));
+ rewriter.create<LLVM::CallOp>(op->getLoc(), ArrayRef<Type>(),
+ rewriter.getFunctionAttr(freeFunc), casted);
return {};
}
};
@@ -734,8 +718,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
op->getLoc(), structType, ArrayRef<Value *>{});
// Otherwise target type is dynamic memref, so create a proper descriptor.
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, buffer},
- getPositionAttribute(rewriter, 0));
+ op->getLoc(), structType, newDescriptor, buffer,
+ getIntegerArrayAttr(rewriter, 0));
// Fill in the dynamic sizes of the new descriptor. If the size was
// dynamic, copy it from the old descriptor. If the size was static, insert
@@ -757,12 +741,12 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> {
sourceSize == -1
? rewriter.create<LLVM::ExtractValueOp>(
op->getLoc(), getIndexType(),
- ArrayRef<Value *>{operands[0]}, // NB: dynamic memref
- getPositionAttribute(rewriter, sourceDynamicDimIdx++))
+ operands[0], // NB: dynamic memref
+ getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++))
: createIndexConstant(rewriter, op->getLoc(), sourceSize);
newDescriptor = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, size},
- getPositionAttribute(rewriter, targetDynamicDimIdx++));
+ op->getLoc(), structType, newDescriptor, size,
+ getIntegerArrayAttr(rewriter, targetDynamicDimIdx++));
}
assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() &&
"source dynamic dimensions were not processed");
@@ -807,8 +791,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> {
++position;
}
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
- op->getLoc(), getIndexType(), operands,
- getPositionAttribute(rewriter, position)));
+ op->getLoc(), getIndexType(), operands[0],
+ getIntegerArrayAttr(rewriter, position)));
} else {
results.push_back(
createIndexConstant(rewriter, op->getLoc(), shape[index]));
@@ -876,9 +860,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
for (int64_t s : shape) {
if (s == -1) {
Value *size = rewriter.create<LLVM::ExtractValueOp>(
- loc, this->getIndexType(), ArrayRef<Value *>{memRefDescriptor},
- llvm::makeArrayRef(
- this->getPositionAttribute(rewriter, dynamicSizeIdx++)));
+ loc, this->getIndexType(), memRefDescriptor,
+ this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++));
sizes.push_back(size);
} else {
sizes.push_back(this->createIndexConstant(rewriter, loc, s));
@@ -890,8 +873,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes);
Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>(
- loc, elementTypePtr, ArrayRef<Value *>{memRefDescriptor},
- llvm::makeArrayRef(this->getPositionAttribute(rewriter, 0)));
+ loc, elementTypePtr, memRefDescriptor,
+ this->getIntegerArrayAttr(rewriter, 0));
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr,
ArrayRef<Value *>{dataPtr, subscript},
ArrayRef<NamedAttribute>{});
@@ -1018,11 +1001,9 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType);
for (unsigned i = 0; i < numArguments; ++i) {
- auto positionNamedAttr = getPositionAttribute(rewriter, i);
packed = rewriter.create<LLVM::InsertValueOp>(
- op->getLoc(), packedType,
- llvm::ArrayRef<Value *>{packed, operands[i]},
- llvm::makeArrayRef(positionNamedAttr));
+ op->getLoc(), packedType, packed, operands[i],
+ getIntegerArrayAttr(rewriter, i));
}
rewriter.create<LLVM::ReturnOp>(
op->getLoc(), llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(),
OpenPOWER on IntegriCloud