diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-02-14 11:01:08 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 16:27:55 -0700 |
| commit | 93d8f14c0fd82044318232fd2dfd28f8bd5d9752 (patch) | |
| tree | d355becef7b19b9049e94b5434de10dfa718be7e | |
| parent | eb3f8dcb935259db192be01587907ab47b35d199 (diff) | |
| download | bcm5719-llvm-93d8f14c0fd82044318232fd2dfd28f8bd5d9752.tar.gz bcm5719-llvm-93d8f14c0fd82044318232fd2dfd28f8bd5d9752.zip | |
[TFLite] Fuse AddOp into preceding convolution ops
If we see an add op adding a constant value to a convolution op with constant
bias, we can fuse the add into the convolution op by constant folding the
bias and the add op's constant operand.
This CL also removes dangling RewriterGen check that prevents us from using
nested DAG nodes in result patterns, which is already supported.
PiperOrigin-RevId: 233989654
| -rw-r--r-- | mlir/include/mlir/IR/op_base.td | 4 | ||||
| -rw-r--r-- | mlir/tools/mlir-tblgen/RewriterGen.cpp | 5 |
2 files changed, 3 insertions, 6 deletions
diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index f11458daa92..7ebb1de25cb 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -315,11 +315,13 @@ def ArrayAttr : Attr<CPred<"true">, "array"> { let returnType = [{ ArrayAttr }]; code convertFromStorage = "{0}"; } -def ElementsAttr : Attr<CPred<"true">, "constant vector/tensor"> { +class ElementsAttrBase<Pred condition, string description> : + Attr<condition, description> { let storageType = [{ ElementsAttr }]; let returnType = [{ ElementsAttr }]; let convertFromStorage = "{0}"; } +def ElementsAttr: ElementsAttrBase<CPred<"true">, "constant vector/tensor">; def F32Attr : FloatAttrBase<F32, "32-bit float"> { let returnType = [{ APFloat }]; } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 20072a01c71..a09aeac100a 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -346,11 +346,6 @@ void PatternEmitter::emitRewriteMethod() { DagNode resultTree = pattern.getResultPattern(0); - // TODO(jpienaar): Expand to multiple results. - for (unsigned i = 0, e = resultTree.getNumArgs(); i != e; ++i) - if (resultTree.getArgAsNestedDag(i)) - PrintFatalError(loc, "only single op result supported"); - os << R"( void rewrite(Instruction *op, std::unique_ptr<PatternState> state, PatternRewriter &rewriter) const override { |

