summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.td1
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp126
-rw-r--r--mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir106
-rw-r--r--mlir/test/Conversion/StandardToLLVM/foo.mlir25
4 files changed, 257 insertions, 1 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index 10f94381d22..be20c382326 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -1195,7 +1195,6 @@ def ViewOp : Std_Op<"view"> {
// TODO(andydavis) Add canonicalizer to fold constants into shape and map.
}
-
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
let summary = "integer binary xor";
let hasFolder = 1;
diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
index 490b6695d84..89bf07f27cd 100644
--- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
@@ -1365,6 +1365,131 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern<SplatOp> {
}
};
+/// Conversion pattern that transforms a op into:
+/// 1. An `llvm.mlir.undef` operation to create a memref descriptor
+/// 2. Updates to the descriptor to introduce the data ptr, offset, size
+/// and stride.
+/// The view op is replaced by the descriptor.
+struct ViewOpLowering : public LLVMLegalizationPattern<ViewOp> {
+ using LLVMLegalizationPattern<ViewOp>::LLVMLegalizationPattern;
+
+ // Build and return the value for the idx^th shape dimension, either by
+ // returning the constant shape dimension or counting the proper dynamic size.
+ Value *getSize(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> shape, ArrayRef<Value *> dynamicSizes,
+ unsigned idx) const {
+ assert(idx < shape.size());
+ if (!ShapedType::isDynamic(shape[idx]))
+ return createIndexConstant(rewriter, loc, shape[idx]);
+ // Count the number of dynamic dims in range [0, idx]
+ unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) {
+ return ShapedType::isDynamic(v);
+ });
+ return dynamicSizes[nDynamic];
+ }
+
+ // Build and return the idx^th stride, either by returning the constant stride
+ // or by computing the dynamic stride from the current `runningStride` and
+ // `nextSize`. The caller should keep a running stride and update it with the
+ // result returned by this function.
+ Value *getStride(ConversionPatternRewriter &rewriter, Location loc,
+ ArrayRef<int64_t> strides, Value *nextSize,
+ Value *runningStride, unsigned idx) const {
+ assert(idx < strides.size());
+ if (strides[idx] != MemRefType::getDynamicStrideOrOffset())
+ return createIndexConstant(rewriter, loc, strides[idx]);
+ if (nextSize)
+ return runningStride
+ ? rewriter.create<LLVM::MulOp>(loc, runningStride, nextSize)
+ : nextSize;
+ assert(!runningStride);
+ return createIndexConstant(rewriter, loc, 1);
+ }
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op->getLoc();
+ auto viewOp = cast<ViewOp>(op);
+ ViewOpOperandAdaptor adaptor(operands);
+ auto sourceMemRefType = viewOp.source()->getType().cast<MemRefType>();
+ auto sourceElementTy =
+ lowering.convertType(sourceMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+
+ auto viewMemRefType = viewOp.getType();
+ auto targetElementTy = lowering.convertType(viewMemRefType.getElementType())
+ .dyn_cast<LLVM::LLVMType>();
+ auto targetDescTy =
+ lowering.convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
+ if (!targetDescTy)
+ return op->emitWarning("Target descriptor type not converted to LLVM"),
+ matchFailure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset);
+ if (failed(successStrides))
+ return op->emitWarning("Cannot cast to non-strided shape"),
+ matchFailure();
+ if (strides.back() != 1)
+ return op->emitWarning("Cannot cast to non-contiguous shape"),
+ matchFailure();
+
+ // Create the descriptor.
+ Value *desc = rewriter.create<LLVM::UndefOp>(loc, targetDescTy);
+
+ // Copy the buffer pointer from the old descriptor to the new one.
+ Value *sourceDescriptor = adaptor.source();
+ Value *bitcastPtr = rewriter.create<LLVM::BitcastOp>(
+ loc, targetElementTy.getPointerTo(),
+ rewriter.create<LLVM::ExtractValueOp>(
+ loc, sourceElementTy.getPointerTo(), sourceDescriptor,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kPtrPosInMemRefDescriptor)));
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, bitcastPtr,
+ rewriter.getI64ArrayAttr(LLVMTypeConverter::kPtrPosInMemRefDescriptor));
+
+ // Offset.
+ unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes());
+ (void)numDynamicSizes;
+ auto sizeAndOffsetOperands = adaptor.operands();
+ assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + 1 ||
+ offset != MemRefType::getDynamicStrideOrOffset());
+ Value *baseOffset = (offset != MemRefType::getDynamicStrideOrOffset())
+ ? createIndexConstant(rewriter, loc, offset)
+ // TODO(ntv): better adaptor.
+ : sizeAndOffsetOperands.back();
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, baseOffset,
+ rewriter.getI64ArrayAttr(
+ LLVMTypeConverter::kOffsetPosInMemRefDescriptor));
+
+ // Update sizes and strides.
+ Value *stride = nullptr, *nextSize = nullptr;
+ for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) {
+ // Update size.
+ Value *size = getSize(rewriter, loc, viewMemRefType.getShape(),
+ sizeAndOffsetOperands, i);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, size,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kSizePosInMemRefDescriptor, i}));
+ // Update stride.
+ stride = getStride(rewriter, loc, strides, nextSize, stride, i);
+ desc = rewriter.create<LLVM::InsertValueOp>(
+ loc, desc, stride,
+ rewriter.getI64ArrayAttr(
+ {LLVMTypeConverter::kStridePosInMemRefDescriptor, i}));
+ nextSize = size;
+ }
+
+ rewriter.replaceOp(op, desc);
+ return matchSuccess();
+ }
+};
+
} // namespace
static void ensureDistinctSuccessors(Block &bb) {
@@ -1459,6 +1584,7 @@ void mlir::populateStdToLLVMConversionPatterns(
SubFOpLowering,
SubIOpLowering,
TruncateIOpLowering,
+ ViewOpLowering,
XOrOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
// clang-format on
diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
index fb23a76cf25..a0daac0d658 100644
--- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
@@ -602,3 +602,109 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
// CHECK-NEXT: [[SCALE:%[0-9]+]] = llvm.fmul [[A]], [[SPLAT]] : !llvm<"<4 x float>">
// CHECK-NEXT: llvm.return [[SCALE]] : !llvm<"<4 x float>">
+// CHECK-LABEL: func @view(
+// CHECK: %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
+ // CHECK: llvm.mlir.constant(2048 : index) : !llvm.i64
+ // CHECK: llvm.mlir.undef : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ %0 = alloc() : memref<2048xi8>
+
+ // Test two dynamic sizes and dynamic offset.
+ // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mul %{{.*}}, %[[ARG1]]
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ %1 = view %0[%arg0, %arg1][%arg2]
+ : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
+
+ // Test two dynamic sizes and static offset.
+ // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %arg0, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mul %{{.*}}, %[[ARG1]]
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ %2 = view %0[%arg0, %arg1][]
+ : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0] -> (d0 * s0 + d1)>
+
+ // Test one dynamic size and dynamic offset.
+ // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mul %{{.*}}, %[[ARG1]]
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ %3 = view %0[%arg1][%arg2]
+ : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
+
+ // Test one dynamic size and static offset.
+ // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(16 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ %4 = view %0[%arg0][]
+ : memref<2048xi8> to memref<?x16xf32, (d0, d1) -> (d0 * 4 + d1)>
+
+ // Test static sizes and static offset.
+ // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(64 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ %5 = view %0[][]
+ : memref<2048xi8> to memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)>
+
+ // Test dynamic everything.
+ // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: %[[STRIDE_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %[[STRIDE_1]], %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG0]], %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mul %[[STRIDE_1]], %[[ARG1]] : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ %6 = view %0[%arg0, %arg1][%arg2]
+ : memref<2048xi8> to memref<?x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
+
+ return
+}
diff --git a/mlir/test/Conversion/StandardToLLVM/foo.mlir b/mlir/test/Conversion/StandardToLLVM/foo.mlir
new file mode 100644
index 00000000000..76007978463
--- /dev/null
+++ b/mlir/test/Conversion/StandardToLLVM/foo.mlir
@@ -0,0 +1,25 @@
+
+// CHECK-LABEL: func @view(
+// CHECK: %[[ARG0:.*]]: !llvm.i64, %[[ARG1:.*]]: !llvm.i64, %[[ARG2:.*]]: !llvm.i64
+func @view(%arg0 : index, %arg1 : index, %arg2 : index) {
+ // CHECK: llvm.mlir.constant(2048 : index) : !llvm.i64
+ // CHECK: llvm.mlir.undef : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ %0 = alloc() : memref<2048xi8>
+
+ // Test one dynamic size and dynamic offset.
+ // CHECK: llvm.mlir.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i8*, i64, [1 x i64], [1 x i64] }">
+ // CHECK: llvm.bitcast %{{.*}} : !llvm<"i8*"> to !llvm<"float*">
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG2]], %{{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.insertvalue %[[ARG1]], %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ // CHECK: llvm.mlir.constant(4 : index) : !llvm.i64
+ // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+ %3 = view %0[%arg1][%arg2]
+ : memref<2048xi8> to memref<4x?xf32, (d0, d1)[s0, s1] -> (d0 * s0 + d1 + s1)>
+ return
+}
OpenPOWER on IntegriCloud