//===- ConvertToCFG.cpp - ML function to CFG function conversion ----------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file implements APIs to convert ML functions into CFG functions. // //===----------------------------------------------------------------------===// #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StmtVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/CommandLine.h" using namespace mlir; //===----------------------------------------------------------------------===// // ML function converter //===----------------------------------------------------------------------===// namespace { // Generates CFG function equivalent to the given ML function. class FunctionConverter : public StmtVisitor { public: FunctionConverter(Function *cfgFunc) : cfgFunc(cfgFunc), builder(cfgFunc) {} Function *convert(Function *mlFunc); void visitForStmt(ForStmt *forStmt); void visitIfStmt(IfStmt *ifStmt); void visitOperationInst(OperationInst *opStmt); private: Value *getConstantIndexValue(int64_t value); void visitStmtBlock(StmtBlock *stmtBlock); Value *buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range values); Function *cfgFunc; FuncBuilder builder; // Mapping between original Values and lowered Values. llvm::DenseMap valueRemapping; }; } // end anonymous namespace // Return a vector of OperationInst's arguments as Values. For each // statement operands, represented as Value, lookup its Value conterpart in // the valueRemapping table. static llvm::SmallVector operandsAs(Statement *opStmt, const llvm::DenseMap &valueRemapping) { llvm::SmallVector operands; for (const Value *operand : opStmt->getOperands()) { assert(valueRemapping.count(operand) != 0 && "operand is not defined"); operands.push_back(valueRemapping.lookup(operand)); } return operands; } // Convert an operation statement into an operation instruction. // // The operation description (name, number and types of operands or results) // remains the same but the values must be updated to be Values. Update the // mapping Value->Value as the conversion is performed. The operation // instruction is appended to current block (end of SESE region). void FunctionConverter::visitOperationInst(OperationInst *opStmt) { // Set up basic operation state (context, name, operands). OperationState state(cfgFunc->getContext(), opStmt->getLoc(), opStmt->getName()); state.addOperands(operandsAs(opStmt, valueRemapping)); // Set up operation return types. The corresponding Values will become // available after the operation is created. state.addTypes(functional::map( [](Value *result) { return result->getType(); }, opStmt->getResults())); // Copy attributes. for (auto attr : opStmt->getAttrs()) { state.addAttribute(attr.first.strref(), attr.second); } auto opInst = builder.createOperation(state); // Make results of the operation accessible to the following operations // through remapping. assert(opInst->getNumResults() == opStmt->getNumResults()); for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) { valueRemapping.insert( std::make_pair(opStmt->getResult(i), opInst->getResult(i))); } } // Create a Value for the given integer constant of index type. Value *FunctionConverter::getConstantIndexValue(int64_t value) { auto op = builder.create(builder.getUnknownLoc(), value); return op->getResult(); } // Visit all statements in the given statement block. void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) { for (auto &stmt : *stmtBlock) this->visit(&stmt); } // Given a range of values, emit the code that reduces them with "min" or "max" // depending on the provided comparison predicate. The predicate defines which // comparison to perform, "lt" for "min", "gt" for "max" and is used for the // `cmpi` operation followed by the `select` operation: // // %cond = cmpi "predicate" %v0, %v1 // %result = select %cond, %v0, %v1 // // Multiple values are scanned in a linear sequence. This creates a data // dependences that wouldn't exist in a tree reduction, but is easier to // recognize as a reduction by the subsequent passes. Value *FunctionConverter::buildMinMaxReductionSeq( Location loc, CmpIPredicate predicate, llvm::iterator_range values) { assert(!llvm::empty(values) && "empty min/max chain"); auto valueIt = values.begin(); Value *value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { auto cmpOp = builder.create(loc, predicate, value, *valueIt); auto selectOp = builder.create(loc, cmpOp->getResult(), value, *valueIt); value = selectOp->getResult(); } return value; } // Convert a "for" loop to a flow of basic blocks. // // Create an SESE region for the loop (including its body) and append it to the // end of the current region. The loop region consists of the initialization // block that sets up the initial value of the loop induction variable (%iv) and // computes the loop bounds that are loop-invariant in MLFunctions; the // condition block that checks the exit condition of the loop; the body SESE // region; and the end block that post-dominates the loop. The end block of the // loop becomes the new end of the current SESE region. The body of the loop is // constructed recursively after starting a new region (it may be, for example, // a nested loop). Induction variable modification is appended to the body SESE // region that always loops back to the condition block. // // +--------------------------------+ // | | // | | // | br init | // +--------------------------------+ // | // v // +--------------------------------+ // | init: | // | | // | | // | br cond(%iv) | // +--------------------------------+ // | // -------| | // | v v // | +--------------------------------+ // | | cond(%iv): | // | | | // | | cond_br %r, body, end | // | +--------------------------------+ // | | | // | | -------------| // | v | // | +--------------------------------+ | // | | body: | | // | | | | // | | <...> | | // | +--------------------------------+ | // | | | // | ... | // | | | // | v | // | +--------------------------------+ | // | | body-end: | | // | | | | // | | %new_iv = | | // | | br cond(%new_iv) | | // | +--------------------------------+ | // | | | // |----------- |-------------------- // v // +--------------------------------+ // | end: | // | | // | | // +--------------------------------+ // void FunctionConverter::visitForStmt(ForStmt *forStmt) { // First, store the loop insertion location so that we can go back to it after // creating the new blocks (block creation updates the insertion point). BasicBlock *loopInsertionPoint = builder.getInsertionBlock(); // Create blocks so that they appear in more human-readable order in the // output. BasicBlock *loopInitBlock = builder.createBlock(); BasicBlock *loopConditionBlock = builder.createBlock(); BasicBlock *loopBodyFirstBlock = builder.createBlock(); // At the loop insertion location, branch immediately to the loop init block. builder.setInsertionPointToEnd(loopInsertionPoint); builder.create(builder.getUnknownLoc(), loopInitBlock); // The loop condition block has an argument for loop induction variable. // Create it upfront and make the loop induction variable -> basic block // argument remapping available to the following instructions. ForStatement // is-a Value corresponding to the loop induction variable. builder.setInsertionPointToEnd(loopConditionBlock); Value *iv = loopConditionBlock->addArgument(builder.getIndexType()); valueRemapping.insert(std::make_pair(forStmt, iv)); // Recursively construct loop body region. // Walking manually because we need custom logic before and after traversing // the list of children. builder.setInsertionPointToEnd(loopBodyFirstBlock); visitStmtBlock(forStmt->getBody()); // Builder point is currently at the last block of the loop body. Append the // induction variable stepping to this block and branch back to the exit // condition block. Construct an affine map f : (x -> x+step) and apply this // map to the induction variable. auto affStep = builder.getAffineConstantExpr(forStmt->getStep()); auto affDim = builder.getAffineDimExpr(0); auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {}); auto stepOp = builder.create(forStmt->getLoc(), affStepMap, iv); Value *nextIvValue = stepOp->getResult(0); builder.create(builder.getUnknownLoc(), loopConditionBlock, nextIvValue); // Create post-loop block here so that it appears after all loop body blocks. BasicBlock *postLoopBlock = builder.createBlock(); builder.setInsertionPointToEnd(loopInitBlock); // Compute loop bounds using affine_apply after remapping its operands. auto remapOperands = [this](const Value *value) -> Value * { return valueRemapping.lookup(value); }; auto operands = functional::map(remapOperands, forStmt->getLowerBoundOperands()); auto lbAffineApply = builder.create( forStmt->getLoc(), forStmt->getLowerBoundMap(), operands); Value *lowerBound = buildMinMaxReductionSeq( forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults()); operands = functional::map(remapOperands, forStmt->getUpperBoundOperands()); auto ubAffineApply = builder.create( forStmt->getLoc(), forStmt->getUpperBoundMap(), operands); Value *upperBound = buildMinMaxReductionSeq( forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults()); builder.create(builder.getUnknownLoc(), loopConditionBlock, lowerBound); builder.setInsertionPointToEnd(loopConditionBlock); auto comparisonOp = builder.create( forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound); auto comparisonResult = comparisonOp->getResult(); builder.create(builder.getUnknownLoc(), comparisonResult, loopBodyFirstBlock, ArrayRef(), postLoopBlock, ArrayRef()); // Finally, make sure building can continue by setting the post-loop block // (end of loop SESE region) as the insertion point. builder.setInsertionPointToEnd(postLoopBlock); } // Convert an "if" statement into a flow of basic blocks. // // Create an SESE region for the if statement (including its "then" and optional // "else" statement blocks) and append it to the end of the current region. The // conditional region consists of a sequence of condition-checking blocks that // implement the short-circuit scheme, followed by a "then" SESE region and an // "else" SESE region, and the continuation block that post-dominates all blocks // of the "if" statement. The flow of blocks that correspond to the "then" and // "else" clauses are constructed recursively, enabling easy nesting of "if" // statements and if-then-else-if chains. // // +--------------------------------+ // | | // | | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | // | cond_br %c, %next, %else | // +--------------------------------+ // | | // | --------------| // v | // +--------------------------------+ | // | next: | | // | | | // | cond_br %c, %next2, %else | | // +--------------------------------+ | // | | | // ... --------------| // | | // v | // +--------------------------------+ | // | last: | | // | | | // | cond_br %c, %then, %else | | // +--------------------------------+ | // | | | // | --------------| // v | // +--------------------------------+ | // | then: | | // | | | // +--------------------------------+ | // | | // ... | // | | // v | // +--------------------------------+ | // | then_end: | | // | | | // | br continue | | // +--------------------------------+ | // | | // |---------- |------------- // | V // | +--------------------------------+ // | | else: | // | | | // | +--------------------------------+ // | | // | ... // | | // | v // | +--------------------------------+ // | | else_end: | // | | | // | | br continue | // | +--------------------------------+ // | | // ------| | // v v // +--------------------------------+ // | continue: | // | | // | | // +--------------------------------+ // void FunctionConverter::visitIfStmt(IfStmt *ifStmt) { assert(ifStmt != nullptr); auto integerSet = ifStmt->getCondition().getIntegerSet(); // Create basic blocks for the 'then' block and for the 'else' block. // Although 'else' block may be empty in absence of an 'else' clause, create // it anyway for the sake of consistency and output IR readability. Also // create extra blocks for condition checking to prepare for short-circuit // logic: conditions in the 'if' statement are conjunctive, so we can jump to // the false branch as soon as one condition fails. `cond_br` requires // another block as a target when the condition is true, and that block will // contain the next condition. BasicBlock *ifInsertionBlock = builder.getInsertionBlock(); SmallVector ifConditionExtraBlocks; unsigned numConstraints = integerSet.getNumConstraints(); ifConditionExtraBlocks.reserve(numConstraints - 1); for (unsigned i = 0, e = numConstraints - 1; i < e; ++i) { ifConditionExtraBlocks.push_back(builder.createBlock()); } BasicBlock *thenBlock = builder.createBlock(); BasicBlock *elseBlock = builder.createBlock(); builder.setInsertionPointToEnd(ifInsertionBlock); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain // the resulting value. Perform the equality or the greater-than-or-equality // test between this value and zero depending on the equality flag of the // condition. If the test fails, jump immediately to the false branch, which // may be the else block if it is present or the continuation block otherwise. // If the test succeeds, jump to the next block testing testing the next // conjunct of the condition in the similar way. When all conjuncts have been // handled, jump to the 'then' block instead. Value *zeroConstant = getConstantIndexValue(0); ifConditionExtraBlocks.push_back(thenBlock); for (auto tuple : llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(), ifConditionExtraBlocks)) { AffineExpr constraintExpr = std::get<0>(tuple); bool isEquality = std::get<1>(tuple); BasicBlock *nextBlock = std::get<2>(tuple); // Build and apply an affine map. auto affineMap = builder.getAffineMap(integerSet.getNumDims(), integerSet.getNumSymbols(), constraintExpr, {}); auto affineApplyOp = builder.create( ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping)); Value *affResult = affineApplyOp->getResult(0); // Compare the result of the apply and branch. auto comparisonOp = builder.create( ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult, zeroConstant); builder.create(ifStmt->getLoc(), comparisonOp->getResult(), nextBlock, /*trueArgs*/ ArrayRef(), elseBlock, /*falseArgs*/ ArrayRef()); builder.setInsertionPointToEnd(nextBlock); } ifConditionExtraBlocks.pop_back(); // Recursively traverse the 'then' block. builder.setInsertionPointToEnd(thenBlock); visitStmtBlock(ifStmt->getThen()); BasicBlock *lastThenBlock = builder.getInsertionBlock(); // Recursively traverse the 'else' block if present. builder.setInsertionPointToEnd(elseBlock); if (ifStmt->hasElse()) visitStmtBlock(ifStmt->getElse()); BasicBlock *lastElseBlock = builder.getInsertionBlock(); // Create the continuation block here so that it appears lexically after the // 'then' and 'else' blocks, branch from end of 'then' and 'else' SESE regions // to the continuation block. BasicBlock *continuationBlock = builder.createBlock(); builder.setInsertionPointToEnd(lastThenBlock); builder.create(ifStmt->getLoc(), continuationBlock); builder.setInsertionPointToEnd(lastElseBlock); builder.create(ifStmt->getLoc(), continuationBlock); // Make sure building can continue by setting up the continuation block as the // insertion point. builder.setInsertionPointToEnd(continuationBlock); } // Entry point of the function convertor. // // Conversion is performed by recursively visiting statements of a Function. // It reasons in terms of single-entry single-exit (SESE) regions that are not // materialized in the code. Instead, the pointer to the last block of the // region is maintained throughout the conversion as the insertion point of the // IR builder since we never change the first block after its creation. "Block" // statements such as loops and branches create new SESE regions for their // bodies, and surround them with additional basic blocks for the control flow. // Individual operations are simply appended to the end of the last basic block // of the current region. The SESE invariant allows us to easily handle nested // structures of arbitrary complexity. // // During the conversion, we maintain a mapping between the Values present in // the original function and their Value images in the function under // construction. When an Value is used, it gets replaced with the // corresponding Value that has been defined previously. The value flow // starts with function arguments converted to basic block arguments. Function *FunctionConverter::convert(Function *mlFunc) { auto outerBlock = builder.createBlock(); // CFGFunctions do not have explicit arguments but use the arguments to the // first basic block instead. Create those from the Function arguments and // set up the value remapping. outerBlock->addArguments(mlFunc->getType().getInputs()); assert(mlFunc->getNumArguments() == outerBlock->getNumArguments()); for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) { const Value *mlArgument = mlFunc->getArgument(i); Value *cfgArgument = outerBlock->getArgument(i); valueRemapping.insert(std::make_pair(mlArgument, cfgArgument)); } // Convert statements in order. for (auto &stmt : *mlFunc->getBody()) { visit(&stmt); } return cfgFunc; } //===----------------------------------------------------------------------===// // Module converter //===----------------------------------------------------------------------===// namespace { // ModuleConverter class does CFG conversion for the whole module. class ModuleConverter : public ModulePass { public: explicit ModuleConverter() : ModulePass(&ModuleConverter::passID) {} PassResult runOnModule(Module *m) override; static char passID; private: // Generates CFG functions for all ML functions in the module. void convertMLFunctions(); // Generates CFG function for the given ML function. Function *convert(Function *mlFunc); // Replaces all ML function references in the module // with references to the generated CFG functions. void replaceReferences(); // Replaces function references in the given function. void replaceReferences(Function *cfgFunc); // Replaces MLFunctions with their CFG counterparts in the module. void replaceFunctions(); // Map from ML functions to generated CFG functions. llvm::DenseMap generatedFuncs; Module *module = nullptr; }; } // end anonymous namespace char ModuleConverter::passID = 0; // Iterates over all functions in the module generating CFG functions // equivalent to ML functions and replacing references to ML functions // with references to the generated ML functions. The names of the converted // functions match those of the original functions to avoid breaking any // external references to the current module. Therefore, converted functions // are added to the module at the end of the pass, after removing the original // functions to avoid name clashes. Conversion procedure has access to the // module as member of ModuleConverter and must not rely on the converted // function to belong to the module. PassResult ModuleConverter::runOnModule(Module *m) { module = m; convertMLFunctions(); replaceReferences(); replaceFunctions(); return success(); } void ModuleConverter::convertMLFunctions() { for (Function &fn : *module) { if (fn.isML()) generatedFuncs[&fn] = convert(&fn); } } // Creates CFG function equivalent to the given ML function. Function *ModuleConverter::convert(Function *mlFunc) { // Use the same name as for ML function; do not add the converted function to // the module yet to avoid collision. auto name = mlFunc->getName().str(); auto *cfgFunc = new Function(Function::Kind::CFGFunc, mlFunc->getLoc(), name, mlFunc->getType(), mlFunc->getAttrs()); // Generates the body of the CFG function. return FunctionConverter(cfgFunc).convert(mlFunc); } // Replace references to MLFunctions with the references to the converted // CFGFunctions. Since this all MLFunctions are converted at this point, it is // unnecessary to replace references in the MLFunctions that are going to be // removed anyway. However, it is necessary to replace the references in the // converted CFGFunctions that have not been added to the module yet. void ModuleConverter::replaceReferences() { // Build the remapping between function attributes pointing to ML functions // and the newly created function attributes pointing to the converted CFG // functions. llvm::DenseMap remappingTable; for (const Function &fn : *module) { if (!fn.isML()) continue; Function *convertedFunc = generatedFuncs.lookup(&fn); assert(convertedFunc && "ML function was not converted"); MLIRContext *context = module->getContext(); auto mlFuncAttr = FunctionAttr::get(&fn, context); auto cfgFuncAttr = FunctionAttr::get(convertedFunc, module->getContext()); remappingTable.insert({mlFuncAttr, cfgFuncAttr}); } // Remap in existing functions. remapFunctionAttrs(*module, remappingTable); // Remap in generated functions. for (auto pair : generatedFuncs) { remapFunctionAttrs(*pair.second, remappingTable); } } // Replace the value of a function attribute named "name" attached to the // operation "op" and containing a Function-typed value with the result of // converting "func" to a Function. static inline void replaceMLFunctionAttr( OperationInst &op, Identifier name, const Function *func, const llvm::DenseMap &generatedFuncs) { if (!func->isML()) return; Builder b(op.getContext()); auto *cfgFunc = generatedFuncs.lookup(func); op.setAttr(name, b.getFunctionAttr(cfgFunc)); } // The CFG and ML functions have the same name. First, erase the Function. // Then insert the Function at the same place. void ModuleConverter::replaceFunctions() { for (auto pair : generatedFuncs) { auto &functions = module->getFunctions(); auto it = functions.erase(pair.first); functions.insert(it, pair.second); } } //===----------------------------------------------------------------------===// // Entry point method //===----------------------------------------------------------------------===// /// Replaces all ML functions in the module with equivalent CFG functions. /// Function references are appropriately patched to refer to the newly /// generated CFG functions. Converted functions have the same names as the /// original functions to preserve module linking. ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); } static PassRegistration pass("convert-to-cfg", "Convert all ML functions in the module to CFG ones");