diff options
| -rw-r--r-- | mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 97 |
1 files changed, 39 insertions, 58 deletions
diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 332d6324879..8150defbd7d 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -375,37 +375,32 @@ public: Value *createIndexConstant(FuncBuilder &builder, Location loc, uint64_t value) const { auto attr = builder.getIntegerAttr(builder.getIndexType(), value); - auto namedAttr = builder.getNamedAttr("value", attr); - return builder.create<LLVM::ConstantOp>( - loc, getIndexType(), ArrayRef<Value *>{}, - ArrayRef<NamedAttribute>{namedAttr}); + return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr); } // Get the array attribute named "position" containing the given list of // integers as integer attribute elements. - static NamedAttribute getPositionAttribute(FuncBuilder &builder, - ArrayRef<int64_t> positions) { - SmallVector<Attribute, 4> attrPositions; - attrPositions.reserve(positions.size()); - for (int64_t pos : positions) - attrPositions.push_back( - builder.getIntegerAttr(builder.getIndexType(), pos)); - return builder.getNamedAttr("position", - builder.getArrayAttr(attrPositions)); + static ArrayAttr getIntegerArrayAttr(FuncBuilder &builder, + ArrayRef<int64_t> values) { + SmallVector<Attribute, 4> attrs; + attrs.reserve(values.size()); + for (int64_t pos : values) + attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos)); + return builder.getArrayAttr(attrs); } // Extract raw data pointer value from a value representing a memref. static Value *extractMemRefElementPtr(FuncBuilder &builder, Location loc, Value *convertedMemRefValue, Type elementTypePtr, - bool statically_shaped) { + bool hasStaticShape) { Value *buffer; - if (statically_shaped) + if (hasStaticShape) return convertedMemRefValue; else return builder.create<LLVM::ExtractValueOp>( - loc, elementTypePtr, ArrayRef<Value *>{convertedMemRefValue}, - getPositionAttribute(builder, 0)); + loc, elementTypePtr, convertedMemRefValue, + getIntegerArrayAttr(builder, 0)); return buffer; } @@ -461,13 +456,11 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { SmallVector<Value *, 4> results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { - auto positionNamedAttr = this->getPositionAttribute(rewriter, i); auto type = TypeConverter::convert(op->getResult(i)->getType(), this->dialect.getLLVMModule()); results.push_back(rewriter.create<LLVM::ExtractValueOp>( - op->getLoc(), type, - ArrayRef<Value *>(newOp->getInstruction()->getResult(0)), - llvm::makeArrayRef(positionNamedAttr))); + op->getLoc(), type, newOp->getInstruction()->getResult(0), + this->getIntegerArrayAttr(rewriter, i))); } return results; } @@ -608,13 +601,11 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto mallocNamedAttr = - rewriter.getNamedAttr("callee", rewriter.getFunctionAttr(mallocFunc)); Value *allocated = rewriter .create<LLVM::CallOp>(op->getLoc(), getVoidPtrType(), - ArrayRef<Value *>(cumulativeSize), - llvm::makeArrayRef(mallocNamedAttr)) + rewriter.getFunctionAttr(mallocFunc), + cumulativeSize) ->getResult(0); auto structElementType = TypeConverter::convert(elementType, getModule()); auto elementPtrType = LLVM::LLVMType::get( @@ -634,21 +625,16 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { Value *memRefDescriptor = rewriter.create<LLVM::UndefOp>( op->getLoc(), structType, ArrayRef<Value *>{}); - auto namedPositionAttr = getPositionAttribute(rewriter, 0); memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, - ArrayRef<Value *>{memRefDescriptor, allocated}, - llvm::makeArrayRef(namedPositionAttr)); + op->getLoc(), structType, memRefDescriptor, allocated, + getIntegerArrayAttr(rewriter, 0)); // Store dynamically allocated sizes in the descriptor. Dynamic sizes are // passed in as operands. for (auto indexedSize : llvm::enumerate(operands)) { - auto positionAttr = - getPositionAttribute(rewriter, 1 + indexedSize.index()); memRefDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, - ArrayRef<Value *>{memRefDescriptor, indexedSize.value()}, - llvm::makeArrayRef(positionAttr)); + op->getLoc(), structType, memRefDescriptor, indexedSize.value(), + getIntegerArrayAttr(rewriter, 1 + indexedSize.index())); } // Return the final value of the descriptor. @@ -677,20 +663,18 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { auto *type = operands[0]->getType().cast<LLVM::LLVMType>().getUnderlyingType(); - auto statically_shaped = type->isPointerTy(); + auto hasStaticShape = type->isPointerTy(); Type elementPtrType = - (statically_shaped) + (hasStaticShape) ? rewriter.getType<LLVM::LLVMType>(type) : rewriter.getType<LLVM::LLVMType>( cast<llvm::StructType>(type)->getStructElementType(0)); Value *bufferPtr = extractMemRefElementPtr( - rewriter, op->getLoc(), operands[0], elementPtrType, statically_shaped); + rewriter, op->getLoc(), operands[0], elementPtrType, hasStaticShape); Value *casted = rewriter.create<LLVM::BitcastOp>( op->getLoc(), getVoidPtrType(), bufferPtr); - auto freeNamedAttr = - rewriter.getNamedAttr("callee", rewriter.getFunctionAttr(freeFunc)); - rewriter.create<LLVM::CallOp>(op->getLoc(), casted, - llvm::makeArrayRef(freeNamedAttr)); + rewriter.create<LLVM::CallOp>(op->getLoc(), ArrayRef<Type>(), + rewriter.getFunctionAttr(freeFunc), casted); return {}; } }; @@ -734,8 +718,8 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { op->getLoc(), structType, ArrayRef<Value *>{}); // Otherwise target type is dynamic memref, so create a proper descriptor. newDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, buffer}, - getPositionAttribute(rewriter, 0)); + op->getLoc(), structType, newDescriptor, buffer, + getIntegerArrayAttr(rewriter, 0)); // Fill in the dynamic sizes of the new descriptor. If the size was // dynamic, copy it from the old descriptor. If the size was static, insert @@ -757,12 +741,12 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { sourceSize == -1 ? rewriter.create<LLVM::ExtractValueOp>( op->getLoc(), getIndexType(), - ArrayRef<Value *>{operands[0]}, // NB: dynamic memref - getPositionAttribute(rewriter, sourceDynamicDimIdx++)) + operands[0], // NB: dynamic memref + getIntegerArrayAttr(rewriter, sourceDynamicDimIdx++)) : createIndexConstant(rewriter, op->getLoc(), sourceSize); newDescriptor = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), structType, ArrayRef<Value *>{newDescriptor, size}, - getPositionAttribute(rewriter, targetDynamicDimIdx++)); + op->getLoc(), structType, newDescriptor, size, + getIntegerArrayAttr(rewriter, targetDynamicDimIdx++)); } assert(sourceDynamicDimIdx - 1 == sourceType.getNumDynamicDims() && "source dynamic dimensions were not processed"); @@ -807,8 +791,8 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> { ++position; } results.push_back(rewriter.create<LLVM::ExtractValueOp>( - op->getLoc(), getIndexType(), operands, - getPositionAttribute(rewriter, position))); + op->getLoc(), getIndexType(), operands[0], + getIntegerArrayAttr(rewriter, position))); } else { results.push_back( createIndexConstant(rewriter, op->getLoc(), shape[index])); @@ -876,9 +860,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { for (int64_t s : shape) { if (s == -1) { Value *size = rewriter.create<LLVM::ExtractValueOp>( - loc, this->getIndexType(), ArrayRef<Value *>{memRefDescriptor}, - llvm::makeArrayRef( - this->getPositionAttribute(rewriter, dynamicSizeIdx++))); + loc, this->getIndexType(), memRefDescriptor, + this->getIntegerArrayAttr(rewriter, dynamicSizeIdx++)); sizes.push_back(size); } else { sizes.push_back(this->createIndexConstant(rewriter, loc, s)); @@ -890,8 +873,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { Value *subscript = linearizeSubscripts(rewriter, loc, indices, sizes); Value *dataPtr = rewriter.create<LLVM::ExtractValueOp>( - loc, elementTypePtr, ArrayRef<Value *>{memRefDescriptor}, - llvm::makeArrayRef(this->getPositionAttribute(rewriter, 0))); + loc, elementTypePtr, memRefDescriptor, + this->getIntegerArrayAttr(rewriter, 0)); return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, ArrayRef<Value *>{dataPtr, subscript}, ArrayRef<NamedAttribute>{}); @@ -1018,11 +1001,9 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { - auto positionNamedAttr = getPositionAttribute(rewriter, i); packed = rewriter.create<LLVM::InsertValueOp>( - op->getLoc(), packedType, - llvm::ArrayRef<Value *>{packed, operands[i]}, - llvm::makeArrayRef(positionNamedAttr)); + op->getLoc(), packedType, packed, operands[i], + getIntegerArrayAttr(rewriter, i)); } rewriter.create<LLVM::ReturnOp>( op->getLoc(), llvm::makeArrayRef(packed), llvm::ArrayRef<Block *>(), |

