diff options
Diffstat (limited to 'mlir/lib/Transforms/DialectConversion.cpp')
-rw-r--r-- | mlir/lib/Transforms/DialectConversion.cpp | 75 |
1 files changed, 49 insertions, 26 deletions
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index ac13bc2ba5b..4b4575a5e50 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -164,9 +164,10 @@ struct ArgConverter { // Rewrite Application //===--------------------------------------------------------------------===// - /// Erase any rewrites registered for the current block that is about to be - /// removed. This merely drops the rewrites without undoing them. - void notifyBlockRemoved(Block *block); + /// Erase any rewrites registered for the blocks within the given operation + /// which is about to be removed. This merely drops the rewrites without + /// undoing them. + void notifyOpRemoved(Operation *op); /// Cleanup and undo any generated conversions for the arguments of block. /// This method replaces the new block with the original, reverting the IR to @@ -194,9 +195,16 @@ struct ArgConverter { Block *block, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping); + /// Insert a new conversion into the cache. + void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); + /// A collection of blocks that have had their arguments converted. llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo; + /// A mapping from valid regions, to those containing the original blocks of a + /// conversion. + DenseMap<Region *, std::unique_ptr<Region>> regionMapping; + /// An instance of the unknown location that is used when materializing /// conversions. Location loc; @@ -212,18 +220,26 @@ struct ArgConverter { //===----------------------------------------------------------------------===// // Rewrite Application -void ArgConverter::notifyBlockRemoved(Block *block) { - auto it = conversionInfo.find(block); - if (it == conversionInfo.end()) - return; - - // Drop all uses of the original arguments and delete the original block. - Block *origBlock = it->second.origBlock; - for (BlockArgument *arg : origBlock->getArguments()) - arg->dropAllUses(); - delete origBlock; - - conversionInfo.erase(it); +void ArgConverter::notifyOpRemoved(Operation *op) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + // Drop any rewrites from within. + for (Operation &nestedOp : block) + if (nestedOp.getNumRegions()) + notifyOpRemoved(&nestedOp); + + // Check if this block was converted. + auto it = conversionInfo.find(&block); + if (it == conversionInfo.end()) + return; + + // Drop all uses of the original arguments and delete the original block. + Block *origBlock = it->second.origBlock; + for (BlockArgument *arg : origBlock->getArguments()) + arg->dropAllUses(); + conversionInfo.erase(it); + } + } } void ArgConverter::discardRewrites(Block *block) { @@ -239,7 +255,7 @@ void ArgConverter::discardRewrites(Block *block) { // Move the operations back the original block and the delete the new block. origBlock->getOperations().splice(origBlock->end(), block->getOperations()); - origBlock->insertBefore(block); + origBlock->moveBefore(block); block->erase(); conversionInfo.erase(it); @@ -301,9 +317,6 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { if (castValue->use_empty()) castValue->getDefiningOp()->erase(); } - - // Drop the original block now the rewrites were applied. - delete origBlock; } } @@ -377,11 +390,24 @@ Block *ArgConverter::applySignatureConversion( } // Remove the original block from the region and return the new one. - newBlock->getParent()->getBlocks().remove(block); - conversionInfo.insert({newBlock, std::move(info)}); + insertConversion(newBlock, std::move(info)); return newBlock; } +void ArgConverter::insertConversion(Block *newBlock, + ConvertedBlockInfo &&info) { + // Get a region to insert the old block. + Region *region = newBlock->getParent(); + std::unique_ptr<Region> &mappedRegion = regionMapping[region]; + if (!mappedRegion) + mappedRegion = std::make_unique<Region>(region->getParentOp()); + + // Move the original block to the mapped region and emplace the conversion. + mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), + info.origBlock->getIterator()); + conversionInfo.insert({newBlock, std::move(info)}); +} + //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// @@ -642,11 +668,8 @@ void ConversionPatternRewriterImpl::applyRewrites() { // If this operation defines any regions, drop any pending argument // rewrites. - if (argConverter.typeConverter && repl.op->getNumRegions()) { - for (auto ®ion : repl.op->getRegions()) - for (auto &block : region) - argConverter.notifyBlockRemoved(&block); - } + if (argConverter.typeConverter && repl.op->getNumRegions()) + argConverter.notifyOpRemoved(repl.op); } // In a second pass, erase all of the replaced operations in reverse. This |