diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 490b6695d84..89bf07f27cd 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1365,6 +1365,131 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> { } }; +/// Conversion pattern that transforms a op into: +/// 1. An `llvm.mlir.undef` operation to create a memref descriptor +/// 2. Updates to the descriptor to introduce the data ptr, offset, size +/// and stride. +/// The view op is replaced by the descriptor. +struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { + using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern; + + // Build and return the value for the idx^th shape dimension, either by + // returning the constant shape dimension or counting the proper dynamic size. + Value *getSize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> shape, ArrayRef<Value *> dynamicSizes, + unsigned idx) const { + assert(idx < shape.size()); + if (!ShapedType::isDynamic(shape[idx])) + return createIndexConstant(rewriter, loc, shape[idx]); + // Count the number of dynamic dims in range [0, idx] + unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { + return ShapedType::isDynamic(v); + }); + return dynamicSizes[nDynamic]; + } + + // Build and return the idx^th stride, either by returning the constant stride + // or by computing the dynamic stride from the current `runningStride` and + // `nextSize`. The caller should keep a running stride and update it with the + // result returned by this function. + Value *getStride(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> strides, Value *nextSize, + Value *runningStride, unsigned idx) const { + assert(idx < strides.size()); + if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) + return createIndexConstant(rewriter, loc, strides[idx]); + if (nextSize) + return runningStride + ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) + : nextSize; + assert(!runningStride); + return createIndexConstant(rewriter, loc, 1); + } + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto viewOp = cast<ViewOp>(op); + ViewOpOperandAdaptor adaptor(operands); + auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>(); + auto sourceElementTy = + lowering.convertType(sourceMemRefType.getElementType()) + .dyn_cast<LLVM::LLVMType>(); + + auto viewMemRefType = viewOp.getType(); + auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) + .dyn_cast<LLVM::LLVMType>(); + auto targetDescTy = + lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>(); + if (!targetDescTy) + return op->emitWarning("Target descriptor type not converted to LLVM"), + matchFailure(); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + if (failed(successStrides)) + return op->emitWarning("Cannot cast to non-strided shape"), + matchFailure(); + if (strides.back() != 1) + return op->emitWarning("Cannot cast to non-contiguous shape"), + matchFailure(); + + // Create the descriptor. + Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy); + + // Copy the buffer pointer from the old descriptor to the new one. + Value *sourceDescriptor = adaptor.source(); + Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>( + loc, targetElementTy.getPointerTo(), + rewriter.create<LLVM::ExtractValueOp>( + loc, sourceElementTy.getPointerTo(), sourceDescriptor, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor))); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, bitcastPtr, + rewriter.getI64ArrayAttr(LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + + // Offset. + unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); + (void)numDynamicSizes; + auto sizeAndOffsetOperands = adaptor.operands(); + assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 || + offset != MemRefType::getDynamicStrideOrOffset()); + Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset()) + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.back(); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, baseOffset, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + + // Update sizes and strides. + Value *stride = nullptr, *nextSize = nullptr; + for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { + // Update size. + Value *size = getSize(rewriter, loc, viewMemRefType.getShape(), + sizeAndOffsetOperands, i); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, size, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, i})); + // Update stride. + stride = getStride(rewriter, loc, strides, nextSize, stride, i); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, stride, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); + nextSize = size; + } + + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + } // namespace static void ensureDistinctSuccessors(Block &bb) { @@ -1459,6 +1584,7 @@ void mlir::populateStdToLLVMConversionPatterns( SubFOpLowering, SubIOpLowering, TruncateIOpLowering, + ViewOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); // clang-format on |

