summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToLLVM
diff options
context:
space:
mode:
authorAlex Zinenko <zinenko@google.com>2019-12-06 10:08:15 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-06 10:08:43 -0800
commite216a72ab8587c443e4c5c06aabc71c36712ce7e (patch)
tree241edd8630181789b45a0fae79fcc5c25f128022 /mlir/lib/Conversion/StandardToLLVM
parent3c69ca1e696645a944fac6c9794d71e8424665c5 (diff)
downloadbcm5719-llvm-e216a72ab8587c443e4c5c06aabc71c36712ce7e.tar.gz
bcm5719-llvm-e216a72ab8587c443e4c5c06aabc71c36712ce7e.zip
Add conversions of GPU func with memory attributions to LLVM/NVVM
GPU functions use memory attributions, a combination of Op attributes and region arguments, to specify function-wide buffers placed in workgroup or private memory spaces. Introduce a lowering pattern for GPU functions to be converted to LLVM functions taking into account memory attributions. Workgroup attributions get transformed into module-level globals with unique names derived from function names. Private attributions get converted into llvm.allocas inside the function body. In both cases, we inject at the beginning of the function the IR that obtains the raw pointer to the data and populates a MemRef descriptor based on the MemRef type of buffer, making attributions compose with the rest of the MemRef lowering and transparent for use with std.load and std.store. While using raw pointers instead of descriptors might have been more efficient, it is better implemented as a canonicalization or a separate transformation so that non-attribution memrefs could also benefit from it. PiperOrigin-RevId: 284208396
Diffstat (limited to 'mlir/lib/Conversion/StandardToLLVM')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp62
1 files changed, 60 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 7b15b758968..c1a7a336401 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -304,6 +304,36 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
return MemRefDescriptor(descriptor);
}
+/// Builds IR creating a MemRef descriptor that represents `type` and
+/// populates it with static shape and stride information extracted from the
+/// type.
+MemRefDescriptor
+MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc,
+ LLVMTypeConverter &typeConverter,
+ MemRefType type, Value *memory) {
+ assert(type.hasStaticShape() && "unexpected dynamic shape");
+ assert(type.getAffineMaps().empty() && "unexpected layout map");
+
+ auto convertedType = typeConverter.convertType(type);
+ assert(convertedType && "unexpected failure in memref type conversion");
+
+ auto descr = MemRefDescriptor::undef(builder, loc, convertedType);
+ descr.setAllocatedPtr(builder, loc, memory);
+ descr.setAlignedPtr(builder, loc, memory);
+ descr.setConstantOffset(builder, loc, 0);
+
+ // Fill in sizes and strides, in reverse order to simplify stride
+ // calculation.
+ uint64_t runningStride = 1;
+ for (unsigned i = type.getRank(); i > 0; --i) {
+ unsigned dim = i - 1;
+ descr.setConstantSize(builder, loc, dim, type.getDimSize(dim));
+ descr.setConstantStride(builder, loc, dim, runningStride);
+ runningStride *= type.getDimSize(dim);
+ }
+ return descr;
+}
+
/// Builds IR extracting the allocated pointer from the descriptor.
Value *MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) {
return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor);
@@ -326,6 +356,14 @@ void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr);
}
+// Creates a constant Op producing a value of `resultType` from an index-typed
+// integer attribute.
+static Value *createIndexAttrConstant(OpBuilder &builder, Location loc,
+ Type resultType, int64_t value) {
+ return builder.create<LLVM::ConstantOp>(
+ loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value));
+}
+
/// Builds IR extracting the offset from the descriptor.
Value *MemRefDescriptor::offset(OpBuilder &builder, Location loc) {
return builder.create<LLVM::ExtractValueOp>(
@@ -341,6 +379,13 @@ void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor));
}
+/// Builds IR inserting the offset into the descriptor.
+void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc,
+ uint64_t offset) {
+ setOffset(builder, loc,
+ createIndexAttrConstant(builder, loc, indexType, offset));
+}
+
/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
@@ -356,6 +401,13 @@ void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos}));
}
+/// Builds IR inserting the pos-th size into the descriptor
+void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc,
+ unsigned pos, uint64_t size) {
+ setSize(builder, loc, pos,
+ createIndexAttrConstant(builder, loc, indexType, size));
+}
+
/// Builds IR extracting the pos-th size from the descriptor.
Value *MemRefDescriptor::stride(OpBuilder &builder, Location loc,
unsigned pos) {
@@ -372,6 +424,13 @@ void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos,
builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos}));
}
+/// Builds IR inserting the pos-th stride into the descriptor
+void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc,
+ unsigned pos, uint64_t stride) {
+ setStride(builder, loc, pos,
+ createIndexAttrConstant(builder, loc, indexType, stride));
+}
+
LLVM::LLVMType MemRefDescriptor::getElementType() {
return value->getType().cast<LLVM::LLVMType>().getStructElementType(
kAlignedPtrPosInMemRefDescriptor);
@@ -448,8 +507,7 @@ public:
// Create an LLVM IR pseudo-operation defining the given index constant.
Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc,
uint64_t value) const {
- auto attr = builder.getIntegerAttr(builder.getIndexType(), value);
- return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
+ return createIndexAttrConstant(builder, loc, getIndexType(), value);
}
protected:
OpenPOWER on IntegriCloud