diff options
author | Alex Zinenko <zinenko@google.com> | 2019-11-14 08:03:39 -0800 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-14 08:04:10 -0800 |
commit | 7c28de4aef6da3ab2f53118ecf717e56c68352e7 (patch) | |
tree | 97a223bbd9ce3d650f8ee17560d7693256e4eb60 /mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | |
parent | a007d4395a36adf3aad0f4b9914dcb8756a37c7d (diff) | |
download | bcm5719-llvm-7c28de4aef6da3ab2f53118ecf717e56c68352e7.tar.gz bcm5719-llvm-7c28de4aef6da3ab2f53118ecf717e56c68352e7.zip |
Use MemRefDescriptor in Linalg-to-LLVM conversion
Following up on the consolidation of MemRef descriptor conversion, update
Linalg-to-LLVM conversion to use the helper class that abstracts away the
implementation details of the MemRef descriptor. This required MemRefDescriptor
to become publicly visible. Since this conversion is heavily EDSC-based,
introduce locally an additional wrapper that uses builder and location pointed
to by the EDSC context while emitting descriptor manipulation operations.
PiperOrigin-RevId: 280429228
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 201 |
1 files changed, 96 insertions, 105 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 0641a6b9ab0..570b6c4bcf2 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -234,126 +234,117 @@ LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, context), lowering(lowering_) {} -namespace { -/// Helper class to produce LLVM dialect operations extracting or inserting -/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. -/// The Value may be null, in which case none of the operations are valid. -class MemRefDescriptor { -public: - /// Construct a helper for the given descriptor value. - explicit MemRefDescriptor(Value *descriptor) : value(descriptor) { - if (value) { - structType = value->getType().cast<LLVM::LLVMType>(); - indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor); - } - } - - /// Builds IR creating an `undef` value of the descriptor type. - static MemRefDescriptor undef(OpBuilder &builder, Location loc, - Type descriptorType) { - Value *descriptor = builder.create<LLVM::UndefOp>( - loc, descriptorType.cast<LLVM::LLVMType>()); - return MemRefDescriptor(descriptor); - } - - /// Builds IR extracting the allocated pointer from the descriptor. - Value *allocatedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, - LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); - } - - /// Builds IR inserting the allocated pointer into the descriptor. - void setAllocatedPtr(OpBuilder &builder, Location loc, Value *ptr) { - setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor, - ptr); - } - - /// Builds IR extracting the aligned pointer from the descriptor. - Value *alignedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, - LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); +/*============================================================================*/ +/* MemRefDescriptor implementation */ +/*============================================================================*/ + +/// Construct a helper for the given descriptor value. +MemRefDescriptor::MemRefDescriptor(Value *descriptor) : value(descriptor) { + if (value) { + structType = value->getType().cast<LLVM::LLVMType>(); + indexType = value->getType().cast<LLVM::LLVMType>().getStructElementType( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor); } +} - /// Builds IR inserting the aligned pointer into the descriptor. - void setAlignedPtr(OpBuilder &builder, Location loc, Value *ptr) { - setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor, - ptr); - } +/// Builds IR creating an `undef` value of the descriptor type. +MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, + Type descriptorType) { + Value *descriptor = + builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>()); + return MemRefDescriptor(descriptor); +} - /// Builds IR extracting the offset from the descriptor. - Value *offset(OpBuilder &builder, Location loc) { - return builder.create<LLVM::ExtractValueOp>( - loc, indexType, value, - builder.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); - } +/// Builds IR extracting the allocated pointer from the descriptor. +Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, + LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor); +} - /// Builds IR inserting the offset into the descriptor. - void setOffset(OpBuilder &builder, Location loc, Value *offset) { - value = builder.create<LLVM::InsertValueOp>( - loc, structType, value, offset, - builder.getI64ArrayAttr( - LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); - } +/// Builds IR inserting the allocated pointer into the descriptor. +void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, + Value *ptr) { + setPtr(builder, loc, LLVMTypeConverter::kAllocatedPtrPosInMemRefDescriptor, + ptr); +} - /// Builds IR extracting the pos-th size from the descriptor. - Value *size(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, indexType, value, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); - } +/// Builds IR extracting the aligned pointer from the descriptor. +Value *MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, + LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor); +} - /// Builds IR inserting the pos-th size into the descriptor - void setSize(OpBuilder &builder, Location loc, unsigned pos, Value *size) { - value = builder.create<LLVM::InsertValueOp>( - loc, structType, value, size, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); - } +/// Builds IR inserting the aligned pointer into the descriptor. +void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, + Value *ptr) { + setPtr(builder, loc, LLVMTypeConverter::kAlignedPtrPosInMemRefDescriptor, + ptr); +} - /// Builds IR extracting the pos-th size from the descriptor. - Value *stride(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create<LLVM::ExtractValueOp>( - loc, indexType, value, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); - } +/// Builds IR extracting the offset from the descriptor. +Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); +} - /// Builds IR inserting the pos-th stride into the descriptor - void setStride(OpBuilder &builder, Location loc, unsigned pos, - Value *stride) { - value = builder.create<LLVM::InsertValueOp>( - loc, structType, value, stride, - builder.getI64ArrayAttr( - {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); - } +/// Builds IR inserting the offset into the descriptor. +void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, + Value *offset) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, offset, + builder.getI64ArrayAttr(LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); +} - /*implicit*/ operator Value *() { return value; } +/// Builds IR extracting the pos-th size from the descriptor. +Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); +} -private: - Value *extractPtr(OpBuilder &builder, Location loc, unsigned pos) { - Type type = structType.getStructElementType(pos); - return builder.create<LLVM::ExtractValueOp>(loc, type, value, - builder.getI64ArrayAttr(pos)); - } +/// Builds IR inserting the pos-th size into the descriptor +void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, + Value *size) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, size, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, pos})); +} - void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value *ptr) { - value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, - builder.getI64ArrayAttr(pos)); - } +/// Builds IR extracting the pos-th size from the descriptor. +Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc, + unsigned pos) { + return builder.create<LLVM::ExtractValueOp>( + loc, indexType, value, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); +} - // Cached descriptor type. - LLVM::LLVMType structType; +/// Builds IR inserting the pos-th stride into the descriptor +void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, + Value *stride) { + value = builder.create<LLVM::InsertValueOp>( + loc, structType, value, stride, + builder.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, pos})); +} - // Cached index type. - LLVM::LLVMType indexType; +Value *MemRefDescriptor::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { + Type type = structType.cast<LLVM::LLVMType>().getStructElementType(pos); + return builder.create<LLVM::ExtractValueOp>(loc, type, value, + builder.getI64ArrayAttr(pos)); +} - // Actual descriptor. - Value *value; -}; +void MemRefDescriptor::setPtr(OpBuilder &builder, Location loc, unsigned pos, + Value *ptr) { + value = builder.create<LLVM::InsertValueOp>(loc, structType, value, ptr, + builder.getI64ArrayAttr(pos)); +} +namespace { // Base class for Standard to LLVM IR op conversions. Matches the Op type // provided as template argument. Carries a reference to the LLVM dialect in // case it is necessary for rewriters. |