diff options
| author | Nicolas Vasilache <ntv@google.com> | 2020-01-05 19:37:56 -0500 |
|---|---|---|
| committer | Nicolas Vasilache <ntv@google.com> | 2020-01-08 13:07:41 -0500 |
| commit | 766ce87e9bed89bc3b5c2c904f1eb2d10be0d3be (patch) | |
| tree | 9f132bfe7e416f6c3194bfffa346c5e4aab79cfb /mlir/lib | |
| parent | 3811417f39a7d0a370fac2923060f5ef8dacd8d7 (diff) | |
| download | bcm5719-llvm-766ce87e9bed89bc3b5c2c904f1eb2d10be0d3be.tar.gz bcm5719-llvm-766ce87e9bed89bc3b5c2c904f1eb2d10be0d3be.zip | |
[mlir][Linalg] Lower linalg.reshape to LLVM for the static case
Summary:
This diff adds lowering of the linalg.reshape op to LLVM.
A new descriptor is created with fields initialized as follows:
1. allocatedPTr, alignedPtr and offset are copied from the source descriptor
2. sizes are copied from the static destination shape
3. strides are copied from the static strides collected with `getStridesAndOffset`
Only the static case in which the target view conforms to strided memref
semantics is supported. Other cases are left for future work and will be added on
a per-need basis.
Reviewers: ftynse, mravishankar
Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D72316
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp | 52 |
1 files changed, 50 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp index 2dd36c94d31..86890b12ade 100644 --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -122,8 +122,14 @@ public: void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); } Value size(unsigned i) { return d.size(rewriter(), loc(), i); } void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); } + void setConstantSize(unsigned i, int64_t v) { + d.setConstantSize(rewriter(), loc(), i, v); + } Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); } void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); } + void setConstantStride(unsigned i, int64_t v) { + d.setConstantStride(rewriter(), loc(), i, v); + } operator Value() { return d; } @@ -161,6 +167,48 @@ public: } }; +// ReshapeOp creates a new view descriptor of the proper rank. +// For now, the only conversion supported is for target MemRef with static sizes +// and strides. +class ReshapeOpConversion : public LLVMOpLowering { +public: + explicit ReshapeOpConversion(MLIRContext *context, + LLVMTypeConverter &lowering_) + : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + auto reshapeOp = cast<ReshapeOp>(op); + MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>(); + + if (!dstType.hasStaticShape()) + return matchFailure(); + + int64_t offset; + SmallVector<int64_t, 4> strides; + auto res = getStridesAndOffset(dstType, strides, offset); + if (failed(res) || llvm::any_of(strides, [](int64_t val) { + return ShapedType::isDynamicStrideOrOffset(val); + })) + return matchFailure(); + + edsc::ScopedContext context(rewriter, op->getLoc()); + ReshapeOpOperandAdaptor adaptor(operands); + BaseViewConversionHelper baseDesc(adaptor.view()); + BaseViewConversionHelper desc(lowering.convertType(dstType)); + desc.setAllocatedPtr(baseDesc.allocatedPtr()); + desc.setAlignedPtr(baseDesc.alignedPtr()); + desc.setOffset(baseDesc.offset()); + for (auto en : llvm::enumerate(dstType.getShape())) + desc.setConstantSize(en.index(), en.value()); + for (auto en : llvm::enumerate(strides)) + desc.setConstantStride(en.index(), en.value()); + rewriter.replaceOp(op, {desc}); + return matchSuccess(); + } +}; + /// Conversion pattern that transforms a linalg.slice op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. @@ -508,8 +556,8 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, void mlir::populateLinalgToLLVMConversionPatterns( LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion, - YieldOpConversion>(ctx, converter); + patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion, + TransposeOpConversion, YieldOpConversion>(ctx, converter); } namespace { |

