diff options
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/Ops.td | 1 | ||||
| -rw-r--r-- | mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp | 126 | ||||
| -rw-r--r-- | mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir | 106 | ||||
| -rw-r--r-- | mlir/test/Conversion/StandardToLLVM/foo.mlir | 25 |
4 files changed, 257 insertions, 1 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 10f94381d22..be20c382326 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -1195,7 +1195,6 @@ def ViewOp : Std_Op<"view"> { // TODO(andydavis) Add canonicalizer to fold constants into shape and map. } - def XOrOp : IntArithmeticOp<"xor", [Commutative]> { let summary = "integer binary xor"; let hasFolder = 1; diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 490b6695d84..89bf07f27cd 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1365,6 +1365,131 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> { } }; +/// Conversion pattern that transforms a op into: +/// 1. An `llvm.mlir.undef` operation to create a memref descriptor +/// 2. Updates to the descriptor to introduce the data ptr, offset, size +/// and stride. +/// The view op is replaced by the descriptor. +struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> { + using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern; + + // Build and return the value for the idx^th shape dimension, either by + // returning the constant shape dimension or counting the proper dynamic size. + Value *getSize(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> shape, ArrayRef<Value *> dynamicSizes, + unsigned idx) const { + assert(idx < shape.size()); + if (!ShapedType::isDynamic(shape[idx])) + return createIndexConstant(rewriter, loc, shape[idx]); + // Count the number of dynamic dims in range [0, idx] + unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { + return ShapedType::isDynamic(v); + }); + return dynamicSizes[nDynamic]; + } + + // Build and return the idx^th stride, either by returning the constant stride + // or by computing the dynamic stride from the current `runningStride` and + // `nextSize`. The caller should keep a running stride and update it with the + // result returned by this function. + Value *getStride(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef<int64_t> strides, Value *nextSize, + Value *runningStride, unsigned idx) const { + assert(idx < strides.size()); + if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) + return createIndexConstant(rewriter, loc, strides[idx]); + if (nextSize) + return runningStride + ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize) + : nextSize; + assert(!runningStride); + return createIndexConstant(rewriter, loc, 1); + } + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto viewOp = cast<ViewOp>(op); + ViewOpOperandAdaptor adaptor(operands); + auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>(); + auto sourceElementTy = + lowering.convertType(sourceMemRefType.getElementType()) + .dyn_cast<LLVM::LLVMType>(); + + auto viewMemRefType = viewOp.getType(); + auto targetElementTy = lowering.convertType(viewMemRefType.getElementType()) + .dyn_cast<LLVM::LLVMType>(); + auto targetDescTy = + lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>(); + if (!targetDescTy) + return op->emitWarning("Target descriptor type not converted to LLVM"), + matchFailure(); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); + if (failed(successStrides)) + return op->emitWarning("Cannot cast to non-strided shape"), + matchFailure(); + if (strides.back() != 1) + return op->emitWarning("Cannot cast to non-contiguous shape"), + matchFailure(); + + // Create the descriptor. + Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy); + + // Copy the buffer pointer from the old descriptor to the new one. + Value *sourceDescriptor = adaptor.source(); + Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>( + loc, targetElementTy.getPointerTo(), + rewriter.create<LLVM::ExtractValueOp>( + loc, sourceElementTy.getPointerTo(), sourceDescriptor, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kPtrPosInMemRefDescriptor))); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, bitcastPtr, + rewriter.getI64ArrayAttr(LLVMTypeConverter::kPtrPosInMemRefDescriptor)); + + // Offset. + unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); + (void)numDynamicSizes; + auto sizeAndOffsetOperands = adaptor.operands(); + assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 || + offset != MemRefType::getDynamicStrideOrOffset()); + Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset()) + ? createIndexConstant(rewriter, loc, offset) + // TODO(ntv): better adaptor. + : sizeAndOffsetOperands.back(); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, baseOffset, + rewriter.getI64ArrayAttr( + LLVMTypeConverter::kOffsetPosInMemRefDescriptor)); + + // Update sizes and strides. + Value *stride = nullptr, *nextSize = nullptr; + for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { + // Update size. + Value *size = getSize(rewriter, loc, viewMemRefType.getShape(), + sizeAndOffsetOperands, i); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, size, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kSizePosInMemRefDescriptor, i})); + // Update stride. + stride = getStride(rewriter, loc, strides, nextSize, stride, i); + desc = rewriter.create<LLVM::InsertValueOp>( + loc, desc, stride, + rewriter.getI64ArrayAttr( + {LLVMTypeConverter::kStridePosInMemRefDescriptor, i})); + nextSize = size; + } + + rewriter.replaceOp(op, desc); + return matchSuccess(); + } +}; + } // namespace static void ensureDistinctSuccessors(Block &bb) { @@ -1459,6 +1584,7 @@ void mlir::populateStdToLLVMConversionPatterns( SubFOpLowering, SubIOpLowering, TruncateIOpLowering, + ViewOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); // clang-format on diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index fb23a76cf25..a0daac0d658 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -602,3 +602,109 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> { // CHECK-NEXT: [[SCALE:%[0-9]+]] = llvm.fmul [[A]], [[SPLAT]] : !llvm<"<4 x float>"> // CHECK-NEXT: llvm.return [[SCALE]] : !llvm<"<4 x float>"> +// CHECK-LABEL: func @view( +// CHECK: %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 +func @view(%arg0 : index, %arg1 : index, %arg2 : index) { + // CHECK: llvm.mlir.constant(2048 : index) : !llvm.i64 + // CHECK: llvm.mlir.undef : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + %0 = alloc() : memref<2048xi8> + + // Test two dynamic sizes and dynamic offset. + // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mul %{{.*}}, %[[ARG1]] + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + %1 = view %0[%arg0, %arg1][%arg2] + : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)> + + // Test two dynamic sizes and static offset. + // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %arg0, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mul %{{.*}}, %[[ARG1]] + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + %2 = view %0[%arg0, %arg1][] + : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * s0 + d1)> + + // Test one dynamic size and dynamic offset. + // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mul %{{.*}}, %[[ARG1]] + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + %3 = view %0[%arg1][%arg2] + : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)> + + // Test one dynamic size and static offset. + // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(16 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + %4 = view %0[%arg0][] + : memref<2048xi8> to memref<?x16xf32, (d0, d1) -> (d0 * 4 + d1)> + + // Test static sizes and static offset. + // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(64 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + %5 = view %0[][] + : memref<2048xi8> to memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> + + // Test dynamic everything. + // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[STRIDE_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %[[STRIDE_1]], %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mul %[[STRIDE_1]], %[[ARG1]] : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + %6 = view %0[%arg0, %arg1][%arg2] + : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)> + + return +} diff --git a/mlir/test/Conversion/StandardToLLVM/foo.mlir b/mlir/test/Conversion/StandardToLLVM/foo.mlir new file mode 100644 index 00000000000..76007978463 --- /dev/null +++ b/mlir/test/Conversion/StandardToLLVM/foo.mlir @@ -0,0 +1,25 @@ + +// CHECK-LABEL: func @view( +// CHECK: %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64 +func @view(%arg0 : index, %arg1 : index, %arg2 : index) { + // CHECK: llvm.mlir.constant(2048 : index) : !llvm.i64 + // CHECK: llvm.mlir.undef : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + %0 = alloc() : memref<2048xi8> + + // Test one dynamic size and dynamic offset. + // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64 + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> + %3 = view %0[%arg1][%arg2] + : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)> + return +} |

