summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2019-08-19 10:21:15 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-19 10:21:40 -0700
commit9bf69e6a2e9d1ef60ac9e4efa8fda9b6c3560e63 (patch)
tree447bea3897341fe71df28bc8b08f84320d57a452 /mlir/lib
parentc9f37fca379035b6334b50380ef05b00026de0cc (diff)
downloadbcm5719-llvm-9bf69e6a2e9d1ef60ac9e4efa8fda9b6c3560e63.tar.gz
bcm5719-llvm-9bf69e6a2e9d1ef60ac9e4efa8fda9b6c3560e63.zip
Refactor linalg lowering to LLVM
The linalg.view type used to be lowered to a struct containing a data pointer, offset, sizes/strides information. This was problematic when passing to external functions due to ABI, struct padding and alignment issues. The linalg.view type is now lowered to LLVMIR as a *pointer* to a struct containing the data pointer, offset and sizes/strides. This simplifies the interfacing with external library functions and makes it trivial to add new functions without creating a shim that would go from a value type struct to a pointer type. The consequences are that: 1. lowering explicitly uses llvm.alloca in lieu of llvm.undef and performs the proper llvm.load/llvm.store where relevant. 2. the shim creation function `getLLVMLibraryCallDefinition` disappears. 3. views are passed by pointer, scalars are passed by value. In the future, other structs will be passed by pointer (on a per-need basis). PiperOrigin-RevId: 264183671
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp15
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp5
-rw-r--r--mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp255
3 files changed, 142 insertions, 133 deletions
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
index b3864a39560..a3b80b1e9e0 100644
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
+++ b/mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
@@ -111,7 +111,8 @@ private:
Value *allocatePointer(OpBuilder &builder, Location loc) {
auto one = builder.create<LLVM::ConstantOp>(loc, getInt32Type(),
builder.getI32IntegerAttr(1));
- return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one);
+ return builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), one,
+ /*alignment=*/0);
}
void declareCudaFunctions(Location loc);
@@ -233,13 +234,13 @@ GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
auto arraySize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(launchOp.getNumKernelOperands()));
- auto array =
- builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(), arraySize);
+ auto array = builder.create<LLVM::AllocaOp>(loc, getPointerPointerType(),
+ arraySize, /*alignment=*/0);
for (int idx = 0, e = launchOp.getNumKernelOperands(); idx < e; ++idx) {
auto operand = launchOp.getKernelOperand(idx);
auto llvmType = operand->getType().cast<LLVM::LLVMType>();
- auto memLocation =
- builder.create<LLVM::AllocaOp>(loc, llvmType.getPointerTo(), one);
+ auto memLocation = builder.create<LLVM::AllocaOp>(
+ loc, llvmType.getPointerTo(), one, /*alignment=*/1);
builder.create<LLVM::StoreOp>(loc, operand, memLocation);
auto casted =
builder.create<LLVM::BitcastOp>(loc, getPointerType(), memLocation);
@@ -267,8 +268,8 @@ Value *GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
auto kernelNameSize = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(),
builder.getI32IntegerAttr(kernelFunction.getName().size() + 1));
- auto kernelName =
- builder.create<LLVM::AllocaOp>(loc, getPointerType(), kernelNameSize);
+ auto kernelName = builder.create<LLVM::AllocaOp>(
+ loc, getPointerType(), kernelNameSize, /*alignment=*/1);
for (auto byte : llvm::enumerate(kernelFunction.getName())) {
auto index = builder.create<LLVM::ConstantOp>(
loc, getInt32Type(), builder.getI32IntegerAttr(byte.index()));
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 9ba06db7aba..4240e3e7ae7 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -808,10 +808,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
Value *dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
transformed.indices(), rewriter, getModule());
- auto elementType = lowering.convertType(type.getElementType());
-
- rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, elementType,
- ArrayRef<Value *>{dataPtr});
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
return matchSuccess();
}
};
diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
index de183f8f76e..05bdf24a975 100644
--- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -117,21 +117,22 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
if (t.isa<RangeType>())
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
- // View descriptor contains the pointer to the data buffer, followed by a
- // 64-bit integer containing the distance between the beginning of the buffer
- // and the first element to be accessed through the view, followed by two
- // arrays, each containing as many 64-bit integers as the rank of the View.
- // The first array represents the size, in number of original elements, of the
- // view along the given dimension. When taking the view, the size is the
- // difference between the upper and the lower bound of the range. The second
- // array represents the "stride" (in tensor abstraction sense), i.e. the
- // number of consecutive elements of the underlying buffer that separate two
- // consecutive elements addressable through the view along the given
- // dimension. When taking the view, the strides are constructed as products
- // of the original sizes along the trailing dimensions, multiplied by the view
- // step. For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
- // i.e. the view of a complete memref, will have strides N and 1. A view with
- // ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
+ // A linalg.view type converts to a *pointer to* a view descriptor. The view
+ // descriptor contains the pointer to the data buffer, followed by a 64-bit
+ // integer containing the distance between the beginning of the buffer and the
+ // first element to be accessed through the view, followed by two arrays, each
+ // containing as many 64-bit integers as the rank of the View. The first array
+ // represents the size, in number of original elements, of the view along the
+ // given dimension. When taking the view, the size is the difference between
+ // the upper and the lower bound of the range. The second array represents the
+ // "stride" (in tensor abstraction sense), i.e. the number of consecutive
+ // elements of the underlying buffer that separate two consecutive elements
+ // addressable through the view along the given dimension. When taking the
+ // view, the strides are constructed as products of the original sizes along
+ // the trailing dimensions, multiplied by the view step. For example, a view
+ // of a MxN memref with ranges {0:M:1}, {0:N:1}, i.e. the view of a complete
+ // memref, will have strides N and 1. A view with ranges {0:M:2}, {0:N:3}
+ // will have strides 2*N and 3.
//
// template <typename Elem, size_t Rank>
// struct {
@@ -139,16 +140,24 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
// int64_t offset;
// int64_t sizes[Rank];
// int64_t strides[Rank];
- // };
+ // } *;
if (auto viewType = t.dyn_cast<ViewType>()) {
auto ptrTy = getPtrToElementType(viewType, lowering);
auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank());
- return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy);
+ return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy)
+ .getPointerTo();
}
return Type();
}
+static constexpr int kPtrPosInBuffer = 0;
+static constexpr int kSizePosInBuffer = 1;
+static constexpr int kPtrPosInView = 0;
+static constexpr int kOffsetPosInView = 1;
+static constexpr int kSizePosInView = 2;
+static constexpr int kStridePosInView = 3;
+
// Create an array attribute containing integer attributes with values provided
// in `position`.
static ArrayAttr positionAttr(Builder &builder, ArrayRef<int> position) {
@@ -192,10 +201,9 @@ public:
llvm::divideCeil(vectorType.getElementTypeBitWidth(), 8);
else
elementSize = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
- auto bufferType = allocOp.getResult()->getType().cast<BufferType>();
+ auto bufferType = allocOp.getBufferType();
auto elementPtrType = getPtrToElementType(bufferType, lowering);
- auto bufferDescriptorType =
- convertLinalgType(allocOp.getResult()->getType(), lowering);
+ auto bufferDescriptorTy = convertLinalgType(bufferType, lowering);
// Emit IR for creating a new buffer descriptor with an underlying malloc.
edsc::ScopedContext context(rewriter, op->getLoc());
@@ -212,11 +220,11 @@ public:
.getOperation()
->getResult(0);
allocated = bitcast(elementPtrType, allocated);
- Value *desc = undef(bufferDescriptorType);
- desc = insertvalue(bufferDescriptorType, desc, allocated,
- positionAttr(rewriter, 0));
- desc = insertvalue(bufferDescriptorType, desc, size,
- positionAttr(rewriter, 1));
+ Value *desc = undef(bufferDescriptorTy);
+ desc = insertvalue(bufferDescriptorTy, desc, allocated,
+ positionAttr(rewriter, kPtrPosInBuffer));
+ desc = insertvalue(bufferDescriptorTy, desc, size,
+ positionAttr(rewriter, kSizePosInBuffer));
rewriter.replaceOp(op, desc);
return matchSuccess();
}
@@ -246,13 +254,15 @@ public:
// Get MLIR types for extracting element pointer.
auto deallocOp = cast<BufferDeallocOp>(op);
- auto elementPtrTy = getPtrToElementType(
- deallocOp.getOperand()->getType().cast<BufferType>(), lowering);
+ auto elementPtrTy =
+ getPtrToElementType(deallocOp.getBufferType(), lowering);
// Emit MLIR for buffer_dealloc.
+ BufferDeallocOpOperandAdaptor adaptor(operands);
edsc::ScopedContext context(rewriter, op->getLoc());
- Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
- positionAttr(rewriter, 0)));
+ Value *casted =
+ bitcast(voidPtrTy, extractvalue(elementPtrTy, adaptor.buffer(),
+ positionAttr(rewriter, 0)));
llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
@@ -270,8 +280,10 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
edsc::ScopedContext context(rewriter, op->getLoc());
+ BufferSizeOpOperandAdaptor adaptor(operands);
rewriter.replaceOp(
- op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))});
+ op, {extractvalue(int64Ty, adaptor.buffer(),
+ positionAttr(rewriter, kSizePosInBuffer))});
return matchSuccess();
}
};
@@ -288,11 +300,11 @@ public:
auto dimOp = cast<linalg::DimOp>(op);
auto indexTy = lowering.convertType(rewriter.getIndexType());
edsc::ScopedContext context(rewriter, op->getLoc());
- rewriter.replaceOp(
- op,
- {extractvalue(
- indexTy, operands[0],
- positionAttr(rewriter, {2, static_cast<int>(dimOp.getIndex())}))});
+ auto pos = positionAttr(
+ rewriter, {kSizePosInView, static_cast<int>(dimOp.getIndex())});
+ linalg::DimOpOperandAdaptor adaptor(operands);
+ Value *viewDescriptor = llvm_load(adaptor.view());
+ rewriter.replaceOp(op, {extractvalue(indexTy, viewDescriptor, pos)});
return matchSuccess();
}
};
@@ -311,7 +323,7 @@ public:
// current view indices. Use the base offset and strides stored in the view
// descriptor to emit IR iteratively computing the actual offset, followed by
// a getelementptr. This must be called under an edsc::ScopedContext.
- Value *obtainDataPtr(Operation *op, Value *viewDescriptor,
+ Value *obtainDataPtr(Operation *op, Value *viewDescriptorPtr,
ArrayRef<Value *> indices,
ConversionPatternRewriter &rewriter) const {
auto loadOp = cast<Op>(op);
@@ -323,10 +335,13 @@ public:
// Linearize subscripts as:
// base_offset + SUM_i index_i * stride_i.
- Value *base = extractvalue(elementTy, viewDescriptor, pos(0));
- Value *offset = extractvalue(int64Ty, viewDescriptor, pos(1));
+ Value *viewDescriptor = llvm_load(viewDescriptorPtr);
+ Value *base = extractvalue(elementTy, viewDescriptor, pos(kPtrPosInView));
+ Value *offset =
+ extractvalue(int64Ty, viewDescriptor, pos(kOffsetPosInView));
for (int i = 0, e = loadOp.getRank(); i < e; ++i) {
- Value *stride = extractvalue(int64Ty, viewDescriptor, pos({3, i}));
+ Value *stride =
+ extractvalue(int64Ty, viewDescriptor, pos({kStridePosInView, i}));
Value *additionalOffset = mul(indices[i], stride);
offset = add(offset, additionalOffset);
}
@@ -344,9 +359,8 @@ class LoadOpConversion : public LoadStoreOpConversion<linalg::LoadOp> {
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
auto elementTy = lowering.convertType(*op->result_type_begin());
- Value *viewDescriptor = operands[0];
- ArrayRef<Value *> indices = operands.drop_front();
- auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
+ linalg::LoadOpOperandAdaptor adaptor(operands);
+ auto ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter);
rewriter.replaceOp(op, {llvm_load(elementTy, ptr)});
return matchSuccess();
}
@@ -368,18 +382,23 @@ public:
edsc::ScopedContext context(rewriter, op->getLoc());
// Fill in an aggregate value of the descriptor.
+ RangeOpOperandAdaptor adaptor(operands);
Value *desc = undef(rangeDescriptorTy);
- desc = insertvalue(rangeDescriptorTy, desc, operands[0],
- positionAttr(rewriter, 0));
- desc = insertvalue(rangeDescriptorTy, desc, operands[1],
- positionAttr(rewriter, 1));
- desc = insertvalue(rangeDescriptorTy, desc, operands[2],
- positionAttr(rewriter, 2));
+ desc = insertvalue(desc, adaptor.min(), positionAttr(rewriter, 0));
+ desc = insertvalue(desc, adaptor.max(), positionAttr(rewriter, 1));
+ desc = insertvalue(desc, adaptor.step(), positionAttr(rewriter, 2));
rewriter.replaceOp(op, desc);
return matchSuccess();
}
};
+/// Conversion pattern that transforms a linalg.slice op into:
+/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
+/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
+/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
+/// and stride corresponding to the
+/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
+/// The linalg.slice op is replaced by the alloca'ed pointer.
class SliceOpConversion : public LLVMOpLowering {
public:
explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
@@ -390,7 +409,8 @@ public:
ConversionPatternRewriter &rewriter) const override {
SliceOpOperandAdaptor adaptor(operands);
auto sliceOp = cast<SliceOp>(op);
- auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering);
+ auto viewDescriptorPtrTy =
+ convertLinalgType(sliceOp.getViewType(), lowering);
auto viewType = sliceOp.getBaseViewType();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
@@ -399,26 +419,36 @@ public:
auto pos = [&rewriter](ArrayRef<int> values) {
return positionAttr(rewriter, values);
};
- // Helper function to obtain the ptr of the given `view`.
- auto getViewPtr = [pos, this](ViewType type, Value *view) -> Value * {
- auto elementPtrTy = getPtrToElementType(type, lowering);
- return extractvalue(elementPtrTy, view, pos(0));
- };
edsc::ScopedContext context(rewriter, op->getLoc());
- // Declare the view descriptor and insert data ptr.
- Value *desc = undef(viewDescriptorTy);
- desc = insertvalue(viewDescriptorTy, desc,
- getViewPtr(viewType, adaptor.view()), pos(0));
+ // Declare the view descriptor and insert data ptr *at the entry block of
+ // the function*, which is the preferred location for LLVM's analyses.
+ auto ip = rewriter.getInsertionPoint();
+ auto ib = rewriter.getInsertionBlock();
+ rewriter.setInsertionPointToStart(
+ &op->getParentOfType<FuncOp>().getBlocks().front());
+ Value *one =
+ constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
+ // Alloca with proper alignment.
+ Value *allocatedDesc =
+ llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
+ Value *desc = llvm_load(allocatedDesc);
+ rewriter.setInsertionPoint(ib, ip);
+
+ Value *baseDesc = llvm_load(adaptor.view());
+
+ auto ptrPos = pos(kPtrPosInView);
+ auto elementTy = getPtrToElementType(sliceOp.getViewType(), lowering);
+ desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos);
// TODO(ntv): extract sizes and emit asserts.
SmallVector<Value *, 4> strides(viewType.getRank());
for (int i = 0, e = viewType.getRank(); i < e; ++i) {
- strides[i] = extractvalue(int64Ty, adaptor.view(), pos({3, i}));
+ strides[i] = extractvalue(int64Ty, baseDesc, pos({kStridePosInView, i}));
}
// Compute and insert base offset.
- Value *baseOffset = extractvalue(int64Ty, adaptor.view(), pos(1));
+ Value *baseOffset = extractvalue(int64Ty, baseDesc, pos(kOffsetPosInView));
for (int i = 0, e = viewType.getRank(); i < e; ++i) {
Value *indexing = adaptor.indexings()[i];
Value *min =
@@ -428,7 +458,7 @@ public:
Value *product = mul(min, strides[i]);
baseOffset = add(baseOffset, product);
}
- desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
+ desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
// Compute and insert view sizes (max - min along the range) and strides.
// Skip the non-range operands as they will be projected away from the view.
@@ -443,14 +473,15 @@ public:
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
Value *size = sub(max, min);
Value *stride = mul(strides[i], step);
- desc = insertvalue(viewDescriptorTy, desc, size, pos({2, numNewDims}));
- desc =
- insertvalue(viewDescriptorTy, desc, stride, pos({3, numNewDims}));
+ desc = insertvalue(desc, size, pos({kSizePosInView, numNewDims}));
+ desc = insertvalue(desc, stride, pos({kStridePosInView, numNewDims}));
++numNewDims;
}
}
- rewriter.replaceOp(op, desc);
+ // Store back in alloca'ed region.
+ llvm_store(desc, allocatedDesc);
+ rewriter.replaceOp(op, allocatedDesc);
return matchSuccess();
}
};
@@ -463,16 +494,21 @@ class StoreOpConversion : public LoadStoreOpConversion<linalg::StoreOp> {
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
edsc::ScopedContext edscContext(rewriter, op->getLoc());
- Value *data = operands[0];
- Value *viewDescriptor = operands[1];
- ArrayRef<Value *> indices = operands.drop_front(2);
- Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter);
- llvm_store(data, ptr);
+ linalg::StoreOpOperandAdaptor adaptor(operands);
+ Value *ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter);
+ llvm_store(adaptor.value(), ptr);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
};
+/// Conversion pattern that transforms a linalg.view op into:
+/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
+/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
+/// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size
+/// and stride.
+/// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer.
+/// The linalg.view op is replaced by the alloca'ed pointer.
class ViewOpConversion : public LLVMOpLowering {
public:
explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_)
@@ -482,7 +518,9 @@ public:
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto viewOp = cast<ViewOp>(op);
- auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering);
+ ViewOpOperandAdaptor adaptor(operands);
+ auto viewDescriptorPtrTy =
+ convertLinalgType(viewOp.getViewType(), lowering);
auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering);
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
@@ -490,21 +528,34 @@ public:
return positionAttr(rewriter, values);
};
- // First operand to `view` is the buffer descriptor.
- Value *bufferDescriptor = operands[0];
+ Value *bufferDescriptor = adaptor.buffer();
+ auto bufferTy = getPtrToElementType(
+ viewOp.buffer()->getType().cast<BufferType>(), lowering);
// Declare the descriptor of the view.
edsc::ScopedContext context(rewriter, op->getLoc());
- Value *desc = undef(viewDescriptorTy);
+ auto ip = rewriter.getInsertionPoint();
+ auto ib = rewriter.getInsertionBlock();
+ rewriter.setInsertionPointToStart(
+ &op->getParentOfType<FuncOp>().getBlocks().front());
+ Value *one =
+ constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
+ // Alloca for proper alignment.
+ Value *allocatedDesc =
+ llvm_alloca(viewDescriptorPtrTy, one, /*alignment=*/8);
+ Value *desc = llvm_load(allocatedDesc);
+ rewriter.setInsertionPoint(ib, ip);
// Copy the buffer pointer from the old descriptor to the new one.
- Value *buffer = extractvalue(elementTy, bufferDescriptor, pos(0));
- desc = insertvalue(viewDescriptorTy, desc, buffer, pos(0));
+ Value *bufferAsViewElementType =
+ bitcast(elementTy,
+ extractvalue(bufferTy, bufferDescriptor, pos(kPtrPosInBuffer)));
+ desc = insertvalue(desc, bufferAsViewElementType, pos(kPtrPosInView));
// Zero base offset.
auto indexTy = rewriter.getIndexType();
Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0));
- desc = insertvalue(viewDescriptorTy, desc, baseOffset, pos(1));
+ desc = insertvalue(desc, baseOffset, pos(kOffsetPosInView));
// Compute and insert view sizes (max - min along the range).
int numRanges = llvm::size(viewOp.ranges());
@@ -514,18 +565,20 @@ public:
Value *rangeDescriptor = operands[1 + i];
Value *step = extractvalue(int64Ty, rangeDescriptor, pos(2));
Value *stride = mul(runningStride, step);
- desc = insertvalue(viewDescriptorTy, desc, stride, pos({3, i}));
+ desc = insertvalue(desc, stride, pos({kStridePosInView, i}));
// Update size.
Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
Value *size = sub(max, min);
- desc = insertvalue(viewDescriptorTy, desc, size, pos({2, i}));
+ desc = insertvalue(desc, size, pos({kSizePosInView, i}));
// Update stride for the next dimension.
if (i > 0)
runningStride = mul(runningStride, max);
}
- rewriter.replaceOp(op, desc);
+ // Store back in alloca'ed region.
+ llvm_store(desc, allocatedDesc);
+ rewriter.replaceOp(op, allocatedDesc);
return matchSuccess();
}
};
@@ -585,32 +638,6 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
return libFn;
}
-static void getLLVMLibraryCallDefinition(FuncOp fn,
- LLVMTypeConverter &lowering) {
- // Generate the implementation function definition.
- auto implFn = getLLVMLibraryCallImplDefinition(fn);
-
- // Generate the function body.
- OpBuilder builder(fn.addEntryBlock());
- edsc::ScopedContext scope(builder, fn.getLoc());
- SmallVector<Value *, 4> implFnArgs;
-
- // Create a constant 1.
- auto one = constant(LLVMType::getInt64Ty(lowering.getDialect()),
- IntegerAttr::get(IndexType::get(fn.getContext()), 1));
- for (auto arg : fn.getArguments()) {
- // Allocate a stack for storing the argument value. The stack is passed to
- // the implementation function.
- auto alloca =
- llvm_alloca(arg->getType().cast<LLVMType>().getPointerTo(), one)
- .getValue();
- implFnArgs.push_back(alloca);
- llvm_store(arg, alloca);
- }
- llvm_call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs);
- llvm_return{ArrayRef<Value *>()};
-}
-
namespace {
// The conversion class from Linalg to LLVMIR.
class LinalgTypeConverter : public LLVMTypeConverter {
@@ -622,16 +649,6 @@ public:
return result;
return convertLinalgType(t, *this);
}
-
- void addLibraryFnDeclaration(FuncOp fn) { libraryFnDeclarations.insert(fn); }
-
- ArrayRef<FuncOp> getLibraryFnDeclarations() {
- return libraryFnDeclarations.getArrayRef();
- }
-
-private:
- /// List of library functions declarations needed during dialect conversion
- llvm::SetVector<FuncOp> libraryFnDeclarations;
};
} // end anonymous namespace
@@ -652,7 +669,6 @@ public:
auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
if (!f)
return matchFailure();
- static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
auto fAttr = rewriter.getSymbolRefAttr(f);
auto named = rewriter.getNamedAttr("callee", fAttr);
@@ -727,11 +743,6 @@ void LowerLinalgToLLVMPass::runOnModule() {
if (failed(applyPartialConversion(module, target, patterns, &converter))) {
signalPassFailure();
}
-
- // Emit the function body of any Library function that was declared.
- for (auto fn : converter.getLibraryFnDeclarations()) {
- getLLVMLibraryCallDefinition(fn, converter);
- }
}
std::unique_ptr<ModulePassBase> mlir::linalg::createLowerLinalgToLLVMPass() {
OpenPOWER on IntegriCloud