diff options
| author | River Riddle <riverriddle@google.com> | 2019-12-13 12:21:42 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-13 16:47:26 -0800 |
| commit | b030e4a4ec5ef47549377cc0af71a95abcf28a98 (patch) | |
| tree | fa7dab40faf97bc094d1e9426d0f17e44afa0f44 /mlir/include | |
| parent | 7b19d736172789ce8e5ca10ae6276302004533f0 (diff) | |
| download | bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.tar.gz bcm5719-llvm-b030e4a4ec5ef47549377cc0af71a95abcf28a98.zip | |
Try to fold operations in DialectConversion when trying to legalize.
This change allows for DialectConversion to attempt folding as a mechanism to legalize illegal operations. This also expands folding support in OpBuilder::createOrFold to generate new constants when folding, and also enables it to work in the context of a PatternRewriter.
PiperOrigin-RevId: 285448440
Diffstat (limited to 'mlir/include')
| -rw-r--r-- | mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 5 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/Ops.h | 5 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/VectorOps/VectorOps.h | 5 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Builders.h | 22 |
4 files changed, 31 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index 8d36473674b..835ac24b4ae 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -47,6 +47,11 @@ class AffineOpsDialect : public Dialect { public: AffineOpsDialect(MLIRContext *context); static StringRef getDialectNamespace() { return "affine"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; }; /// The "affine.apply" operation applies an affine map to a list of operands, diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h index d01a1eaaca2..c7c8714752f 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -42,6 +42,11 @@ class StandardOpsDialect : public Dialect { public: StandardOpsDialect(MLIRContext *context); static StringRef getDialectNamespace() { return "std"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; }; /// The predicate indicates the type of the comparison to perform: diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h index 5b4351b3cf5..06672c7ea73 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -37,6 +37,11 @@ class VectorOpsDialect : public Dialect { public: VectorOpsDialect(MLIRContext *context); static StringRef getDialectNamespace() { return "vector"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; }; /// Collect a set of vector-to-vector canonicalization patterns. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 9c787c14567..766902fabfa 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -315,8 +315,17 @@ public: template <typename OpTy, typename... Args> void createOrFold(SmallVectorImpl<Value *> &results, Location location, Args &&... args) { - auto op = create<OpTy>(location, std::forward<Args>(args)...); - tryFold(op.getOperation(), results); + // Create the operation without using 'createOperation' as we don't want to + // insert it yet. + OperationState state(location, OpTy::getOperationName()); + OpTy::build(this, state, std::forward<Args>(args)...); + Operation *op = Operation::create(state); + + // Fold the operation. If successful destroy it, otherwise insert it. + if (succeeded(tryFold(op, results))) + op->destroy(); + else + insert(op); } /// Overload to create or fold a single result operation. @@ -343,6 +352,11 @@ public: return op; } + /// Attempts to fold the given operation and places new results within + /// 'results'. Returns success if the operation was folded, failure otherwise. + /// Note: This function does not erase the operation on a successful fold. + LogicalResult tryFold(Operation *op, SmallVectorImpl<Value *> &results); + /// Creates a deep copy of the specified operation, remapping any operands /// that use values outside of the operation using the map that is provided /// ( leaving them alone if no entry is present). Replaces references to @@ -367,10 +381,6 @@ public: } private: - /// Attempts to fold the given operation and places new results within - /// 'results'. - void tryFold(Operation *op, SmallVectorImpl<Value *> &results); - Block *block = nullptr; Block::iterator insertPoint; }; |

