//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// // // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" #include "mlir/IR/Block.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::detail; #define DEBUG_TYPE "dialect-conversion" /// Recursively collect all of the operations to convert from within 'region'. /// If 'target' is nonnull, operations that are recursively legal have their /// regions pre-filtered to avoid considering them for legalization. static LogicalResult computeConversionSet(iterator_range region, Location regionLoc, std::vector &toConvert, ConversionTarget *target = nullptr) { if (llvm::empty(region)) return success(); // Traverse starting from the entry block. SmallVector worklist(1, &*region.begin()); DenseSet visitedBlocks; visitedBlocks.insert(worklist.front()); while (!worklist.empty()) { Block *block = worklist.pop_back_val(); // Compute the conversion set of each of the nested operations. for (Operation &op : *block) { toConvert.emplace_back(&op); // Don't check this operation's children for conversion if the operation // is recursively legal. auto legalityInfo = target ? target->isLegal(&op) : Optional(); if (legalityInfo && legalityInfo->isRecursivelyLegal) continue; for (auto ®ion : op.getRegions()) computeConversionSet(region.getBlocks(), region.getLoc(), toConvert, target); } // Recurse to children that haven't been visited. for (Block *succ : block->getSuccessors()) if (visitedBlocks.insert(succ).second) worklist.push_back(succ); } // Check that all blocks in the region were visited. if (llvm::any_of(llvm::drop_begin(region, 1), [&](Block &block) { return !visitedBlocks.count(&block); })) return emitError(regionLoc, "unreachable blocks were not converted"); return success(); } //===----------------------------------------------------------------------===// // Multi-Level Value Mapper //===----------------------------------------------------------------------===// namespace { /// This class wraps a BlockAndValueMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return the provided value. Value lookupOrDefault(Value from) const; /// Map a value to the one provided. void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); } /// Drop the last mapping for the given value. void erase(Value value) { mapping.erase(value); } private: /// Current value mappings. BlockAndValueMapping mapping; }; } // end anonymous namespace /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return the provided value. Value ConversionValueMapping::lookupOrDefault(Value from) const { // If this value had a valid mapping, unmap that value as well in the case // that it was also replaced. while (auto mappedValue = mapping.lookupOrNull(from)) from = mappedValue; return from; } //===----------------------------------------------------------------------===// // ArgConverter //===----------------------------------------------------------------------===// namespace { /// This class provides a simple interface for converting the types of block /// arguments. This is done by creating a new block that contains the new legal /// types and extracting the block that contains the old illegal types to allow /// for undoing pending rewrites in the case of failure. struct ArgConverter { ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter) : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter), rewriter(rewriter) {} /// This structure contains the information pertaining to an argument that has /// been converted. struct ConvertedArgInfo { ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, Value castValue = nullptr) : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} /// The start index of in the new argument list that contains arguments that /// replace the original. unsigned newArgIdx; /// The number of arguments that replaced the original argument. unsigned newArgSize; /// The cast value that was created to cast from the new arguments to the /// old. This only used if 'newArgSize' > 1. Value castValue; }; /// This structure contains information pertaining to a block that has had its /// signature converted. struct ConvertedBlockInfo { ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {} /// The original block that was requested to have its signature converted. Block *origBlock; /// The conversion information for each of the arguments. The information is /// None if the argument was dropped during conversion. SmallVector, 1> argInfo; }; /// Return if the signature of the given block has already been converted. bool hasBeenConverted(Block *block) const { return conversionInfo.count(block); } //===--------------------------------------------------------------------===// // Rewrite Application //===--------------------------------------------------------------------===// /// Erase any rewrites registered for the blocks within the given operation /// which is about to be removed. This merely drops the rewrites without /// undoing them. void notifyOpRemoved(Operation *op); /// Cleanup and undo any generated conversions for the arguments of block. /// This method replaces the new block with the original, reverting the IR to /// its original state. void discardRewrites(Block *block); /// Fully replace uses of the old arguments with the new, materializing cast /// operations as necessary. // FIXME(riverriddle) The 'mapping' parameter is only necessary because the // implementation of replaceUsesOfBlockArgument is buggy. void applyRewrites(ConversionValueMapping &mapping); //===--------------------------------------------------------------------===// // Conversion //===--------------------------------------------------------------------===// /// Attempt to convert the signature of the given block, if successful a new /// block is returned containing the new arguments. On failure, nullptr is /// returned. Block *convertSignature(Block *block, ConversionValueMapping &mapping); /// Apply the given signature conversion on the given block. The new block /// containing the updated signature is returned. Block *applySignatureConversion( Block *block, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping); /// Insert a new conversion into the cache. void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); /// A collection of blocks that have had their arguments converted. llvm::MapVector conversionInfo; /// A mapping from valid regions, to those containing the original blocks of a /// conversion. DenseMap> regionMapping; /// An instance of the unknown location that is used when materializing /// conversions. Location loc; /// The type converter to use when changing types. TypeConverter *typeConverter; /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; }; } // end anonymous namespace //===----------------------------------------------------------------------===// // Rewrite Application void ArgConverter::notifyOpRemoved(Operation *op) { for (Region ®ion : op->getRegions()) { for (Block &block : region) { // Drop any rewrites from within. for (Operation &nestedOp : block) if (nestedOp.getNumRegions()) notifyOpRemoved(&nestedOp); // Check if this block was converted. auto it = conversionInfo.find(&block); if (it == conversionInfo.end()) return; // Drop all uses of the original arguments and delete the original block. Block *origBlock = it->second.origBlock; for (BlockArgument arg : origBlock->getArguments()) arg.dropAllUses(); conversionInfo.erase(it); } } } void ArgConverter::discardRewrites(Block *block) { auto it = conversionInfo.find(block); if (it == conversionInfo.end()) return; Block *origBlock = it->second.origBlock; // Drop all uses of the new block arguments and replace uses of the new block. for (int i = block->getNumArguments() - 1; i >= 0; --i) block->getArgument(i).dropAllUses(); block->replaceAllUsesWith(origBlock); // Move the operations back the original block and the delete the new block. origBlock->getOperations().splice(origBlock->end(), block->getOperations()); origBlock->moveBefore(block); block->erase(); conversionInfo.erase(it); } void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { for (auto &info : conversionInfo) { Block *newBlock = info.first; ConvertedBlockInfo &blockInfo = info.second; Block *origBlock = blockInfo.origBlock; // Process the remapping for each of the original arguments. for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { Optional &argInfo = blockInfo.argInfo[i]; BlockArgument origArg = origBlock->getArgument(i); // Handle the case of a 1->0 value mapping. if (!argInfo) { // If a replacement value was given for this argument, use that to // replace all uses. auto argReplacementValue = mapping.lookupOrDefault(origArg); if (argReplacementValue != origArg) { origArg.replaceAllUsesWith(argReplacementValue); continue; } // If there are any dangling uses then replace the argument with one // generated by the type converter. This is necessary as the cast must // persist in the IR after conversion. if (!origArg.use_empty()) { rewriter.setInsertionPointToStart(newBlock); auto *newOp = typeConverter->materializeConversion( rewriter, origArg.getType(), llvm::None, loc); origArg.replaceAllUsesWith(newOp->getResult(0)); } continue; } // If mapping is 1-1, replace the remaining uses and drop the cast // operation. // FIXME(riverriddle) This should check that the result type and operand // type are the same, otherwise it should force a conversion to be // materialized. if (argInfo->newArgSize == 1) { origArg.replaceAllUsesWith( mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx))); continue; } // Otherwise this is a 1->N value mapping. Value castValue = argInfo->castValue; assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping"); // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue)); // If all users of the cast were removed, we can drop it. Otherwise, keep // the operation alive and let the user handle any remaining usages. if (castValue.use_empty()) castValue.getDefiningOp()->erase(); } } } //===----------------------------------------------------------------------===// // Conversion Block *ArgConverter::convertSignature(Block *block, ConversionValueMapping &mapping) { if (auto conversion = typeConverter->convertBlockSignature(block)) return applySignatureConversion(block, *conversion, mapping); return nullptr; } Block *ArgConverter::applySignatureConversion( Block *block, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping) { // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); auto convertedTypes = signatureConversion.getConvertedTypes(); if (origArgCount == 0 && convertedTypes.empty()) return block; // Split the block at the beginning to get a new block to use for the updated // signature. Block *newBlock = block->splitBlock(block->begin()); block->replaceAllUsesWith(newBlock); SmallVector newArgRange(newBlock->addArguments(convertedTypes)); ArrayRef newArgs(newArgRange); // Remap each of the original arguments as determined by the signature // conversion. ConvertedBlockInfo info(block); info.argInfo.resize(origArgCount); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(newBlock); for (unsigned i = 0; i != origArgCount; ++i) { auto inputMap = signatureConversion.getInputMapping(i); if (!inputMap) continue; BlockArgument origArg = block->getArgument(i); // If inputMap->replacementValue is not nullptr, then the argument is // dropped and a replacement value is provided to be the remappedValue. if (inputMap->replacementValue) { assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); mapping.map(origArg, inputMap->replacementValue); continue; } // If this is a 1->1 mapping, then map the argument directly. if (inputMap->size == 1) { mapping.map(origArg, newArgs[inputMap->inputNo]); info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size); continue; } // Otherwise, this is a 1->N mapping. Call into the provided type converter // to pack the new values. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); Operation *cast = typeConverter->materializeConversion( rewriter, origArg.getType(), replArgs, loc); assert(cast->getNumResults() == 1 && cast->getNumOperands() == replArgs.size()); mapping.map(origArg, cast->getResult(0)); info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0)); } // Remove the original block from the region and return the new one. insertConversion(newBlock, std::move(info)); return newBlock; } void ArgConverter::insertConversion(Block *newBlock, ConvertedBlockInfo &&info) { // Get a region to insert the old block. Region *region = newBlock->getParent(); std::unique_ptr &mappedRegion = regionMapping[region]; if (!mappedRegion) mappedRegion = std::make_unique(region->getParentOp()); // Move the original block to the mapped region and emplace the conversion. mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), info.origBlock->getIterator()); conversionInfo.insert({newBlock, std::move(info)}); } //===----------------------------------------------------------------------===// // ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// namespace { /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { RewriterState(unsigned numCreatedOps, unsigned numReplacements, unsigned numBlockActions, unsigned numIgnoredOperations, unsigned numRootUpdates) : numCreatedOps(numCreatedOps), numReplacements(numReplacements), numBlockActions(numBlockActions), numIgnoredOperations(numIgnoredOperations), numRootUpdates(numRootUpdates) {} /// The current number of created operations. unsigned numCreatedOps; /// The current number of replacements queued. unsigned numReplacements; /// The current number of block actions performed. unsigned numBlockActions; /// The current number of ignored operations. unsigned numIgnoredOperations; /// The current number of operations that were updated in place. unsigned numRootUpdates; }; /// The state of an operation that was updated by a pattern in-place. This /// contains all of the necessary information to reconstruct an operation that /// was updated in place. class OperationTransactionState { public: OperationTransactionState() = default; OperationTransactionState(Operation *op) : op(op), loc(op->getLoc()), attrs(op->getAttrList()), operands(op->operand_begin(), op->operand_end()), successors(op->successor_begin(), op->successor_end()) {} /// Discard the transaction state and reset the state of the original /// operation. void resetOperation() const { op->setLoc(loc); op->setAttrs(attrs); op->setOperands(operands); for (auto it : llvm::enumerate(successors)) op->setSuccessor(it.value(), it.index()); } /// Return the original operation of this state. Operation *getOperation() const { return op; } private: Operation *op; LocationAttr loc; NamedAttributeList attrs; SmallVector operands; SmallVector successors; }; } // end anonymous namespace namespace mlir { namespace detail { struct ConversionPatternRewriterImpl { /// This class represents one requested operation replacement via 'replaceOp'. struct OpReplacement { OpReplacement() = default; OpReplacement(Operation *op, ValueRange newValues) : op(op), newValues(newValues.begin(), newValues.end()) {} Operation *op; SmallVector newValues; }; /// The kind of the block action performed during the rewrite. Actions can be /// undone if the conversion fails. enum class BlockActionKind { Create, Move, Split, TypeConversion }; /// Original position of the given block in its parent region. We cannot use /// a region iterator because it could have been invalidated by other region /// operations since the position was stored. struct BlockPosition { Region *region; Region::iterator::difference_type position; }; /// The storage class for an undoable block action (one of BlockActionKind), /// contains the information necessary to undo this action. struct BlockAction { static BlockAction getCreate(Block *block) { return {BlockActionKind::Create, block, {}}; } static BlockAction getMove(Block *block, BlockPosition originalPos) { return {BlockActionKind::Move, block, {originalPos}}; } static BlockAction getSplit(Block *block, Block *originalBlock) { BlockAction action{BlockActionKind::Split, block, {}}; action.originalBlock = originalBlock; return action; } static BlockAction getTypeConversion(Block *block) { return BlockAction{BlockActionKind::TypeConversion, block, {}}; } // The action kind. BlockActionKind kind; // A pointer to the block that was created by the action. Block *block; union { // In use if kind == BlockActionKind::Move and contains a pointer to the // region that originally contained the block as well as the position of // the block in that region. BlockPosition originalPosition; // In use if kind == BlockActionKind::Split and contains a pointer to the // block that was split into two parts. Block *originalBlock; }; }; ConversionPatternRewriterImpl(PatternRewriter &rewriter, TypeConverter *converter) : argConverter(converter, rewriter) {} /// Return the current state of the rewriter. RewriterState getCurrentState(); /// Reset the state of the rewriter to a previously saved point. void resetState(RewriterState state); /// Undo the block actions (motions, splits) one by one in reverse order until /// "numActionsToKeep" actions remains. void undoBlockActions(unsigned numActionsToKeep = 0); /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. void discardRewrites(); /// Apply all requested operation rewrites. This method is invoked when the /// conversion process succeeds. void applyRewrites(); /// Convert the signature of the given block. LogicalResult convertBlockSignature(Block *block); /// Apply a signature conversion on the given region. Block * applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion); /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ValueRange newValues, ValueRange valuesToRemoveIfDead); /// Notifies that a block was split. void notifySplitBlock(Block *block, Block *continuation); /// Notifies that the blocks of a region are about to be moved. void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, Region::iterator before); /// Notifies that the blocks of a region were cloned into another. void notifyRegionWasClonedBefore(iterator_range &blocks, Location origRegionLoc); /// Remap the given operands to those with potentially different types. void remapValues(Operation::operand_range operands, SmallVectorImpl &remapped); /// Returns true if the given operation is ignored, and does not need to be /// converted. bool isOpIgnored(Operation *op) const; /// Recursively marks the nested operations under 'op' as ignored. This /// removes them from being considered for legalization. void markNestedOpsIgnored(Operation *op); // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. ConversionValueMapping mapping; /// Utility used to convert block arguments. ArgConverter argConverter; /// Ordered vector of all of the newly created operations during conversion. std::vector createdOps; /// Ordered vector of any requested operation replacements. SmallVector replacements; /// Ordered list of block operations (creations, splits, motions). SmallVector blockActions; /// A set of operations that have been erased/replaced/etc that should no /// longer be considered for legalization. This is not meant to be an /// exhaustive list of all operations, but the minimal set that can be used to /// detect if a given operation should be `ignored`. For example, we may add /// the operations that define non-empty regions to the set, but not any of /// the others. This simplifies the amount of memory needed as we can query if /// the parent operation was ignored. llvm::SetVector ignoredOps; /// A transaction state for each of operations that were updated in-place. SmallVector rootUpdates; #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra /// verification. SmallPtrSet pendingRootUpdates; #endif }; } // end namespace detail } // end namespace mlir RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(createdOps.size(), replacements.size(), blockActions.size(), ignoredOps.size(), rootUpdates.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { // Reset any operations that were updated in place. for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i) rootUpdates[i].resetOperation(); rootUpdates.resize(state.numRootUpdates); // Undo any block actions. undoBlockActions(state.numBlockActions); // Reset any replaced operations and undo any saved mappings. for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) for (auto result : repl.op->getResults()) mapping.erase(result); replacements.resize(state.numReplacements); // Pop all of the newly created operations. while (createdOps.size() != state.numCreatedOps) { createdOps.back()->erase(); createdOps.pop_back(); } // Pop all of the recorded ignored operations that are no longer valid. while (ignoredOps.size() != state.numIgnoredOperations) ignoredOps.pop_back(); } void ConversionPatternRewriterImpl::undoBlockActions( unsigned numActionsToKeep) { for (auto &action : llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { switch (action.kind) { // Delete the created block. case BlockActionKind::Create: { // Unlink all of the operations within this block, they will be deleted // separately. auto &blockOps = action.block->getOperations(); while (!blockOps.empty()) blockOps.remove(blockOps.begin()); action.block->dropAllDefinedValueUses(); action.block->erase(); break; } // Move the block back to its original position. case BlockActionKind::Move: { Region *originalRegion = action.originalPosition.region; originalRegion->getBlocks().splice( std::next(originalRegion->begin(), action.originalPosition.position), action.block->getParent()->getBlocks(), action.block); break; } // Merge back the block that was split out. case BlockActionKind::Split: { action.originalBlock->getOperations().splice( action.originalBlock->end(), action.block->getOperations()); action.block->dropAllDefinedValueUses(); action.block->erase(); break; } // Undo the type conversion. case BlockActionKind::TypeConversion: { argConverter.discardRewrites(action.block); break; } } } blockActions.resize(numActionsToKeep); } void ConversionPatternRewriterImpl::discardRewrites() { // Reset any operations that were updated in place. for (auto &state : rootUpdates) state.resetOperation(); undoBlockActions(); // Remove any newly created ops. for (auto *op : llvm::reverse(createdOps)) op->erase(); } void ConversionPatternRewriterImpl::applyRewrites() { // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) { if (auto newValue = repl.newValues[i]) repl.op->getResult(i).replaceAllUsesWith( mapping.lookupOrDefault(newValue)); } // If this operation defines any regions, drop any pending argument // rewrites. if (argConverter.typeConverter && repl.op->getNumRegions()) argConverter.notifyOpRemoved(repl.op); } // In a second pass, erase all of the replaced operations in reverse. This // allows processing nested operations before their parent region is // destroyed. for (auto &repl : llvm::reverse(replacements)) repl.op->erase(); argConverter.applyRewrites(mapping); } LogicalResult ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { // Check to see if this block should not be converted: // * There is no type converter. // * The block has already been converted. // * This is an entry block, these are converted explicitly via patterns. if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) || !block->getParent() || block->isEntryBlock()) return success(); // Otherwise, try to convert the block signature. Block *newBlock = argConverter.convertSignature(block, mapping); if (newBlock) blockActions.push_back(BlockAction::getTypeConversion(newBlock)); return success(newBlock); } Block *ConversionPatternRewriterImpl::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion) { if (!region->empty()) { Block *newEntry = argConverter.applySignatureConversion( ®ion->front(), conversion, mapping); blockActions.push_back(BlockAction::getTypeConversion(newEntry)); return newEntry; } return nullptr; } void ConversionPatternRewriterImpl::replaceOp(Operation *op, ValueRange newValues, ValueRange valuesToRemoveIfDead) { assert(newValues.size() == op->getNumResults()); // Create mappings for each of the new result values. for (unsigned i = 0, e = newValues.size(); i < e; ++i) if (auto repl = newValues[i]) mapping.map(op->getResult(i), repl); // Record the requested operation replacement. replacements.emplace_back(op, newValues); /// Mark this operation as recursively ignored so that we don't need to /// convert any nested operations. markNestedOpsIgnored(op); } void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, Block *continuation) { blockActions.push_back(BlockAction::getSplit(continuation, block)); } void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( Region ®ion, Region &parent, Region::iterator before) { for (auto &pair : llvm::enumerate(region)) { Block &block = pair.value(); Region::iterator::difference_type position = pair.index(); blockActions.push_back(BlockAction::getMove(&block, {®ion, position})); } } void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore( iterator_range &blocks, Location origRegionLoc) { for (Block &block : blocks) blockActions.push_back(BlockAction::getCreate(&block)); // Compute the conversion set for the inlined region. auto result = computeConversionSet(blocks, origRegionLoc, createdOps); // This original region has already had its conversion set computed, so there // shouldn't be any new failures. (void)result; assert(succeeded(result) && "expected region to have no unreachable blocks"); } void ConversionPatternRewriterImpl::remapValues( Operation::operand_range operands, SmallVectorImpl &remapped) { remapped.reserve(llvm::size(operands)); for (Value operand : operands) remapped.push_back(mapping.lookupOrDefault(operand)); } bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation or its parent were ignored. return ignoredOps.count(op) || ignoredOps.count(op->getParentOp()); } void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { // Walk this operation and collect nested operations that define non-empty // regions. We mark such operations as 'ignored' so that we know we don't have // to convert them, or their nested ops. if (op->getNumRegions() == 0) return; op->walk([&](Operation *op) { if (llvm::any_of(op->getRegions(), [](Region ®ion) { return !region.empty(); })) ignoredOps.insert(op); }); } //===----------------------------------------------------------------------===// // ConversionPatternRewriter //===----------------------------------------------------------------------===// ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter) : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {} ConversionPatternRewriter::~ConversionPatternRewriter() {} /// PatternRewriter hook for replacing the results of an operation. void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues, ValueRange valuesToRemoveIfDead) { LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName() << "\n"); impl->replaceOp(op, newValues, valuesToRemoveIfDead); } /// PatternRewriter hook for erasing a dead operation. The uses of this /// operation *must* be made dead by the end of the conversion process, /// otherwise an assert will be issued. void ConversionPatternRewriter::eraseOp(Operation *op) { LLVM_DEBUG(llvm::dbgs() << "** Erasing operation : " << op->getName() << "\n"); SmallVector nullRepls(op->getNumResults(), nullptr); impl->replaceOp(op, nullRepls, /*valuesToRemoveIfDead=*/llvm::None); } /// Apply a signature conversion to the entry block of the given region. Block *ConversionPatternRewriter::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion) { return impl->applySignatureConversion(region, conversion); } void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value to) { for (auto &u : from.getUses()) { if (u.getOwner() == to.getDefiningOp()) continue; u.getOwner()->replaceUsesOfWith(from, to); } impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. Value ConversionPatternRewriter::getRemappedValue(Value key) { return impl->mapping.lookupOrDefault(key); } /// PatternRewriter hook for splitting a block into two parts. Block *ConversionPatternRewriter::splitBlock(Block *block, Block::iterator before) { auto *continuation = PatternRewriter::splitBlock(block, before); impl->notifySplitBlock(block, continuation); return continuation; } /// PatternRewriter hook for merging a block into another. void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, ValueRange argValues) { // TODO(riverriddle) This requires fixing the implementation of // 'replaceUsesOfBlockArgument', which currently isn't undoable. llvm_unreachable("block merging updates are currently not supported"); } /// PatternRewriter hook for moving blocks out of a region. void ConversionPatternRewriter::inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before) { impl->notifyRegionIsBeingInlinedBefore(region, parent, before); PatternRewriter::inlineRegionBefore(region, parent, before); } /// PatternRewriter hook for cloning blocks of one region into another. void ConversionPatternRewriter::cloneRegionBefore( Region ®ion, Region &parent, Region::iterator before, BlockAndValueMapping &mapping) { if (region.empty()) return; PatternRewriter::cloneRegionBefore(region, parent, before, mapping); // Collect the range of the cloned blocks. auto clonedBeginIt = mapping.lookup(®ion.front())->getIterator(); auto clonedBlocks = llvm::make_range(clonedBeginIt, before); impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc()); } /// PatternRewriter hook for creating a new operation. Operation *ConversionPatternRewriter::insert(Operation *op) { LLVM_DEBUG(llvm::dbgs() << "** Inserting operation : " << op->getName() << "\n"); impl->createdOps.push_back(op); return OpBuilder::insert(op); } /// PatternRewriter hook for updating the root operation in-place. void ConversionPatternRewriter::startRootUpdate(Operation *op) { #ifndef NDEBUG impl->pendingRootUpdates.insert(op); #endif impl->rootUpdates.emplace_back(op); } /// PatternRewriter hook for updating the root operation in-place. void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif } /// PatternRewriter hook for updating the root operation in-place. void ConversionPatternRewriter::cancelRootUpdate(Operation *op) { #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif // Erase the last update for this operation. auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; }; auto &rootUpdates = impl->rootUpdates; auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp); rootUpdates.erase(rootUpdates.begin() + (rootUpdates.rend() - it)); } /// Return a reference to the internal implementation. detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { return *impl; } //===----------------------------------------------------------------------===// // Conversion Patterns //===----------------------------------------------------------------------===// /// Attempt to match and rewrite the IR root at the specified operation. PatternMatchResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { SmallVector operands; auto &dialectRewriter = static_cast(rewriter); dialectRewriter.getImpl().remapValues(op->getOperands(), operands); // If this operation has no successors, invoke the rewrite directly. if (op->getNumSuccessors() == 0) return matchAndRewrite(op, operands, dialectRewriter); // Otherwise, we need to remap the successors. SmallVector destinations; destinations.reserve(op->getNumSuccessors()); SmallVector, 2> operandsPerDestination; unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0); for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) { destinations.push_back(op->getSuccessor(i)); // Lookup the successors operands. unsigned n = op->getNumSuccessorOperands(i); operandsPerDestination.push_back( llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n)); seen += n; } // Rewrite the operation. return matchAndRewrite( op, llvm::makeArrayRef(operands.data(), operands.data() + firstSuccessorOperand), destinations, operandsPerDestination, dialectRewriter); } //===----------------------------------------------------------------------===// // OperationLegalizer //===----------------------------------------------------------------------===// namespace { /// A set of rewrite patterns that can be used to legalize a given operation. using LegalizationPatterns = SmallVector; /// This class defines a recursive operation legalizer. class OperationLegalizer { public: using LegalizationAction = ConversionTarget::LegalizationAction; OperationLegalizer(ConversionTarget &targetInfo, const OwningRewritePatternList &patterns) : target(targetInfo) { buildLegalizationGraph(patterns); computeLegalizationGraphBenefit(); } /// Returns if the given operation is known to be illegal on the target. bool isIllegal(Operation *op) const; /// Attempt to legalize the given operation. Returns success if the operation /// was legalized, failure otherwise. LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); /// Returns the conversion target in use by the legalizer. ConversionTarget &getTarget() { return target; } private: /// Attempt to legalize the given operation by folding it. LogicalResult legalizeWithFold(Operation *op, ConversionPatternRewriter &rewriter); /// Attempt to legalize the given operation by applying the provided pattern. /// Returns success if the operation was legalized, failure otherwise. LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, ConversionPatternRewriter &rewriter); /// Build an optimistic legalization graph given the provided patterns. This /// function populates 'legalizerPatterns' with the operations that are not /// directly legal, but may be transitively legal for the current target given /// the provided patterns. void buildLegalizationGraph(const OwningRewritePatternList &patterns); /// Compute the benefit of each node within the computed legalization graph. /// This orders the patterns within 'legalizerPatterns' based upon two /// criteria: /// 1) Prefer patterns that have the lowest legalization depth, i.e. /// represent the more direct mapping to the target. /// 2) When comparing patterns with the same legalization depth, prefer the /// pattern with the highest PatternBenefit. This allows for users to /// prefer specific legalizations over others. void computeLegalizationGraphBenefit(); /// The current set of patterns that have been applied. SmallPtrSet appliedPatterns; /// The set of legality information for operations transitively supported by /// the target. DenseMap legalizerPatterns; /// The legalization information provided by the target. ConversionTarget ⌖ }; } // namespace bool OperationLegalizer::isIllegal(Operation *op) const { // Check if the target explicitly marked this operation as illegal. return target.getOpAction(op->getName()) == LegalizationAction::Illegal; } LogicalResult OperationLegalizer::legalize(Operation *op, ConversionPatternRewriter &rewriter) { LLVM_DEBUG(llvm::dbgs() << "Legalizing operation : " << op->getName() << "\n"); // Check if this operation is legal on the target. if (auto legalityInfo = target.isLegal(op)) { LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation marked legal by the target\n"); // If this operation is recursively legal, mark its children as ignored so // that we don't consider them for legalization. if (legalityInfo->isRecursivelyLegal) { LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation is recursively legal; " "Skipping internals\n"); rewriter.getImpl().markNestedOpsIgnored(op); } return success(); } // Check to see if the operation is ignored and doesn't need to be converted. if (rewriter.getImpl().isOpIgnored(op)) { LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation marked ignored during conversion\n"); return success(); } // If the operation isn't legal, try to fold it in-place. // TODO(riverriddle) Should we always try to do this, even if the op is // already legal? if (succeeded(legalizeWithFold(op, rewriter))) { LLVM_DEBUG(llvm::dbgs() << "-- Success : Operation was folded\n"); return success(); } // Otherwise, we need to apply a legalization pattern to this operation. auto it = legalizerPatterns.find(op->getName()); if (it == legalizerPatterns.end()) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no known legalization path.\n"); return failure(); } // The patterns are sorted by expected benefit, so try to apply each in-order. for (auto *pattern : it->second) if (succeeded(legalizePattern(op, pattern, rewriter))) return success(); LLVM_DEBUG(llvm::dbgs() << "-- FAIL : no matched legalization pattern.\n"); return failure(); } LogicalResult OperationLegalizer::legalizeWithFold(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. SmallVector replacementValues; rewriter.setInsertionPoint(op); if (failed(rewriter.tryFold(op, replacementValues))) return failure(); // Insert a replacement for 'op' with the folded replacement values. rewriter.replaceOp(op, replacementValues); // Recursively legalize any new constant operations. for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); i != e; ++i) { Operation *cstOp = rewriterImpl.createdOps[i]; if (failed(legalize(cstOp, rewriter))) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated folding constant '" << cstOp->getName() << "' was illegal.\n"); rewriterImpl.resetState(curState); return failure(); } } return success(); } LogicalResult OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, ConversionPatternRewriter &rewriter) { LLVM_DEBUG({ llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> ("; interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); llvm::dbgs() << ")'.\n"; }); // Ensure that we don't cycle by not allowing the same pattern to be // applied twice in the same recursion stack. // TODO(riverriddle) We could eventually converge, but that requires more // complicated analysis. if (!appliedPatterns.insert(pattern).second) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern was already applied.\n"); return failure(); } auto &rewriterImpl = rewriter.getImpl(); RewriterState curState = rewriterImpl.getCurrentState(); auto cleanupFailure = [&] { // Reset the rewriter state and pop this pattern. rewriterImpl.resetState(curState); appliedPatterns.erase(pattern); return failure(); }; // Try to rewrite with the given pattern. rewriter.setInsertionPoint(op); auto matchedPattern = pattern->matchAndRewrite(op, rewriter); #ifndef NDEBUG assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); #endif if (!matchedPattern) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Pattern failed to match.\n"); return cleanupFailure(); } // If the pattern moved or created any blocks, try to legalize their types. // This ensures that the types of the block arguments are legal for the region // they were moved into. for (unsigned i = curState.numBlockActions, e = rewriterImpl.blockActions.size(); i != e; ++i) { auto &action = rewriterImpl.blockActions[i]; if (action.kind == ConversionPatternRewriterImpl::BlockActionKind::TypeConversion) continue; // Convert the block signature. if (failed(rewriterImpl.convertBlockSignature(action.block))) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: failed to convert types of moved block.\n"); return cleanupFailure(); } } // Check all of the replacements to ensure that the pattern actually replaced // the root operation. We also mark any other replaced ops as 'dead' so that // we don't try to legalize them later. bool replacedRoot = false; for (unsigned i = curState.numReplacements, e = rewriterImpl.replacements.size(); i != e; ++i) { Operation *replacedOp = rewriterImpl.replacements[i].op; if (replacedOp == op) replacedRoot = true; else rewriterImpl.ignoredOps.insert(replacedOp); } // Check that the root was either updated or replace. auto updatedRootInPlace = [&] { return llvm::any_of( llvm::drop_begin(rewriterImpl.rootUpdates, curState.numRootUpdates), [op](auto &state) { return state.getOperation() == op; }); }; (void)replacedRoot; (void)updatedRootInPlace; assert((replacedRoot || updatedRootInPlace()) && "expected pattern to replace the root operation"); // Recursively legalize each of the operations updated in place. for (unsigned i = curState.numRootUpdates, e = rewriterImpl.rootUpdates.size(); i != e; ++i) { auto &state = rewriterImpl.rootUpdates[i]; if (failed(legalize(state.getOperation(), rewriter))) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Operation updated in-place '" << op->getName() << "' was illegal.\n"); return cleanupFailure(); } } // Recursively legalize each of the new operations. for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); i != e; ++i) { Operation *op = rewriterImpl.createdOps[i]; if (failed(legalize(op, rewriter))) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation '" << op->getName() << "' was illegal.\n"); return cleanupFailure(); } } appliedPatterns.erase(pattern); return success(); } void OperationLegalizer::buildLegalizationGraph( const OwningRewritePatternList &patterns) { // A mapping between an operation and a set of operations that can be used to // generate it. DenseMap> parentOps; // A mapping between an operation and any currently invalid patterns it has. DenseMap> invalidPatterns; // A worklist of patterns to consider for legality. llvm::SetVector patternWorklist; // Build the mapping from operations to the parent ops that may generate them. for (auto &pattern : patterns) { auto root = pattern->getRootKind(); // Skip operations that are always known to be legal. if (target.getOpAction(root) == LegalizationAction::Legal) continue; // Add this pattern to the invalid set for the root op and record this root // as a parent for any generated operations. invalidPatterns[root].insert(pattern.get()); for (auto op : pattern->getGeneratedOps()) parentOps[op].insert(root); // Add this pattern to the worklist. patternWorklist.insert(pattern.get()); } while (!patternWorklist.empty()) { auto *pattern = patternWorklist.pop_back_val(); // Check to see if any of the generated operations are invalid. if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { Optional action = target.getOpAction(op); return !legalizerPatterns.count(op) && (!action || action == LegalizationAction::Illegal); })) continue; // Otherwise, if all of the generated operation are valid, this op is now // legal so add all of the child patterns to the worklist. legalizerPatterns[pattern->getRootKind()].push_back(pattern); invalidPatterns[pattern->getRootKind()].erase(pattern); // Add any invalid patterns of the parent operations to see if they have now // become legal. for (auto op : parentOps[pattern->getRootKind()]) patternWorklist.set_union(invalidPatterns[op]); } } void OperationLegalizer::computeLegalizationGraphBenefit() { // The smallest pattern depth, when legalizing an operation. DenseMap minPatternDepth; // Compute the minimum legalization depth for a given operation. std::function computeDepth = [&](OperationName op) { // Check for existing depth. auto depthIt = minPatternDepth.find(op); if (depthIt != minPatternDepth.end()) return depthIt->second; // If a mapping for this operation does not exist, then this operation // is always legal. Return 0 as the depth for a directly legal operation. auto opPatternsIt = legalizerPatterns.find(op); if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) return 0u; // Initialize the depth to the maximum value. unsigned minDepth = std::numeric_limits::max(); // Record this initial depth in case we encounter this op again when // recursively computing the depth. minPatternDepth.try_emplace(op, minDepth); // Compute the depth for each pattern used to legalize this operation. SmallVector, 4> patternsByDepth; patternsByDepth.reserve(opPatternsIt->second.size()); for (RewritePattern *pattern : opPatternsIt->second) { unsigned depth = 0; for (auto generatedOp : pattern->getGeneratedOps()) depth = std::max(depth, computeDepth(generatedOp) + 1); patternsByDepth.emplace_back(pattern, depth); // Update the min depth for this operation. minDepth = std::min(minDepth, depth); } // Update the pattern depth. minPatternDepth[op] = minDepth; // If the operation only has one legalization pattern, there is no need to // sort them. if (patternsByDepth.size() == 1) return minDepth; // Sort the patterns by those likely to be the most beneficial. llvm::array_pod_sort( patternsByDepth.begin(), patternsByDepth.end(), [](const std::pair *lhs, const std::pair *rhs) { // First sort by the smaller pattern legalization depth. if (lhs->second != rhs->second) return llvm::array_pod_sort_comparator(&lhs->second, &rhs->second); // Then sort by the larger pattern benefit. auto lhsBenefit = lhs->first->getBenefit(); auto rhsBenefit = rhs->first->getBenefit(); return llvm::array_pod_sort_comparator(&rhsBenefit, &lhsBenefit); }); // Update the legalization pattern to use the new sorted list. opPatternsIt->second.clear(); for (auto &patternIt : patternsByDepth) opPatternsIt->second.push_back(patternIt.first); return minDepth; }; // For each operation that is transitively legal, compute a cost for it. for (auto &opIt : legalizerPatterns) if (!minPatternDepth.count(opIt.first)) computeDepth(opIt.first); } //===----------------------------------------------------------------------===// // OperationConverter //===----------------------------------------------------------------------===// namespace { enum OpConversionMode { // In this mode, the conversion will ignore failed conversions to allow // illegal operations to co-exist in the IR. Partial, // In this mode, all operations must be legal for the given target for the // conversion to succeed. Full, // In this mode, operations are analyzed for legality. No actual rewrites are // applied to the operations on success. Analysis, }; // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the // conversion mode. struct OperationConverter { explicit OperationConverter(ConversionTarget &target, const OwningRewritePatternList &patterns, OpConversionMode mode, DenseSet *legalizableOps = nullptr) : opLegalizer(target, patterns), mode(mode), legalizableOps(legalizableOps) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef ops, TypeConverter *typeConverter); private: /// Converts an operation with the given rewriter. LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); /// Converts the type signatures of the blocks nested within 'op'. LogicalResult convertBlockSignatures(ConversionPatternRewriter &rewriter, Operation *op); /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; /// The conversion mode to use when legalizing operations. OpConversionMode mode; /// A set of pre-existing operations that were found to be legalizable to the /// target. This field is only used when mode == OpConversionMode::Analysis. DenseSet *legalizableOps; }; } // end anonymous namespace LogicalResult OperationConverter::convertBlockSignatures(ConversionPatternRewriter &rewriter, Operation *op) { // Check to see if type signatures need to be converted. if (!rewriter.getImpl().argConverter.typeConverter) return success(); for (auto ®ion : op->getRegions()) { for (auto &block : llvm::make_early_inc_range(region)) if (failed(rewriter.getImpl().convertBlockSignature(&block))) return failure(); } return success(); } LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, Operation *op) { // Legalize the given operation. if (failed(opLegalizer.legalize(op, rewriter))) { // Handle the case of a failed conversion for each of the different modes. /// Full conversions expect all operations to be converted. if (mode == OpConversionMode::Full) return op->emitError() << "failed to legalize operation '" << op->getName() << "'"; /// Partial conversions allow conversions to fail iff the operation was not /// explicitly marked as illegal. if (mode == OpConversionMode::Partial && opLegalizer.isIllegal(op)) return op->emitError() << "failed to legalize operation '" << op->getName() << "' that was explicitly marked illegal"; } else { /// Analysis conversions don't fail if any operations fail to legalize, /// they are only interested in the operations that were successfully /// legalized. if (mode == OpConversionMode::Analysis) legalizableOps->insert(op); // If legalization succeeded, convert the types any of the blocks within // this operation. if (failed(convertBlockSignatures(rewriter, op))) return failure(); } return success(); } LogicalResult OperationConverter::convertOperations(ArrayRef ops, TypeConverter *typeConverter) { if (ops.empty()) return success(); ConversionTarget &target = opLegalizer.getTarget(); /// Compute the set of operations and blocks to convert. std::vector toConvert; for (auto *op : ops) { toConvert.emplace_back(op); for (auto ®ion : op->getRegions()) if (failed(computeConversionSet(region.getBlocks(), region.getLoc(), toConvert, &target))) return failure(); } // Convert each operation and discard rewrites on failure. ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter); for (auto *op : toConvert) if (failed(convert(rewriter, op))) return rewriter.getImpl().discardRewrites(), failure(); // Otherwise, the body conversion succeeded. Apply rewrites if this is not an // analysis conversion. if (mode == OpConversionMode::Analysis) rewriter.getImpl().discardRewrites(); else rewriter.getImpl().applyRewrites(); return success(); } //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// /// Remap an input of the original signature with a new set of types. The /// new types are appended to the new signature conversion. void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, ArrayRef types) { assert(!types.empty() && "expected valid types"); remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); addInputs(types); } /// Append new input types to the signature conversion, this should only be /// used if the new types are not intended to remap an existing input. void TypeConverter::SignatureConversion::addInputs(ArrayRef types) { assert(!types.empty() && "1->0 type remappings don't need to be added explicitly"); argTypes.append(types.begin(), types.end()); } /// Remap an input of the original signature with a range of types in the /// new signature. void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, unsigned newInputNo, unsigned newInputCount) { assert(!remappedInputs[origInputNo] && "input has already been remapped"); assert(newInputCount != 0 && "expected valid input count"); remappedInputs[origInputNo] = InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr}; } /// Remap an input of the original signature to another `replacementValue` /// value. This would make the signature converter drop this argument. void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, Value replacementValue) { assert(!remappedInputs[origInputNo] && "input has already been remapped"); remappedInputs[origInputNo] = InputMapping{origInputNo, /*size=*/0, replacementValue}; } /// This hooks allows for converting a type. LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl &results) { if (auto newT = convertType(t)) { results.push_back(newT); return success(); } return failure(); } /// Convert the given set of types, filling 'results' as necessary. This /// returns failure if the conversion of any of the types fails, success /// otherwise. LogicalResult TypeConverter::convertTypes(ArrayRef types, SmallVectorImpl &results) { for (auto type : types) if (failed(convertType(type, results))) return failure(); return success(); } /// Return true if the given type is legal for this type converter, i.e. the /// type converts to itself. bool TypeConverter::isLegal(Type type) { SmallVector results; return succeeded(convertType(type, results)) && results.size() == 1 && results.front() == type; } /// Return true if the inputs and outputs of the given function type are /// legal. bool TypeConverter::isSignatureLegal(FunctionType funcType) { return llvm::all_of( llvm::concat(funcType.getInputs(), funcType.getResults()), [this](Type type) { return isLegal(type); }); } /// This hook allows for converting a specific argument of a signature. LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type, SignatureConversion &result) { // Try to convert the given input type. SmallVector convertedTypes; if (failed(convertType(type, convertedTypes))) return failure(); // If this argument is being dropped, there is nothing left to do. if (convertedTypes.empty()) return success(); // Otherwise, add the new inputs. result.addInputs(inputNo, convertedTypes); return success(); } /// Create a default conversion pattern that rewrites the type signature of a /// FuncOp. namespace { struct FuncOpSignatureConversion : public OpConversionPattern { FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) : OpConversionPattern(ctx), converter(converter) {} /// Hook for derived classes to implement combined matching and rewriting. PatternMatchResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { FunctionType type = funcOp.getType(); // Convert the original function arguments. TypeConverter::SignatureConversion result(type.getNumInputs()); for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) return matchFailure(); // Convert the original function results. SmallVector convertedResults; if (failed(converter.convertTypes(type.getResults(), convertedResults))) return matchFailure(); // Update the function signature in-place. rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(FunctionType::get(result.getConvertedTypes(), convertedResults, funcOp.getContext())); rewriter.applySignatureConversion(&funcOp.getBody(), result); }); return matchSuccess(); } /// The type converter to use when rewriting the signature. TypeConverter &converter; }; } // end anonymous namespace void mlir::populateFuncOpTypeConversionPattern( OwningRewritePatternList &patterns, MLIRContext *ctx, TypeConverter &converter) { patterns.insert(ctx, converter); } /// This function converts the type signature of the given block, by invoking /// 'convertSignatureArg' for each argument. This function should return a valid /// conversion for the signature on success, None otherwise. auto TypeConverter::convertBlockSignature(Block *block) -> Optional { SignatureConversion conversion(block->getNumArguments()); for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) if (failed(convertSignatureArg(i, block->getArgument(i).getType(), conversion))) return llvm::None; return conversion; } //===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// /// Register a legality action for the given operation. void ConversionTarget::setOpAction(OperationName op, LegalizationAction action) { legalOperations[op] = {action, /*isRecursivelyLegal=*/false}; } /// Register a legality action for the given dialects. void ConversionTarget::setDialectAction(ArrayRef dialectNames, LegalizationAction action) { for (StringRef dialect : dialectNames) legalDialects[dialect] = action; } /// Get the legality action for the given operation. auto ConversionTarget::getOpAction(OperationName op) const -> Optional { Optional info = getOpInfo(op); return info ? info->action : Optional(); } /// If the given operation instance is legal on this target, a structure /// containing legality information is returned. If the operation is not legal, /// None is returned. auto ConversionTarget::isLegal(Operation *op) const -> Optional { Optional info = getOpInfo(op->getName()); if (!info) return llvm::None; // Returns true if this operation instance is known to be legal. auto isOpLegal = [&] { // Handle dynamic legality. if (info->action == LegalizationAction::Dynamic) { // Check for callbacks on the operation or dialect. auto opFn = opLegalityFns.find(op->getName()); if (opFn != opLegalityFns.end()) return opFn->second(op); auto dialectFn = dialectLegalityFns.find(op->getName().getDialect()); if (dialectFn != dialectLegalityFns.end()) return dialectFn->second(op); // Otherwise, invoke the hook on the derived instance. return isDynamicallyLegal(op); } // Otherwise, the operation is only legal if it was marked 'Legal'. return info->action == LegalizationAction::Legal; }; if (!isOpLegal()) return llvm::None; // This operation is legal, compute any additional legality information. LegalOpDetails legalityDetails; if (info->isRecursivelyLegal) { auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); if (legalityFnIt != opRecursiveLegalityFns.end()) legalityDetails.isRecursivelyLegal = legalityFnIt->second(op); else legalityDetails.isRecursivelyLegal = true; } return legalityDetails; } /// Set the dynamic legality callback for the given operation. void ConversionTarget::setLegalityCallback( OperationName name, const DynamicLegalityCallbackFn &callback) { assert(callback && "expected valid legality callback"); opLegalityFns[name] = callback; } /// Set the recursive legality callback for the given operation and mark the /// operation as recursively legal. void ConversionTarget::markOpRecursivelyLegal( OperationName name, const DynamicLegalityCallbackFn &callback) { auto infoIt = legalOperations.find(name); assert(infoIt != legalOperations.end() && infoIt->second.action != LegalizationAction::Illegal && "expected operation to already be marked as legal"); infoIt->second.isRecursivelyLegal = true; if (callback) opRecursiveLegalityFns[name] = callback; else opRecursiveLegalityFns.erase(name); } /// Set the dynamic legality callback for the given dialects. void ConversionTarget::setLegalityCallback( ArrayRef dialects, const DynamicLegalityCallbackFn &callback) { assert(callback && "expected valid legality callback"); for (StringRef dialect : dialects) dialectLegalityFns[dialect] = callback; } /// Get the legalization information for the given operation. auto ConversionTarget::getOpInfo(OperationName op) const -> Optional { // Check for info for this specific operation. auto it = legalOperations.find(op); if (it != legalOperations.end()) return it->second; // Otherwise, default to checking on the parent dialect. auto dialectIt = legalDialects.find(op.getDialect()); if (dialectIt != legalDialects.end()) return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false}; return llvm::None; } //===----------------------------------------------------------------------===// // Op Conversion Entry Points //===----------------------------------------------------------------------===// /// Apply a partial conversion on the given operations, and all nested /// operations. This method converts as many operations to the target as /// possible, ignoring operations that failed to legalize. LogicalResult mlir::applyPartialConversion( ArrayRef ops, ConversionTarget &target, const OwningRewritePatternList &patterns, TypeConverter *converter) { OperationConverter opConverter(target, patterns, OpConversionMode::Partial); return opConverter.convertOperations(ops, converter); } LogicalResult mlir::applyPartialConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, TypeConverter *converter) { return applyPartialConversion(llvm::makeArrayRef(op), target, patterns, converter); } /// Apply a complete conversion on the given operations, and all nested /// operations. This method will return failure if the conversion of any /// operation fails. LogicalResult mlir::applyFullConversion(ArrayRef ops, ConversionTarget &target, const OwningRewritePatternList &patterns, TypeConverter *converter) { OperationConverter opConverter(target, patterns, OpConversionMode::Full); return opConverter.convertOperations(ops, converter); } LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, TypeConverter *converter) { return applyFullConversion(llvm::makeArrayRef(op), target, patterns, converter); } /// Apply an analysis conversion on the given operations, and all nested /// operations. This method analyzes which operations would be successfully /// converted to the target if a conversion was applied. All operations that /// were found to be legalizable to the given 'target' are placed within the /// provided 'convertedOps' set; note that no actual rewrites are applied to the /// operations on success and only pre-existing operations are added to the set. LogicalResult mlir::applyAnalysisConversion( ArrayRef ops, ConversionTarget &target, const OwningRewritePatternList &patterns, DenseSet &convertedOps, TypeConverter *converter) { OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, &convertedOps); return opConverter.convertOperations(ops, converter); } LogicalResult mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, const OwningRewritePatternList &patterns, DenseSet &convertedOps, TypeConverter *converter) { return applyAnalysisConversion(llvm::makeArrayRef(op), target, patterns, convertedOps, converter); }