summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/DialectConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/DialectConversion.cpp')
-rw-r--r--mlir/lib/Transforms/DialectConversion.cpp75
1 files changed, 49 insertions, 26 deletions
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ac13bc2ba5b..4b4575a5e50 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -164,9 +164,10 @@ struct ArgConverter {
// Rewrite Application
//===--------------------------------------------------------------------===//
- /// Erase any rewrites registered for the current block that is about to be
- /// removed. This merely drops the rewrites without undoing them.
- void notifyBlockRemoved(Block *block);
+ /// Erase any rewrites registered for the blocks within the given operation
+ /// which is about to be removed. This merely drops the rewrites without
+ /// undoing them.
+ void notifyOpRemoved(Operation *op);
/// Cleanup and undo any generated conversions for the arguments of block.
/// This method replaces the new block with the original, reverting the IR to
@@ -194,9 +195,16 @@ struct ArgConverter {
Block *block, TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping);
+ /// Insert a new conversion into the cache.
+ void insertConversion(Block *newBlock, ConvertedBlockInfo &&info);
+
/// A collection of blocks that have had their arguments converted.
llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
+ /// A mapping from valid regions, to those containing the original blocks of a
+ /// conversion.
+ DenseMap<Region *, std::unique_ptr<Region>> regionMapping;
+
/// An instance of the unknown location that is used when materializing
/// conversions.
Location loc;
@@ -212,18 +220,26 @@ struct ArgConverter {
//===----------------------------------------------------------------------===//
// Rewrite Application
-void ArgConverter::notifyBlockRemoved(Block *block) {
- auto it = conversionInfo.find(block);
- if (it == conversionInfo.end())
- return;
-
- // Drop all uses of the original arguments and delete the original block.
- Block *origBlock = it->second.origBlock;
- for (BlockArgument *arg : origBlock->getArguments())
- arg->dropAllUses();
- delete origBlock;
-
- conversionInfo.erase(it);
+void ArgConverter::notifyOpRemoved(Operation *op) {
+ for (Region &region : op->getRegions()) {
+ for (Block &block : region) {
+ // Drop any rewrites from within.
+ for (Operation &nestedOp : block)
+ if (nestedOp.getNumRegions())
+ notifyOpRemoved(&nestedOp);
+
+ // Check if this block was converted.
+ auto it = conversionInfo.find(&block);
+ if (it == conversionInfo.end())
+ return;
+
+ // Drop all uses of the original arguments and delete the original block.
+ Block *origBlock = it->second.origBlock;
+ for (BlockArgument *arg : origBlock->getArguments())
+ arg->dropAllUses();
+ conversionInfo.erase(it);
+ }
+ }
}
void ArgConverter::discardRewrites(Block *block) {
@@ -239,7 +255,7 @@ void ArgConverter::discardRewrites(Block *block) {
// Move the operations back the original block and the delete the new block.
origBlock->getOperations().splice(origBlock->end(), block->getOperations());
- origBlock->insertBefore(block);
+ origBlock->moveBefore(block);
block->erase();
conversionInfo.erase(it);
@@ -301,9 +317,6 @@ void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
if (castValue->use_empty())
castValue->getDefiningOp()->erase();
}
-
- // Drop the original block now the rewrites were applied.
- delete origBlock;
}
}
@@ -377,11 +390,24 @@ Block *ArgConverter::applySignatureConversion(
}
// Remove the original block from the region and return the new one.
- newBlock->getParent()->getBlocks().remove(block);
- conversionInfo.insert({newBlock, std::move(info)});
+ insertConversion(newBlock, std::move(info));
return newBlock;
}
+void ArgConverter::insertConversion(Block *newBlock,
+ ConvertedBlockInfo &&info) {
+ // Get a region to insert the old block.
+ Region *region = newBlock->getParent();
+ std::unique_ptr<Region> &mappedRegion = regionMapping[region];
+ if (!mappedRegion)
+ mappedRegion = std::make_unique<Region>(region->getParentOp());
+
+ // Move the original block to the mapped region and emplace the conversion.
+ mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(),
+ info.origBlock->getIterator());
+ conversionInfo.insert({newBlock, std::move(info)});
+}
+
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -642,11 +668,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
// If this operation defines any regions, drop any pending argument
// rewrites.
- if (argConverter.typeConverter && repl.op->getNumRegions()) {
- for (auto &region : repl.op->getRegions())
- for (auto &block : region)
- argConverter.notifyBlockRemoved(&block);
- }
+ if (argConverter.typeConverter && repl.op->getNumRegions())
+ argConverter.notifyOpRemoved(repl.op);
}
// In a second pass, erase all of the replaced operations in reverse. This
OpenPOWER on IntegriCloud