diff options
| author | Alex Zinenko <zinenko@google.com> | 2019-12-06 10:08:15 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-06 10:08:43 -0800 |
| commit | e216a72ab8587c443e4c5c06aabc71c36712ce7e (patch) | |
| tree | 241edd8630181789b45a0fae79fcc5c25f128022 /mlir/lib/Conversion/StandardToLLVM | |
| parent | 3c69ca1e696645a944fac6c9794d71e8424665c5 (diff) | |
| download | bcm5719-llvm-e216a72ab8587c443e4c5c06aabc71c36712ce7e.tar.gz bcm5719-llvm-e216a72ab8587c443e4c5c06aabc71c36712ce7e.zip | |
Add conversions of GPU func with memory attributions to LLVM/NVVM
GPU functions use memory attributions, a combination of Op attributes and
region arguments, to specify function-wide buffers placed in workgroup or
private memory spaces. Introduce a lowering pattern for GPU functions to be
converted to LLVM functions taking into account memory attributions. Workgroup
attributions get transformed into module-level globals with unique names
derived from function names. Private attributions get converted into
llvm.allocas inside the function body. In both cases, we inject at the
beginning of the function the IR that obtains the raw pointer to the data and
populates a MemRef descriptor based on the MemRef type of buffer, making
attributions compose with the rest of the MemRef lowering and transparent for
use with std.load and std.store. While using raw pointers instead of
descriptors might have been more efficient, it is better implemented as a
canonicalization or a separate transformation so that non-attribution memrefs
could also benefit from it.
PiperOrigin-RevId: 284208396
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 62 |
1 files changed, 60 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 7b15b758968..c1a7a336401 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -304,6 +304,36 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, return MemRefDescriptor(descriptor); } +/// Builds IR creating a MemRef descriptor that represents `type` and +/// populates it with static shape and stride information extracted from the +/// type. +MemRefDescriptor +MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + MemRefType type, Value *memory) { + assert(type.hasStaticShape() && "unexpected dynamic shape"); + assert(type.getAffineMaps().empty() && "unexpected layout map"); + + auto convertedType = typeConverter.convertType(type); + assert(convertedType && "unexpected failure in memref type conversion"); + + auto descr = MemRefDescriptor::undef(builder, loc, convertedType); + descr.setAllocatedPtr(builder, loc, memory); + descr.setAlignedPtr(builder, loc, memory); + descr.setConstantOffset(builder, loc, 0); + + // Fill in sizes and strides, in reverse order to simplify stride + // calculation. + uint64_t runningStride = 1; + for (unsigned i = type.getRank(); i > 0; --i) { + unsigned dim = i - 1; + descr.setConstantSize(builder, loc, dim, type.getDimSize(dim)); + descr.setConstantStride(builder, loc, dim, runningStride); + runningStride *= type.getDimSize(dim); + } + return descr; +} + /// Builds IR extracting the allocated pointer from the descriptor. Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); @@ -326,6 +356,14 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } +// Creates a constant Op producing a value of `resultType` from an index-typed +// integer attribute. +static Value *createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create<LLVM::ConstantOp>( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + /// Builds IR extracting the offset from the descriptor. Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create<LLVM::ExtractValueOp>( @@ -341,6 +379,13 @@ void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } +/// Builds IR inserting the offset into the descriptor. +void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, + uint64_t offset) { + setOffset(builder, loc, + createIndexAttrConstant(builder, loc, indexType, offset)); +} + /// Builds IR extracting the pos-th size from the descriptor. Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { return builder.create<LLVM::ExtractValueOp>( @@ -356,6 +401,13 @@ void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } +/// Builds IR inserting the pos-th size into the descriptor +void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, + unsigned pos, uint64_t size) { + setSize(builder, loc, pos, + createIndexAttrConstant(builder, loc, indexType, size)); +} + /// Builds IR extracting the pos-th size from the descriptor. Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { @@ -372,6 +424,13 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } +/// Builds IR inserting the pos-th stride into the descriptor +void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, + unsigned pos, uint64_t stride) { + setStride(builder, loc, pos, + createIndexAttrConstant(builder, loc, indexType, stride)); +} + LLVM::LLVMType MemRefDescriptor::getElementType() { return value->getType().cast<LLVM::LLVMType>().getStructElementType( kAlignedPtrPosInMemRefDescriptor); @@ -448,8 +507,7 @@ public: // Create an LLVM IR pseudo-operation defining the given index constant. Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const { - auto attr = builder.getIntegerAttr(builder.getIndexType(), value); - return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr); + return createIndexAttrConstant(builder, loc, getIndexType(), value); } protected: |

