summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-01-05 19:37:56 -0500
committerNicolas Vasilache <ntv@google.com>2020-01-08 13:07:41 -0500
commit766ce87e9bed89bc3b5c2c904f1eb2d10be0d3be (patch)
tree9f132bfe7e416f6c3194bfffa346c5e4aab79cfb /mlir/lib
parent3811417f39a7d0a370fac2923060f5ef8dacd8d7 (diff)
downloadbcm5719-llvm-766ce87e9bed89bc3b5c2c904f1eb2d10be0d3be.tar.gz
bcm5719-llvm-766ce87e9bed89bc3b5c2c904f1eb2d10be0d3be.zip
[mlir][Linalg] Lower linalg.reshape to LLVM for the static case
Summary: This diff adds lowering of the linalg.reshape op to LLVM. A new descriptor is created with fields initialized as follows: 1. allocatedPTr, alignedPtr and offset are copied from the source descriptor 2. sizes are copied from the static destination shape 3. strides are copied from the static strides collected with `getStridesAndOffset` Only the static case in which the target view conforms to strided memref semantics is supported. Other cases are left for future work and will be added on a per-need basis. Reviewers: ftynse, mravishankar Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72316
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp52
1 files changed, 50 insertions, 2 deletions
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 2dd36c94d31..86890b12ade 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -122,8 +122,14 @@ public:
void setOffset(Value v) { d.setOffset(rewriter(), loc(), v); }
Value size(unsigned i) { return d.size(rewriter(), loc(), i); }
void setSize(unsigned i, Value v) { d.setSize(rewriter(), loc(), i, v); }
+ void setConstantSize(unsigned i, int64_t v) {
+ d.setConstantSize(rewriter(), loc(), i, v);
+ }
Value stride(unsigned i) { return d.stride(rewriter(), loc(), i); }
void setStride(unsigned i, Value v) { d.setStride(rewriter(), loc(), i, v); }
+ void setConstantStride(unsigned i, int64_t v) {
+ d.setConstantStride(rewriter(), loc(), i, v);
+ }
operator Value() { return d; }
@@ -161,6 +167,48 @@ public:
}
};
+// ReshapeOp creates a new view descriptor of the proper rank.
+// For now, the only conversion supported is for target MemRef with static sizes
+// and strides.
+class ReshapeOpConversion : public LLVMOpLowering {
+public:
+ explicit ReshapeOpConversion(MLIRContext *context,
+ LLVMTypeConverter &lowering_)
+ : LLVMOpLowering(ReshapeOp::getOperationName(), context, lowering_) {}
+
+ PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto reshapeOp = cast<ReshapeOp>(op);
+ MemRefType dstType = reshapeOp.getResult().getType().cast<MemRefType>();
+
+ if (!dstType.hasStaticShape())
+ return matchFailure();
+
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto res = getStridesAndOffset(dstType, strides, offset);
+ if (failed(res) || llvm::any_of(strides, [](int64_t val) {
+ return ShapedType::isDynamicStrideOrOffset(val);
+ }))
+ return matchFailure();
+
+ edsc::ScopedContext context(rewriter, op->getLoc());
+ ReshapeOpOperandAdaptor adaptor(operands);
+ BaseViewConversionHelper baseDesc(adaptor.view());
+ BaseViewConversionHelper desc(lowering.convertType(dstType));
+ desc.setAllocatedPtr(baseDesc.allocatedPtr());
+ desc.setAlignedPtr(baseDesc.alignedPtr());
+ desc.setOffset(baseDesc.offset());
+ for (auto en : llvm::enumerate(dstType.getShape()))
+ desc.setConstantSize(en.index(), en.value());
+ for (auto en : llvm::enumerate(strides))
+ desc.setConstantStride(en.index(), en.value());
+ rewriter.replaceOp(op, {desc});
+ return matchSuccess();
+ }
+};
+
/// Conversion pattern that transforms a linalg.slice op into:
/// 1. A function entry `alloca` operation to allocate a ViewDescriptor.
/// 2. A load of the ViewDescriptor from the pointer allocated in 1.
@@ -508,8 +556,8 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
void mlir::populateLinalgToLLVMConversionPatterns(
LinalgTypeConverter &converter, OwningRewritePatternList &patterns,
MLIRContext *ctx) {
- patterns.insert<RangeOpConversion, SliceOpConversion, TransposeOpConversion,
- YieldOpConversion>(ctx, converter);
+ patterns.insert<RangeOpConversion, ReshapeOpConversion, SliceOpConversion,
+ TransposeOpConversion, YieldOpConversion>(ctx, converter);
}
namespace {
OpenPOWER on IntegriCloud