summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp126
1 files changed, 126 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 490b6695d84..89bf07f27cd 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1365,6 +1365,131 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
}
};
+/// Conversion pattern that transforms a op into:
+/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
+/// 2. Updates to the descriptor to introduce the data ptr, offset, size
+/// and stride.
+/// The view op is replaced by the descriptor.
+struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
+ using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern;
+
+ // Build and return the value for the idx^th shape dimension, either by
+ // returning the constant shape dimension or counting the proper dynamic size.
+ Value *getSize(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> shape, ArrayRef<Value *> dynamicSizes,
+ unsigned idx) const {
+ assert(idx < shape.size());
+ if (!ShapedType::isDynamic(shape[idx]))
+ return createIndexConstant(rewriter, loc, shape[idx]);
+ // Count the number of dynamic dims in range [0, idx]
+ unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
+ return ShapedType::isDynamic(v);
+ });
+ return dynamicSizes[nDynamic];
+ }
+
+ // Build and return the idx^th stride, either by returning the constant stride
+ // or by computing the dynamic stride from the current `runningStride` and
+ // `nextSize`. The caller should keep a running stride and update it with the
+ // result returned by this function.
+ Value *getStride(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> strides, Value *nextSize,
+ Value *runningStride, unsigned idx) const {
+ assert(idx < strides.size());
+ if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
+ return createIndexConstant(rewriter, loc, strides[idx]);
+ if (nextSize)
+ return runningStride
+ ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
+ : nextSize;
+ assert(!runningStride);
+ return createIndexConstant(rewriter, loc, 1);
+ }
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto viewOp = cast<ViewOp>(op);
+ ViewOpOperandAdaptor adaptor(operands);
+ auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
+ auto sourceElementTy =
+ lowering.convertType(sourceMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+
+ auto viewMemRefType = viewOp.getType();
+ auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+ auto targetDescTy =
+ lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
+ if (!targetDescTy)
+ return op->emitWarning("Target descriptor type not converted to LLVM"),
+ matchFailure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+ if (failed(successStrides))
+ return op->emitWarning("Cannot cast to non-strided shape"),
+ matchFailure();
+ if (strides.back() != 1)
+ return op->emitWarning("Cannot cast to non-contiguous shape"),
+ matchFailure();
+
+ // Create the descriptor.
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
+
+ // Copy the buffer pointer from the old descriptor to the new one.
+ Value *sourceDescriptor = adaptor.source();
+ Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, targetElementTy.getPointerTo(),
+ rewriter.create<LLVM::ExtractValueOp>(
+ loc, sourceElementTy.getPointerTo(), sourceDescriptor,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor)));
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, bitcastPtr,
+ rewriter.getI64ArrayAttr(LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+
+ // Offset.
+ unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
+ (void)numDynamicSizes;
+ auto sizeAndOffsetOperands = adaptor.operands();
+ assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 ||
+ offset != MemRefType::getDynamicStrideOrOffset());
+ Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset())
+ ? createIndexConstant(rewriter, loc, offset)
+ // TODO(ntv): better adaptor.
+ : sizeAndOffsetOperands.back();
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, baseOffset,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+
+ // Update sizes and strides.
+ Value *stride = nullptr, *nextSize = nullptr;
+ for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
+ // Update size.
+ Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
+ sizeAndOffsetOperands, i);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, size,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
+ // Update stride.
+ stride = getStride(rewriter, loc, strides, nextSize, stride, i);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, stride,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+ nextSize = size;
+ }
+
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
@@ -1459,6 +1584,7 @@ void mlir::populateStdToLLVMConversionPatterns(
SubFOpLowering,
SubIOpLowering,
TruncateIOpLowering,
+ ViewOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
// clang-format on
OpenPOWER on IntegriCloud