//===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements a generic pass for converting between MLIR dialects. // //===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Transforms/Utils.h" using namespace mlir; namespace mlir { namespace impl { // Implementation detail class of the DialectConversion pass. Performs // function-by-function conversions by creating new functions, filling them in // with converted blocks, updating the function attributes, and replacing the // old functions with the new ones in the module. class FunctionConversion { public: // Entry point. Uses hooks defined in `conversion` to obtain the list of // conversion patterns and to convert function and block argument types. // Converts the `module` in-place by replacing all existing functions with the // converted ones. static LogicalResult convert(DialectConversion *conversion, Module *module); private: // Constructs a FunctionConversion by storing the hooks. explicit FunctionConversion(DialectConversion *conversion) : dialectConversion(conversion) {} // Utility that looks up a list of value in the value remapping table. Returns // an empty vector if one of the values is not mapped yet. SmallVector lookupValues(Operation::operand_range operands); // Converts the given function to the dialect using hooks defined in // `dialectConversion`. Returns the converted function or `nullptr` on error. Function *convertFunction(Function *f); // Converts the given region starting from the entry block and following the // block successors. Returns the converted region or `nullptr` on error. template std::unique_ptr convertRegion(MLIRContext *context, Region *region, RegionParent *parent); // Converts an operation with successors. Extracts the converted operands // from `valueRemapping` and the converted blocks from `blockRemapping`, and // passes them to `converter->rewriteTerminator` function defined in the // pattern, together with `builder`. LogicalResult convertOpWithSuccessors(DialectOpConversion *converter, Operation *op, FuncBuilder &builder); // Converts an operation without successors. Extracts the converted operands // from `valueRemapping` and passes them to the `converter->rewrite` function // defined in the pattern, together with `builder`. LogicalResult convertOp(DialectOpConversion *converter, Operation *op, FuncBuilder &builder); // Converts a block by traversing its operations sequentially, looking for // the first pattern match and dispatching the operation conversion to // either `convertOp` or `convertOpWithSuccessors` depending on the presence // of successors. If there is no match, clones the operation. // // After converting operations, traverses the successor blocks unless they // have been visited already as indicated in `visitedBlocks`. LogicalResult convertBlock(Block *block, FuncBuilder &builder, llvm::DenseSet &visitedBlocks); // Converts the module as follows. // 1. Call `convertFunction` on each function of the module and collect the // mapping between old and new functions. // 2. Remap all function attributes in the new functions to point to the new // functions instead of the old ones. // 3. Replace old functions with the new in the module. LogicalResult run(Module *m); // Pointer to a specific dialect pass. DialectConversion *dialectConversion; // Set of known conversion patterns. llvm::DenseSet conversions; // Mapping between values(blocks) in the original function and in the new // function. BlockAndValueMapping mapping; }; } // end namespace impl } // end namespace mlir SmallVector impl::FunctionConversion::lookupValues(Operation::operand_range operands) { SmallVector remapped; remapped.reserve(llvm::size(operands)); for (Value *operand : operands) { Value *value = mapping.lookupOrNull(operand); if (!value) return {}; remapped.push_back(value); } return remapped; } LogicalResult impl::FunctionConversion::convertOpWithSuccessors( DialectOpConversion *converter, Operation *op, FuncBuilder &builder) { SmallVector destinations; destinations.reserve(op->getNumSuccessors()); SmallVector operands = lookupValues(op->getOperands()); assert((!operands.empty() || op->getNumOperands() == 0) && "converting op before ops defining its operands"); SmallVector, 2> operandsPerDestination; unsigned numSuccessorOperands = 0; for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i) numSuccessorOperands += op->getNumSuccessorOperands(i); unsigned seen = 0; unsigned firstSuccessorOperand = op->getNumOperands() - numSuccessorOperands; for (unsigned i = 0, e = op->getNumSuccessors(); i < e; ++i) { Block *successor = mapping.lookupOrNull(op->getSuccessor(i)); assert(successor && "block was not remapped"); destinations.push_back(successor); unsigned n = op->getNumSuccessorOperands(i); operandsPerDestination.push_back( llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n)); seen += n; } converter->rewriteTerminator( op, llvm::makeArrayRef(operands.data(), operands.data() + firstSuccessorOperand), destinations, operandsPerDestination, builder); return success(); } LogicalResult impl::FunctionConversion::convertOp(DialectOpConversion *converter, Operation *op, FuncBuilder &builder) { auto operands = lookupValues(op->getOperands()); assert((!operands.empty() || op->getNumOperands() == 0) && "converting op before ops defining its operands"); auto results = converter->rewrite(op, operands, builder); if (results.size() != op->getNumResults()) return (op->emitError("rewriting produced a different number of results"), failure()); for (unsigned i = 0, e = results.size(); i < e; ++i) mapping.map(op->getResult(i), results[i]); return success(); } LogicalResult impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder, llvm::DenseSet &visitedBlocks) { // First, add the current block to the list of visited blocks. visitedBlocks.insert(block); // Setup the builder to the insert to the converted block. builder.setInsertionPointToStart(mapping.lookupOrNull(block)); // Iterate over ops and convert them. for (Operation &op : *block) { // Find the first matching conversion and apply it. bool converted = false; for (auto *conversion : conversions) { // Ignore patterns that are for the wrong root or are impossible to match. if (conversion->getRootKind() != op.getName() || conversion->getBenefit().isImpossibleToMatch()) continue; if (!conversion->match(&op)) continue; if (op.getNumSuccessors() != 0) { if (failed(convertOpWithSuccessors(conversion, &op, builder))) return failure(); } else if (failed(convertOp(conversion, &op, builder))) { return failure(); } converted = true; break; } // If there is no conversion provided for the op, clone the op and convert // its regions, if any. if (!converted) { auto *newOp = builder.cloneWithoutRegions(op, mapping); for (int i = 0, e = op.getNumRegions(); i < e; ++i) { auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op); newOp->getRegion(i).takeBody(*newRegion); } } } // Recurse to children unless they have been already visited. for (Block *succ : block->getSuccessors()) { if (visitedBlocks.count(succ) != 0) continue; if (failed(convertBlock(succ, builder, visitedBlocks))) return failure(); } return success(); } template std::unique_ptr impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region, RegionParent *parent) { assert(region && "expected a region"); auto newRegion = llvm::make_unique(parent); if (region->empty()) return newRegion; auto emitError = [context](llvm::Twine f) -> std::unique_ptr { context->emitError(UnknownLoc::get(context), f.str()); return nullptr; }; // Create new blocks and convert their arguments. for (Block &block : *region) { auto *newBlock = new Block; newRegion->push_back(newBlock); mapping.map(&block, newBlock); for (auto *arg : block.getArguments()) { auto convertedType = dialectConversion->convertType(arg->getType()); if (!convertedType) return emitError("could not convert block argument type"); newBlock->addArgument(convertedType); mapping.map(arg, *newBlock->args_rbegin()); } } // Start a DFS-order traversal of the CFG to make sure defs are converted // before uses in dominated blocks. llvm::DenseSet visitedBlocks; FuncBuilder builder(&newRegion->front()); if (failed(convertBlock(®ion->front(), builder, visitedBlocks))) return nullptr; // If some blocks are not reachable through successor chains, they should have // been removed by the DCE before this. if (visitedBlocks.size() != std::distance(region->begin(), region->end())) return emitError("unreachable blocks were not converted"); return newRegion; } Function *impl::FunctionConversion::convertFunction(Function *f) { assert(f && "expected function"); MLIRContext *context = f->getContext(); auto emitError = [context](llvm::Twine f) -> Function * { context->emitError(UnknownLoc::get(context), f.str()); return nullptr; }; // Create a new function with argument types and result types converted. Wrap // it into a unique_ptr to make sure it is cleaned up in case of error. SmallVector newFunctionArgAttrs; Type newFunctionType = dialectConversion->convertFunctionSignatureType( f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs); if (!newFunctionType) return emitError("could not convert function type"); auto newFunction = llvm::make_unique( f->getLoc(), f->getName().strref(), newFunctionType.cast(), f->getAttrs(), newFunctionArgAttrs); // Return early if the function has no blocks. if (f->getBlocks().empty()) return newFunction.release(); auto newBody = convertRegion(context, &f->getBody(), f); if (!newBody) return emitError("could not convert function body"); newFunction->getBody().takeBody(*newBody); return newFunction.release(); } LogicalResult impl::FunctionConversion::convert(DialectConversion *conversion, Module *module) { return impl::FunctionConversion(conversion).run(module); } LogicalResult impl::FunctionConversion::run(Module *module) { if (!module) return failure(); MLIRContext *context = module->getContext(); conversions = dialectConversion->initConverters(context); // Convert the functions but don't add them to the module yet to avoid // converted functions to be converted again. SmallVector originalFuncs, convertedFuncs; DenseMap functionAttrRemapping; originalFuncs.reserve(module->getFunctions().size()); for (auto &func : *module) originalFuncs.push_back(&func); convertedFuncs.reserve(module->getFunctions().size()); for (auto *func : originalFuncs) { Function *converted = convertFunction(func); if (!converted) return failure(); auto origFuncAttr = FunctionAttr::get(func, context); auto convertedFuncAttr = FunctionAttr::get(converted, context); convertedFuncs.push_back(converted); functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr}); } // Remap function attributes in the converted functions (they are not yet in // the module). Original functions will disappear anyway so there is no // need to remap attributes in them. for (const auto &funcPair : functionAttrRemapping) { remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping); } // Remove original functions from the module, then insert converted // functions. The order is important to avoid name collisions. for (auto &func : originalFuncs) func->erase(); for (auto *func : convertedFuncs) module->getFunctions().push_back(func); return success(); } // Create a function type with arguments and results converted, and argument // attributes passed through. FunctionType DialectConversion::convertFunctionSignatureType( FunctionType type, ArrayRef argAttrs, SmallVectorImpl &convertedArgAttrs) { SmallVector arguments; SmallVector results; arguments.reserve(type.getNumInputs()); for (auto t : type.getInputs()) arguments.push_back(convertType(t)); results.reserve(type.getNumResults()); for (auto t : type.getResults()) results.push_back(convertType(t)); // Note this will cause an extra allocation only if we need // to grow the caller-provided resulting attribute vector. convertedArgAttrs.reserve(arguments.size()); for (auto attr : argAttrs) convertedArgAttrs.push_back(attr); return FunctionType::get(arguments, results, type.getContext()); } LogicalResult DialectConversion::convert(Module *m) { return impl::FunctionConversion::convert(this, m); }