diff options
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 301 |
1 files changed, 152 insertions, 149 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index fdc90851b64..67b545c4ec8 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -256,20 +256,20 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, /*============================================================================*/ /* StructBuilder implementation */ /*============================================================================*/ -StructBuilder::StructBuilder(Value *v) : value(v) { +StructBuilder::StructBuilder(ValuePtr v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value->getType().cast<LLVM::LLVMType>(); } -Value *StructBuilder::extractPtr(OpBuilder &builder, Location loc, - unsigned pos) { +ValuePtr StructBuilder::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos); return builder.create<LLVM::ExtractValueOp>(loc, type, value, builder.getI64ArrayAttr(pos)); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, - Value *ptr) { + ValuePtr ptr) { value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } @@ -278,7 +278,7 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, /*============================================================================*/ /// Construct a helper for the given descriptor value. -MemRefDescriptor::MemRefDescriptor(Value *descriptor) +MemRefDescriptor::MemRefDescriptor(ValuePtr descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( @@ -289,7 +289,7 @@ MemRefDescriptor::MemRefDescriptor(Value *descriptor) MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - Value *descriptor = + ValuePtr descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); return MemRefDescriptor(descriptor); } @@ -300,7 +300,7 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - MemRefType type, Value *memory) { + MemRefType type, ValuePtr memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); assert(type.getAffineMaps().empty() && "unexpected layout map"); @@ -325,37 +325,37 @@ MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, } /// Builds IR extracting the allocated pointer from the descriptor. -Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { +ValuePtr MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, - Value *ptr) { + ValuePtr ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. -Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { +ValuePtr MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, - Value *ptr) { + ValuePtr ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } // Creates a constant Op producing a value of `resultType` from an index-typed // integer attribute. -static Value *createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { +static ValuePtr createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { return builder.create<LLVM::ConstantOp>( loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } /// Builds IR extracting the offset from the descriptor. -Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) { +ValuePtr MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create<LLVM::ExtractValueOp>( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); @@ -363,7 +363,7 @@ Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) { /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, - Value *offset) { + ValuePtr offset) { value = builder.create<LLVM::InsertValueOp>( loc, structType, value, offset, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); @@ -377,7 +377,8 @@ void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, } /// Builds IR extracting the pos-th size from the descriptor. -Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { +ValuePtr MemRefDescriptor::size(OpBuilder &builder, Location loc, + unsigned pos) { return builder.create<LLVM::ExtractValueOp>( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); @@ -385,7 +386,7 @@ Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, - Value *size) { + ValuePtr size) { value = builder.create<LLVM::InsertValueOp>( loc, structType, value, size, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); @@ -399,8 +400,8 @@ void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, } /// Builds IR extracting the pos-th size from the descriptor. -Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc, - unsigned pos) { +ValuePtr MemRefDescriptor::stride(OpBuilder &builder, Location loc, + unsigned pos) { return builder.create<LLVM::ExtractValueOp>( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); @@ -408,7 +409,7 @@ Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc, /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, - Value *stride) { + ValuePtr stride) { value = builder.create<LLVM::InsertValueOp>( loc, structType, value, stride, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); @@ -431,30 +432,30 @@ LLVM::LLVMType MemRefDescriptor::getElementType() { /*============================================================================*/ /// Construct a helper for the given descriptor value. -UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value *descriptor) +UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(ValuePtr descriptor) : StructBuilder(descriptor) {} /// Builds IR creating an `undef` value of the descriptor type. UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - Value *descriptor = + ValuePtr descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); return UnrankedMemRefDescriptor(descriptor); } -Value *UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { +ValuePtr UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, - Value *v) { + ValuePtr v) { setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } -Value *UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, - Location loc) { +ValuePtr UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, + Location loc) { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, - Location loc, Value *v) { + Location loc, ValuePtr v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } namespace { @@ -495,8 +496,8 @@ public: } // Create an LLVM IR pseudo-operation defining the given index constant. - Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, - uint64_t value) const { + ValuePtr createIndexConstant(ConversionPatternRewriter &builder, Location loc, + uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); } @@ -508,7 +509,7 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast<FuncOp>(op); FunctionType type = funcOp.getType(); @@ -556,8 +557,8 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> { Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); for (unsigned idx : promotedArgIndices) { - BlockArgument *arg = firstBlock->getArgument(idx); - Value *loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg); + BlockArgumentPtr arg = firstBlock->getArgument(idx); + ValuePtr loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg); rewriter.replaceUsesOfBlockArgument(arg, loaded); } } @@ -656,7 +657,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); @@ -680,7 +681,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. - SmallVector<Value *, 4> results; + SmallVector<ValuePtr, 4> results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); @@ -721,7 +722,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { ValidateOpCount<SourceOp, OpCount>(); static_assert( @@ -732,7 +733,7 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { "expected same operands and result type"); // Cannot convert ops if their operands are not of LLVM type. - for (Value *operand : operands) { + for (ValuePtr operand : operands) { if (!operand || !operand->getType().isa<LLVM::LLVMType>()) return this->matchFailure(); } @@ -755,16 +756,16 @@ struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> { if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return this->matchFailure(); - Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); + ValuePtr desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors - SmallVector<Value *, OpCount> extractedOperands; + SmallVector<ValuePtr, OpCount> extractedOperands; for (unsigned i = 0; i < OpCount; ++i) { extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( loc, llvmVectorTy, operands[i], position)); } - Value *newVal = rewriter.create<TargetOp>( + ValuePtr newVal = rewriter.create<TargetOp>( loc, llvmVectorTy, extractedOperands, op->getAttrs()); desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc, newVal, position); @@ -927,7 +928,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { return matchSuccess(); } - void rewrite(Operation *op, ArrayRef<Value *> operands, + void rewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto allocOp = cast<AllocOp>(op); @@ -936,7 +937,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). - SmallVector<Value *, 4> sizes; + SmallVector<ValuePtr, 4> sizes; sizes.reserve(type.getRank()); unsigned i = 0; for (int64_t s : type.getShape()) @@ -946,10 +947,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { sizes.push_back(createIndexConstant(rewriter, loc, 1)); // Compute the total number of memref elements. - Value *cumulativeSize = sizes.front(); + ValuePtr cumulativeSize = sizes.front(); for (unsigned i = 1, e = sizes.size(); i < e; ++i) cumulativeSize = rewriter.create<LLVM::MulOp>( - loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, sizes[i]}); + loc, getIndexType(), ArrayRef<ValuePtr>{cumulativeSize, sizes[i]}); // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: @@ -962,17 +963,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); auto one = createIndexConstant(rewriter, loc, 1); auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, - ArrayRef<Value *>{nullPtr, one}); + ArrayRef<ValuePtr>{nullPtr, one}); auto elementSize = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); cumulativeSize = rewriter.create<LLVM::MulOp>( - loc, getIndexType(), ArrayRef<Value *>{cumulativeSize, elementSize}); + loc, getIndexType(), ArrayRef<ValuePtr>{cumulativeSize, elementSize}); // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - Value *allocated = nullptr; + ValuePtr allocated = nullptr; int alignment = 0; - Value *alignmentValue = nullptr; + ValuePtr alignmentValue = nullptr; if (auto alignAttr = allocOp.alignment()) alignment = alignAttr.getValue().getSExtValue(); @@ -1008,8 +1009,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { auto structElementType = lowering.convertType(elementType); auto elementPtrType = structElementType.cast<LLVM::LLVMType>().getPointerTo( type.getMemorySpace()); - Value *bitcastAllocated = rewriter.create<LLVM::BitcastOp>( - loc, elementPtrType, ArrayRef<Value *>(allocated)); + ValuePtr bitcastAllocated = rewriter.create<LLVM::BitcastOp>( + loc, elementPtrType, ArrayRef<ValuePtr>(allocated)); int64_t offset; SmallVector<int64_t, 4> strides; @@ -1031,22 +1032,22 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { memRefDescriptor.setAllocatedPtr(rewriter, loc, bitcastAllocated); // Field 2: Actual aligned pointer to payload. - Value *bitcastAligned = bitcastAllocated; + ValuePtr bitcastAligned = bitcastAllocated; if (!useAlloca && alignment != 0) { assert(alignmentValue); // offset = (align - (ptr % align))% align - Value *intVal = rewriter.create<LLVM::PtrToIntOp>( + ValuePtr intVal = rewriter.create<LLVM::PtrToIntOp>( loc, this->getIndexType(), allocated); - Value *ptrModAlign = + ValuePtr ptrModAlign = rewriter.create<LLVM::URemOp>(loc, intVal, alignmentValue); - Value *subbed = + ValuePtr subbed = rewriter.create<LLVM::SubOp>(loc, alignmentValue, ptrModAlign); - Value *offset = + ValuePtr offset = rewriter.create<LLVM::URemOp>(loc, subbed, alignmentValue); - Value *aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(), - allocated, offset); + ValuePtr aligned = rewriter.create<LLVM::GEPOp>(loc, allocated->getType(), + allocated, offset); bitcastAligned = rewriter.create<LLVM::BitcastOp>( - loc, elementPtrType, ArrayRef<Value *>(aligned)); + loc, elementPtrType, ArrayRef<ValuePtr>(aligned)); } memRefDescriptor.setAlignedPtr(rewriter, loc, bitcastAligned); @@ -1061,10 +1062,10 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> { // Fields 4 and 5: Sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. - Value *runningStride = nullptr; + ValuePtr runningStride = nullptr; // Iterate strides in reverse order, compute runningStride and strideValues. auto nStrides = strides.size(); - SmallVector<Value *, 4> strideValues(nStrides, nullptr); + SmallVector<ValuePtr, 4> strideValues(nStrides, nullptr); for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) { int64_t index = nStrides - 1 - indexedStride.index(); if (strides[index] == MemRefType::getDynamicStrideOrOffset()) @@ -1101,7 +1102,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> { using Base = LLVMLegalizationPattern<CallOpType>; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor<CallOpType> transformed(operands); auto callOp = cast<CallOpType>(op); @@ -1139,7 +1140,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern<CallOpType> { // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around // a particular interaction between MemRefType and CallOp lowering. Find a // way to avoid special casing. - SmallVector<Value *, 4> results; + SmallVector<ValuePtr, 4> results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->lowering.convertType(op->getResult(i)->getType()); @@ -1173,7 +1174,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { useAlloca(useAlloca) {} PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { if (useAlloca) return rewriter.eraseOp(op), matchSuccess(); @@ -1193,7 +1194,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> { } MemRefDescriptor memref(transformed.memref()); - Value *casted = rewriter.create<LLVM::BitcastOp>( + ValuePtr casted = rewriter.create<LLVM::BitcastOp>( op->getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp<LLVM::CallOp>( @@ -1209,7 +1210,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern<TanhOp> { using LLVMLegalizationPattern<TanhOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { using LLVMFuncOpT = LLVM::LLVMFuncOp; @@ -1283,7 +1284,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { : matchFailure(); } - void rewrite(Operation *op, ArrayRef<Value *> operands, + void rewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast<MemRefCastOp>(op); OperandAdaptor<MemRefCastOp> transformed(operands); @@ -1324,7 +1325,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern<MemRefCastOp> { memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); - rewriter.replaceOp(op, (Value *)memRefDesc); + rewriter.replaceOp(op, (ValuePtr)memRefDesc); } else if (srcType.isa<UnrankedMemRefType>() && dstType.isa<MemRefType>()) { // Casting from unranked type to ranked. @@ -1355,7 +1356,7 @@ struct DimOpLowering : public LLVMLegalizationPattern<DimOp> { using LLVMLegalizationPattern<DimOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast<DimOp>(op); OperandAdaptor<DimOp> transformed(operands); @@ -1397,43 +1398,45 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { // by accumulating the running linearized value. // Note that `indices` and `allocSizes` are passed in the same order as they // appear in load/store operations and memref type declarations. - Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, - ArrayRef<Value *> indices, - ArrayRef<Value *> allocSizes) const { + ValuePtr linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, + ArrayRef<ValuePtr> indices, + ArrayRef<ValuePtr> allocSizes) const { assert(indices.size() == allocSizes.size() && "mismatching number of indices and allocation sizes"); assert(!indices.empty() && "cannot linearize a 0-dimensional access"); - Value *linearized = indices.front(); + ValuePtr linearized = indices.front(); for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { linearized = builder.create<LLVM::MulOp>( loc, this->getIndexType(), - ArrayRef<Value *>{linearized, allocSizes[i]}); + ArrayRef<ValuePtr>{linearized, allocSizes[i]}); linearized = builder.create<LLVM::AddOp>( - loc, this->getIndexType(), ArrayRef<Value *>{linearized, indices[i]}); + loc, this->getIndexType(), + ArrayRef<ValuePtr>{linearized, indices[i]}); } return linearized; } // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - Value *getStridedElementPtr(Location loc, Type elementTypePtr, - Value *descriptor, ArrayRef<Value *> indices, - ArrayRef<int64_t> strides, int64_t offset, - ConversionPatternRewriter &rewriter) const { + ValuePtr getStridedElementPtr(Location loc, Type elementTypePtr, + ValuePtr descriptor, ArrayRef<ValuePtr> indices, + ArrayRef<int64_t> strides, int64_t offset, + ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); - Value *base = memRefDescriptor.alignedPtr(rewriter, loc); - Value *offsetValue = offset == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.offset(rewriter, loc) - : this->createIndexConstant(rewriter, loc, offset); + ValuePtr base = memRefDescriptor.alignedPtr(rewriter, loc); + ValuePtr offsetValue = + offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.offset(rewriter, loc) + : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { - Value *stride = + ValuePtr stride = strides[i] == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.stride(rewriter, loc, i) : this->createIndexConstant(rewriter, loc, strides[i]); - Value *additionalOffset = + ValuePtr additionalOffset = rewriter.create<LLVM::MulOp>(loc, indices[i], stride); offsetValue = rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset); @@ -1441,10 +1444,10 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> { return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue); } - Value *getDataPtr(Location loc, MemRefType type, Value *memRefDesc, - ArrayRef<Value *> indices, - ConversionPatternRewriter &rewriter, - llvm::Module &module) const { + ValuePtr getDataPtr(Location loc, MemRefType type, ValuePtr memRefDesc, + ArrayRef<ValuePtr> indices, + ConversionPatternRewriter &rewriter, + llvm::Module &module) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); int64_t offset; SmallVector<int64_t, 4> strides; @@ -1462,14 +1465,14 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast<LoadOp>(op); OperandAdaptor<LoadOp> transformed(operands); auto type = loadOp.getMemRefType(); - Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr); return matchSuccess(); } @@ -1481,13 +1484,13 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto type = cast<StoreOp>(op).getMemRefType(); OperandAdaptor<StoreOp> transformed(operands); - Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(), dataPtr); return matchSuccess(); @@ -1500,14 +1503,14 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> { using Base::Base; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast<PrefetchOp>(op); OperandAdaptor<PrefetchOp> transformed(operands); auto type = prefetchOp.getMemRefType(); - Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter, getModule()); + ValuePtr dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter, getModule()); // Replace with llvm.prefetch. auto llvmI32Type = lowering.convertType(rewriter.getIntegerType(32)); @@ -1535,7 +1538,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern<IndexCastOp> { using LLVMLegalizationPattern<IndexCastOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); auto indexCastOp = cast<IndexCastOp>(op); @@ -1570,7 +1573,7 @@ struct CmpIOpLowering : public LLVMLegalizationPattern<CmpIOp> { using LLVMLegalizationPattern<CmpIOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast<CmpIOp>(op); CmpIOpOperandAdaptor transformed(operands); @@ -1589,7 +1592,7 @@ struct CmpFOpLowering : public LLVMLegalizationPattern<CmpFOp> { using LLVMLegalizationPattern<CmpFOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast<CmpFOp>(op); CmpFOpOperandAdaptor transformed(operands); @@ -1641,9 +1644,9 @@ struct OneToOneLLVMTerminatorLowering using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> properOperands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> properOperands, ArrayRef<Block *> destinations, - ArrayRef<ArrayRef<Value *>> operands, + ArrayRef<ArrayRef<ValuePtr>> operands, ConversionPatternRewriter &rewriter) const override { SmallVector<ValueRange, 2> operandRanges(operands.begin(), operands.end()); rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations, @@ -1662,19 +1665,19 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { using LLVMLegalizationPattern<ReturnOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( - op, ArrayRef<Value *>(), ArrayRef<Block *>(), op->getAttrs()); + op, ArrayRef<ValuePtr>(), ArrayRef<Block *>(), op->getAttrs()); return matchSuccess(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( - op, ArrayRef<Value *>(operands.front()), ArrayRef<Block *>(), + op, ArrayRef<ValuePtr>(operands.front()), ArrayRef<Block *>(), op->getAttrs()); return matchSuccess(); } @@ -1684,7 +1687,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> { auto packedType = lowering.packFunctionResults(llvm::to_vector<4>(op->getOperandTypes())); - Value *packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType); + ValuePtr packed = rewriter.create<LLVM::UndefOp>(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create<LLVM::InsertValueOp>( op->getLoc(), packedType, packed, operands[i], @@ -1712,7 +1715,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> { using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast<SplatOp>(op); VectorType resultType = splatOp.getType().dyn_cast<VectorType>(); @@ -1721,7 +1724,7 @@ struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> { // First insert it into an undef vector so we can shuffle it. auto vectorType = lowering.convertType(splatOp.getType()); - Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType); + ValuePtr undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType); auto zero = rewriter.create<LLVM::ConstantOp>( op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); @@ -1746,7 +1749,7 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> { using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast<SplatOp>(op); OperandAdaptor<SplatOp> adaptor(operands); @@ -1763,16 +1766,16 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> { return matchFailure(); // Construct returned value. - Value *desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); + ValuePtr desc = rewriter.create<LLVM::UndefOp>(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. - Value *vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy); + ValuePtr vdesc = rewriter.create<LLVM::UndefOp>(loc, llvmVectorTy); auto zero = rewriter.create<LLVM::ConstantOp>( loc, lowering.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); - Value *v = rewriter.create<LLVM::InsertElementOp>(loc, llvmVectorTy, vdesc, - adaptor.input(), zero); + ValuePtr v = rewriter.create<LLVM::InsertElementOp>( + loc, llvmVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); @@ -1800,21 +1803,21 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { using LLVMLegalizationPattern<SubViewOp>::LLVMLegalizationPattern; PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast<SubViewOp>(op); // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support // having multiple variadic operands where each operand can have different // number of entries, clean all of this up. - SmallVector<Value *, 2> dynamicOffsets( + SmallVector<ValuePtr, 2> dynamicOffsets( std::next(operands.begin()), std::next(operands.begin(), 1 + viewOp.getNumOffsets())); - SmallVector<Value *, 2> dynamicSizes( + SmallVector<ValuePtr, 2> dynamicSizes( std::next(operands.begin(), 1 + viewOp.getNumOffsets()), std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes())); - SmallVector<Value *, 2> dynamicStrides( + SmallVector<ValuePtr, 2> dynamicStrides( std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()), operands.end()); @@ -1851,8 +1854,8 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. - Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc); - Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>( + ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr bitcastPtr = rewriter.create<LLVM::BitcastOp>( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -1862,7 +1865,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Extract strides needed to compute offset. - SmallVector<Value *, 4> strideValues; + SmallVector<ValuePtr, 4> strideValues; strideValues.reserve(viewMemRefType.getRank()); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); @@ -1879,9 +1882,9 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { } // Offset. - Value *baseOffset = sourceMemRef.offset(rewriter, loc); + ValuePtr baseOffset = sourceMemRef.offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { - Value *min = dynamicOffsets[i]; + ValuePtr min = dynamicOffsets[i]; baseOffset = rewriter.create<LLVM::AddOp>( loc, baseOffset, rewriter.create<LLVM::MulOp>(loc, min, strideValues[i])); @@ -1891,7 +1894,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern<SubViewOp> { // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); - Value *newStride; + ValuePtr newStride; if (dynamicStrides.empty()) newStride = rewriter.create<LLVM::ConstantOp>( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); @@ -1916,9 +1919,9 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. - Value *getSize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef<int64_t> shape, ArrayRef<Value *> dynamicSizes, - unsigned idx) const { + ValuePtr getSize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> shape, ArrayRef<ValuePtr> dynamicSizes, + unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); @@ -1933,9 +1936,9 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. - Value *getStride(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef<int64_t> strides, Value *nextSize, - Value *runningStride, unsigned idx) const { + ValuePtr getStride(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> strides, ValuePtr nextSize, + ValuePtr runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); @@ -1948,7 +1951,7 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { } PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + matchAndRewrite(Operation *op, ArrayRef<ValuePtr> operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast<ViewOp>(op); @@ -1975,8 +1978,8 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. - Value *extracted = sourceMemRef.allocatedPtr(rewriter, loc); - Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>( + ValuePtr extracted = sourceMemRef.allocatedPtr(rewriter, loc); + ValuePtr bitcastPtr = rewriter.create<LLVM::BitcastOp>( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -1993,10 +1996,10 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { auto sizeAndOffsetOperands = adaptor.operands(); assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + (hasDynamicOffset ? 1 : 0)); - Value *baseOffset = !hasDynamicOffset - ? createIndexConstant(rewriter, loc, offset) - // TODO(ntv): better adaptor. - : sizeAndOffsetOperands.front(); + ValuePtr baseOffset = !hasDynamicOffset + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.front(); targetMemRef.setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. @@ -2007,14 +2010,14 @@ struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), matchFailure(); - Value *stride = nullptr, *nextSize = nullptr; + ValuePtr stride = nullptr, nextSize = nullptr; // Drop the dynamic stride from the operand list, if present. - ArrayRef<Value *> sizeOperands(sizeAndOffsetOperands); + ArrayRef<ValuePtr> sizeOperands(sizeAndOffsetOperands); if (hasDynamicOffset) sizeOperands = sizeOperands.drop_front(); for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. - Value *size = + ValuePtr size = getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. @@ -2058,7 +2061,7 @@ static void ensureDistinctSuccessors(Block &bb) { auto *dummyBlock = new Block(); bb.getParent()->push_back(dummyBlock); auto builder = OpBuilder(dummyBlock); - SmallVector<Value *, 8> operands( + SmallVector<ValuePtr, 8> operands( terminator->getSuccessorOperands(*position)); builder.create<BranchOp>(terminator->getLoc(), successor.first, operands); terminator->setSuccessor(dummyBlock, *position); @@ -2179,33 +2182,33 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) { return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } -Value *LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, - Value *operand, - OpBuilder &builder) { +ValuePtr LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, + ValuePtr operand, + OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand->getType().cast<LLVM::LLVMType>().getPointerTo(); - Value *one = builder.create<LLVM::ConstantOp>(loc, int64Ty, - IntegerAttr::get(indexType, 1)); - Value *allocated = + ValuePtr one = builder.create<LLVM::ConstantOp>( + loc, int64Ty, IntegerAttr::get(indexType, 1)); + ValuePtr allocated = builder.create<LLVM::AllocaOp>(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create<LLVM::StoreOp>(loc, operand, allocated); return allocated; } -SmallVector<Value *, 4> +SmallVector<ValuePtr, 4> LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder) { - SmallVector<Value *, 4> promotedOperands; + SmallVector<ValuePtr, 4> promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { - auto *operand = std::get<0>(it); - auto *llvmOperand = std::get<1>(it); + auto operand = std::get<0>(it); + auto llvmOperand = std::get<1>(it); if (!operand->getType().isa<MemRefType>() && !operand->getType().isa<UnrankedMemRefType>()) { promotedOperands.push_back(operand); |