diff options
| -rw-r--r-- | mlir/include/mlir/IR/Builders.h | 29 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 4 | ||||
| -rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 10 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 18 | ||||
| -rw-r--r-- | mlir/lib/Transforms/DialectConversion.cpp | 18 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 11 |
6 files changed, 33 insertions, 57 deletions
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index c5ed7b16b56..9c787c14567 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -281,6 +281,9 @@ public: /// Returns the current insertion point of the builder. Block::iterator getInsertionPoint() const { return insertPoint; } + /// Insert the given operation at the current insertion point and return it. + virtual Operation *insert(Operation *op); + /// Add new block and set the insertion point to the end of it. The block is /// inserted at the provided insertion point of 'parent'. Block *createBlock(Region *parent, Region::iterator insertPt = {}); @@ -293,7 +296,7 @@ public: Block *getBlock() const { return block; } /// Creates an operation given the fields represented as an OperationState. - virtual Operation *createOperation(const OperationState &state); + Operation *createOperation(const OperationState &state); /// Create an operation of specific op type at the current insertion point. template <typename OpTy, typename... Args> @@ -346,28 +349,21 @@ public: /// cloned sub-operations to the corresponding operation that is copied, /// and adds those mappings to the map. Operation *clone(Operation &op, BlockAndValueMapping &mapper) { - Operation *cloneOp = op.clone(mapper); - insert(cloneOp); - return cloneOp; - } - Operation *clone(Operation &op) { - Operation *cloneOp = op.clone(); - insert(cloneOp); - return cloneOp; + return insert(op.clone(mapper)); } + Operation *clone(Operation &op) { return insert(op.clone()); } /// Creates a deep copy of this operation but keep the operation regions /// empty. Operands are remapped using `mapper` (if present), and `mapper` is /// updated to contain the results. Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) { - Operation *cloneOp = op.cloneWithoutRegions(mapper); - insert(cloneOp); - return cloneOp; + return insert(op.cloneWithoutRegions(mapper)); } Operation *cloneWithoutRegions(Operation &op) { - Operation *cloneOp = op.cloneWithoutRegions(); - insert(cloneOp); - return cloneOp; + return insert(op.cloneWithoutRegions()); + } + template <typename OpT> OpT cloneWithoutRegions(OpT op) { + return cast<OpT>(cloneWithoutRegions(*op.getOperation())); } private: @@ -375,9 +371,6 @@ private: /// 'results'. void tryFold(Operation *op, SmallVectorImpl<Value *> &results); - /// Insert the given operation at the current insertion point. - void insert(Operation *op); - Block *block = nullptr; Block::iterator insertPoint; }; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 366d2b893af..4805152cf4c 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -302,9 +302,9 @@ public: return OpTy(); } - /// This is implemented to create the specified operations and serves as a + /// This is implemented to insert the specified operation and serves as a /// notification hook for rewriters that want to know about new operations. - virtual Operation *createOperation(const OperationState &state) = 0; + virtual Operation *insert(Operation *op) = 0; /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index fee58a4904a..249b4c114c9 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -332,12 +332,6 @@ public: /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument *from, Value *to); - /// Clone the given operation without cloning its regions. - Operation *cloneWithoutRegions(Operation *op); - template <typename OpT> OpT cloneWithoutRegions(OpT op) { - return cast<OpT>(cloneWithoutRegions(op.getOperation())); - } - /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. Value *getRemappedValue(Value *key); @@ -376,8 +370,8 @@ public: BlockAndValueMapping &mapping) override; using PatternRewriter::cloneRegionBefore; - /// PatternRewriter hook for creating a new operation. - Operation *createOperation(const OperationState &state) override; + /// PatternRewriter hook for inserting a new operation. + Operation *insert(Operation *op) override; /// PatternRewriter hook for updating the root operation in-place. void notifyRootUpdated(Operation *op) override; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 4d6cd3550ca..8c54df4d55b 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -306,6 +306,13 @@ AffineMap Builder::getShiftedAffineMap(AffineMap map, int64_t shift) { OpBuilder::~OpBuilder() {} +/// Insert the given operation at the current insertion point and return it. +Operation *OpBuilder::insert(Operation *op) { + if (block) + block->getOperations().insert(insertPoint, op); + return op; +} + /// Add new block and set the insertion point to the end of it. The block is /// inserted at the provided insertion point of 'parent'. Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) { @@ -328,10 +335,7 @@ Block *OpBuilder::createBlock(Block *insertBefore) { /// Create an operation given the fields represented as an OperationState. Operation *OpBuilder::createOperation(const OperationState &state) { - assert(block && "createOperation() called without setting builder's block"); - auto *op = Operation::create(state); - insert(op); - return op; + return insert(Operation::create(state)); } /// Attempts to fold the given operation and places new results within @@ -359,9 +363,3 @@ void OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value *> &results) { [](OpFoldResult result) { return result.get<Value *>(); }); op->erase(); } - -/// Insert the given operation at the current insertion point. -void OpBuilder::insert(Operation *op) { - if (block) - block->getOperations().insert(insertPoint, op); -} diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 6d34db90912..ea4ad681693 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -802,13 +802,6 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from, impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } -/// Clone the given operation without cloning its regions. -Operation *ConversionPatternRewriter::cloneWithoutRegions(Operation *op) { - Operation *newOp = OpBuilder::cloneWithoutRegions(*op); - impl->createdOps.push_back(newOp); - return newOp; -} - /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. Value *ConversionPatternRewriter::getRemappedValue(Value *key) { @@ -854,12 +847,11 @@ void ConversionPatternRewriter::cloneRegionBefore( } /// PatternRewriter hook for creating a new operation. -Operation * -ConversionPatternRewriter::createOperation(const OperationState &state) { - LLVM_DEBUG(llvm::dbgs() << "** Creating operation : " << state.name << "\n"); - auto *result = OpBuilder::createOperation(state); - impl->createdOps.push_back(result); - return result; +Operation *ConversionPatternRewriter::insert(Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "** Inserting operation : " << op->getName() + << "\n"); + impl->createdOps.push_back(op); + return OpBuilder::insert(op); } /// PatternRewriter hook for updating the root operation in-place. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index aa4563c96e4..e2ca3f8fc5e 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -86,12 +86,11 @@ public: // These are hooks implemented for PatternRewriter. protected: - // Implement the hook for creating operations, and make sure that newly - // created ops are added to the worklist for processing. - Operation *createOperation(const OperationState &state) override { - auto *result = OpBuilder::createOperation(state); - addToWorklist(result); - return result; + // Implement the hook for inserting operations, and make sure that newly + // inserted ops are added to the worklist for processing. + Operation *insert(Operation *op) override { + addToWorklist(op); + return OpBuilder::insert(op); } // If an operation is about to be removed, make sure it is not in our |

