diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 201 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp | 157 |
2 files changed, 157 insertions, 201 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 0641a6b9ab0..570b6c4bcf2 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -234,126 +234,117 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} -namespace { -/// Helper class to produce LLVM dialect operations extracting or inserting -/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. -/// The Value may be null, in which case none of the operations are valid. -class MemRefDescriptor { -public: - /// Construct a helper for the given descriptor value. - explicit MemRefDescriptor(Value *descriptor) : value(descriptor) { - if (value) { - structType = value->getType().cast<LLVM::LLVMType>(); - indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor); - } - } - - /// Builds IR creating an `undef` value of the descriptor type. - static MemRefDescriptor undef(OpBuilder &builder, Location loc, - Type descriptorType) { - Value *descriptor = builder.create<LLVM::UndefOp>( - loc, descriptorType.cast<LLVM::LLVMType>()); - return MemRefDescriptor(descriptor); - } - - /// Builds IR extracting the allocated pointer from the descriptor. - Value *allocatedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); - } - - /// Builds IR inserting the allocated pointer into the descriptor. - void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) { - setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor, - ptr); - } - - /// Builds IR extracting the aligned pointer from the descriptor. - Value *alignedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); +/*============================================================================*/ +/* MemRefDescriptor implementation */ +/*============================================================================*/ + +/// Construct a helper for the given descriptor value. +MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) { + if (value) { + structType = value->getType().cast<LLVM::LLVMType>(); + indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor); } +} - /// Builds IR inserting the aligned pointer into the descriptor. - void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) { - setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor, - ptr); - } +/// Builds IR creating an `undef` value of the descriptor type. +MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, + Type descriptorType) { + Value *descriptor = + builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); + return MemRefDescriptor(descriptor); +} - /// Builds IR extracting the offset from the descriptor. - Value *offset(OpBuilder &builder, Location loc) { - return builder.create<LLVM::ExtractValueOp>( - loc, indexType, value, - builder.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); - } +/// Builds IR extracting the allocated pointer from the descriptor. +Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); +} - /// Builds IR inserting the offset into the descriptor. - void setOffset(OpBuilder &builder, Location loc, Value *offset) { - value = builder.create<LLVM::InsertValueOp>( - loc, structType, value, offset, - builder.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); - } +/// Builds IR inserting the allocated pointer into the descriptor. +void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, + Value *ptr) { + setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor, + ptr); +} - /// Builds IR extracting the pos-th size from the descriptor. - Value *size(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, indexType, value, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); - } +/// Builds IR extracting the aligned pointer from the descriptor. +Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); +} - /// Builds IR inserting the pos-th size into the descriptor - void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) { - value = builder.create<LLVM::InsertValueOp>( - loc, structType, value, size, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); - } +/// Builds IR inserting the aligned pointer into the descriptor. +void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, + Value *ptr) { + setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor, + ptr); +} - /// Builds IR extracting the pos-th size from the descriptor. - Value *stride(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, indexType, value, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); - } +/// Builds IR extracting the offset from the descriptor. +Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); +} - /// Builds IR inserting the pos-th stride into the descriptor - void setStride(OpBuilder &builder, Location loc, unsigned pos, - Value *stride) { - value = builder.create<LLVM::InsertValueOp>( - loc, structType, value, stride, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); - } +/// Builds IR inserting the offset into the descriptor. +void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, + Value *offset) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, offset, + builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); +} - /*implicit*/ operator Value *() { return value; } +/// Builds IR extracting the pos-th size from the descriptor. +Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); +} -private: - Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) { - Type type = structType.getStructElementType(pos); - return builder.create<LLVM::ExtractValueOp>(loc, type, value, - builder.getI64ArrayAttr(pos)); - } +/// Builds IR inserting the pos-th size into the descriptor +void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, + Value *size) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, size, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); +} - void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) { - value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, - builder.getI64ArrayAttr(pos)); - } +/// Builds IR extracting the pos-th size from the descriptor. +Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc, + unsigned pos) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); +} - // Cached descriptor type. - LLVM::LLVMType structType; +/// Builds IR inserting the pos-th stride into the descriptor +void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, + Value *stride) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, stride, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); +} - // Cached index type. - LLVM::LLVMType indexType; +Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { + Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos); + return builder.create<LLVM::ExtractValueOp>(loc, type, value, + builder.getI64ArrayAttr(pos)); +} - // Actual descriptor. - Value *value; -}; +void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos, + Value *ptr) { + value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, + builder.getI64ArrayAttr(pos)); +} +namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 9d0395346af..61614aaf417 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -128,33 +128,33 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { } namespace { -/// Factor out the common information for all view conversions: -/// 1. common types in (standard and LLVM dialects) -/// 2. `pos` method -/// 3. view descriptor construction `desc`. +/// EDSC-compatible wrapper for MemRefDescriptor. class BaseViewConversionHelper { public: - BaseViewConversionHelper(Location loc, MemRefType memRefType, - ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering) - : zeroDMemRef(memRefType.getRank() == 0), - elementTy(getPtrToElementType(memRefType, lowering)), - int64Ty( - lowering.convertType(rewriter.getIntegerType(64)).cast<LLVMType>()), - desc(nullptr), rewriter(rewriter) { - assert(isStrided(memRefType) && "expected strided memref type"); - viewDescriptorTy = lowering.convertType(memRefType).cast<LLVMType>(); - desc = rewriter.create<LLVM::UndefOp>(loc, viewDescriptorTy); - } - - ArrayAttr pos(ArrayRef<int64_t> values) const { - return rewriter.getI64ArrayAttr(values); - }; - - bool zeroDMemRef; - LLVMType elementTy, int64Ty, viewDescriptorTy; - Value *desc; - ConversionPatternRewriter &rewriter; + BaseViewConversionHelper(Type type) + : d(MemRefDescriptor::undef(rewriter(), loc(), type)) {} + + BaseViewConversionHelper(Value *v) : d(v) {} + + /// Wrappers around MemRefDescriptor that use EDSC builder and location. + Value *allocatedPtr() { return d.allocatedPtr(rewriter(), loc()); } + void setAllocatedPtr(Value *v) { d.setAllocatedPtr(rewriter(), loc(), v); } + Value *alignedPtr() { return d.alignedPtr(rewriter(), loc()); } + void setAlignedPtr(Value *v) { d.setAlignedPtr(rewriter(), loc(), v); } + Value *offset() { return d.offset(rewriter(), loc()); } + void setOffset(Value *v) { d.setOffset(rewriter(), loc(), v); } + Value *size(unsigned i) { return d.size(rewriter(), loc(), i); } + void setSize(unsigned i, Value *v) { d.setSize(rewriter(), loc(), i, v); } + Value *stride(unsigned i) { return d.stride(rewriter(), loc(), i); } + void setStride(unsigned i, Value *v) { d.setStride(rewriter(), loc(), i, v); } + + operator Value *() { return d; } + +private: + OpBuilder &rewriter() { return ScopedContext::getBuilder(); } + Location loc() { return ScopedContext::getLocation(); } + + MemRefDescriptor d; }; } // namespace @@ -200,53 +200,46 @@ public: PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { + edsc::ScopedContext context(rewriter, op->getLoc()); SliceOpOperandAdaptor adaptor(operands); - Value *baseDesc = adaptor.view(); + BaseViewConversionHelper baseDesc(adaptor.view()); auto sliceOp = cast<SliceOp>(op); auto memRefType = sliceOp.getBaseViewType(); + auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)) + .cast<LLVM::LLVMType>(); - BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(), - rewriter, lowering); - LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; - Value *desc = helper.desc; - - edsc::ScopedContext context(rewriter, op->getLoc()); + BaseViewConversionHelper desc(lowering.convertType(sliceOp.getViewType())); // TODO(ntv): extract sizes and emit asserts. SmallVector<Value *, 4> strides(memRefType.getRank()); for (int i = 0, e = memRefType.getRank(); i < e; ++i) - strides[i] = extractvalue( - int64Ty, baseDesc, - helper.pos({LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); + strides[i] = baseDesc.stride(i); + + auto pos = [&rewriter](ArrayRef<int64_t> values) { + return rewriter.getI64ArrayAttr(values); + }; // Compute base offset. - Value *baseOffset = extractvalue( - int64Ty, baseDesc, - helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + Value *baseOffset = baseDesc.offset(); for (int i = 0, e = memRefType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; Value *min = indexing; if (sliceOp.indexing(i)->getType().isa<RangeType>()) - min = extractvalue(int64Ty, indexing, helper.pos(0)); + min = extractvalue(int64Ty, indexing, pos(0)); baseOffset = add(baseOffset, mul(min, strides[i])); } // Insert the base and aligned pointers. - auto ptrPos = - helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); - desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); - ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); + desc.setAllocatedPtr(baseDesc.allocatedPtr()); + desc.setAlignedPtr(baseDesc.alignedPtr()); // Insert base offset. - desc = insertvalue( - desc, baseOffset, - helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + desc.setOffset(baseOffset); // Corner case, no sizes or strides: early return the descriptor. - if (helper.zeroDMemRef) - return rewriter.replaceOp(op, desc), matchSuccess(); + if (sliceOp.getViewType().getRank() == 0) + return rewriter.replaceOp(op, {desc}), matchSuccess(); Value *zero = constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); @@ -258,12 +251,11 @@ public: if (indexing->getType().isa<RangeType>()) { int rank = en.index(); Value *rangeDescriptor = adaptor.indexings()[rank]; - Value *min = extractvalue(int64Ty, rangeDescriptor, helper.pos(0)); - Value *max = extractvalue(int64Ty, rangeDescriptor, helper.pos(1)); - Value *step = extractvalue(int64Ty, rangeDescriptor, helper.pos(2)); - Value *baseSize = extractvalue( - int64Ty, baseDesc, - helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, rank})); + Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0)); + Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1)); + Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value *baseSize = baseDesc.size(rank); + // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); @@ -272,19 +264,13 @@ public: size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); Value *stride = mul(strides[rank], step); - desc = insertvalue( - desc, size, - helper.pos( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, numNewDims})); - desc = insertvalue( - desc, stride, - helper.pos( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, numNewDims})); + desc.setSize(numNewDims, size); + desc.setStride(numNewDims, stride); ++numNewDims; } } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, {desc}); return matchSuccess(); } }; @@ -306,56 +292,35 @@ public: matchAndRewrite(Operation *op, ArrayRef<Value *> operands, ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. + edsc::ScopedContext context(rewriter, op->getLoc()); TransposeOpOperandAdaptor adaptor(operands); - Value *baseDesc = adaptor.view(); + BaseViewConversionHelper baseDesc(adaptor.view()); auto transposeOp = cast<TransposeOp>(op); // No permutation, early exit. if (transposeOp.permutation().isIdentity()) - return rewriter.replaceOp(op, baseDesc), matchSuccess(); + return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); - BaseViewConversionHelper helper(op->getLoc(), transposeOp.getViewType(), - rewriter, lowering); - LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; - Value *desc = helper.desc; + BaseViewConversionHelper desc( + lowering.convertType(transposeOp.getViewType())); - edsc::ScopedContext context(rewriter, op->getLoc()); // Copy the base and aligned pointers from the old descriptor to the new // one. - ArrayAttr ptrPos = - helper.pos(LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); - desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); - ptrPos = helper.pos(LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); - desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); + desc.setAllocatedPtr(baseDesc.allocatedPtr()); + desc.setAlignedPtr(baseDesc.alignedPtr()); // Copy the offset pointer from the old descriptor to the new one. - ArrayAttr offPos = - helper.pos(LLVMTypeConverter::kOffsetPosInMemRefDescriptor); - desc = insertvalue(desc, extractvalue(int64Ty, baseDesc, offPos), offPos); + desc.setOffset(baseDesc.offset()); // Iterate over the dimensions and apply size/stride permutation. for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast<AffineDimExpr>().getPosition(); - Value *size = extractvalue( - int64Ty, baseDesc, - helper.pos( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, sourcePos})); - desc = - insertvalue(desc, size, - helper.pos({LLVMTypeConverter::kSizePosInMemRefDescriptor, - targetPos})); - Value *stride = extractvalue( - int64Ty, baseDesc, - helper.pos( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, sourcePos})); - desc = insertvalue( - desc, stride, - helper.pos( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, targetPos})); + desc.setSize(targetPos, baseDesc.size(sourcePos)); + desc.setStride(targetPos, baseDesc.stride(sourcePos)); } - rewriter.replaceOp(op, desc); + rewriter.replaceOp(op, {desc}); return matchSuccess(); } }; |