diff options
Diffstat (limited to 'mlir/lib/Transforms/ViewOpGraph.cpp')
| -rw-r--r-- | mlir/lib/Transforms/ViewOpGraph.cpp | 80 |
1 files changed, 39 insertions, 41 deletions
diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 503a82bf82b..591562d0245 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -28,15 +28,17 @@ static llvm::cl::opt<int> elideIfLarger( llvm::cl::desc("Upper limit to emit elements attribute rather than elide"), llvm::cl::init(16)); +using namespace mlir; + namespace llvm { // Specialize GraphTraits to treat Block as a graph of Operations as nodes and // uses as edges. -template <> struct GraphTraits<mlir::Block *> { - using GraphType = mlir::Block *; - using NodeRef = mlir::Operation *; +template <> struct GraphTraits<Block *> { + using GraphType = Block *; + using NodeRef = Operation *; - using ChildIteratorType = mlir::UseIterator; + using ChildIteratorType = UseIterator; static ChildIteratorType child_begin(NodeRef n) { return ChildIteratorType(n); } @@ -46,49 +48,46 @@ template <> struct GraphTraits<mlir::Block *> { // Operation's destructor is private so use Operation* instead and use // mapped iterator. - static mlir::Operation *AddressOf(mlir::Operation &op) { return &op; } - using nodes_iterator = - mapped_iterator<mlir::Block::iterator, decltype(&AddressOf)>; - static nodes_iterator nodes_begin(mlir::Block *b) { + static Operation *AddressOf(Operation &op) { return &op; } + using nodes_iterator = mapped_iterator<Block::iterator, decltype(&AddressOf)>; + static nodes_iterator nodes_begin(Block *b) { return nodes_iterator(b->begin(), &AddressOf); } - static nodes_iterator nodes_end(mlir::Block *b) { + static nodes_iterator nodes_end(Block *b) { return nodes_iterator(b->end(), &AddressOf); } }; // Specialize DOTGraphTraits to produce more readable output. -template <> -struct DOTGraphTraits<mlir::Block *> : public DefaultDOTGraphTraits { +template <> struct DOTGraphTraits<Block *> : public DefaultDOTGraphTraits { using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - static std::string getNodeLabel(mlir::Operation *op, mlir::Block *); + static std::string getNodeLabel(Operation *op, Block *); }; -std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op, - mlir::Block *b) { +std::string DOTGraphTraits<Block *>::getNodeLabel(Operation *op, Block *b) { // Reuse the print output for the node labels. std::string ostr; raw_string_ostream os(ostr); os << op->getName() << "\n"; - if (!op->getLoc().isa<mlir::UnknownLoc>()) { + if (!op->getLoc().isa<UnknownLoc>()) { os << op->getLoc() << "\n"; } // Print resultant types - mlir::interleaveComma(op->getResultTypes(), os); + interleaveComma(op->getResultTypes(), os); os << "\n"; for (auto attr : op->getAttrs()) { os << '\n' << attr.first << ": "; // Always emit splat attributes. - if (attr.second.isa<mlir::SplatElementsAttr>()) { + if (attr.second.isa<SplatElementsAttr>()) { attr.second.print(os); continue; } // Elide "big" elements attributes. - auto elements = attr.second.dyn_cast<mlir::ElementsAttr>(); + auto elements = attr.second.dyn_cast<ElementsAttr>(); if (elements && elements.getNumElements() > elideIfLarger) { os << std::string(elements.getType().getRank(), '[') << "..." << std::string(elements.getType().getRank(), ']') << " : " @@ -96,7 +95,7 @@ std::string DOTGraphTraits<mlir::Block *>::getNodeLabel(mlir::Operation *op, continue; } - auto array = attr.second.dyn_cast<mlir::ArrayAttr>(); + auto array = attr.second.dyn_cast<ArrayAttr>(); if (array && static_cast<int64_t>(array.size()) > elideIfLarger) { os << "[...]"; continue; @@ -114,14 +113,14 @@ namespace { // PrintOpPass is simple pass to write graph per function. // Note: this is a module pass only to avoid interleaving on the same ostream // due to multi-threading over functions. -struct PrintOpPass : public mlir::ModulePass<PrintOpPass> { - explicit PrintOpPass(llvm::raw_ostream &os = llvm::errs(), - bool short_names = false, const llvm::Twine &title = "") +struct PrintOpPass : public ModulePass<PrintOpPass> { + explicit PrintOpPass(raw_ostream &os = llvm::errs(), bool short_names = false, + const Twine &title = "") : os(os), title(title.str()), short_names(short_names) {} - std::string getOpName(mlir::Operation &op) { - auto symbolAttr = op.getAttrOfType<mlir::StringAttr>( - mlir::SymbolTable::getSymbolAttrName()); + std::string getOpName(Operation &op) { + auto symbolAttr = + op.getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); if (symbolAttr) return symbolAttr.getValue(); ++unnamedOpCtr; @@ -129,22 +128,22 @@ struct PrintOpPass : public mlir::ModulePass<PrintOpPass> { } // Print all the ops in a module. - void processModule(mlir::ModuleOp module) { - for (mlir::Operation &op : module) { + void processModule(ModuleOp module) { + for (Operation &op : module) { // Modules may actually be nested, recurse on nesting. - if (auto nestedModule = llvm::dyn_cast<mlir::ModuleOp>(op)) { + if (auto nestedModule = dyn_cast<ModuleOp>(op)) { processModule(nestedModule); continue; } auto opName = getOpName(op); - for (mlir::Region ®ion : op.getRegions()) { + for (Region ®ion : op.getRegions()) { for (auto indexed_block : llvm::enumerate(region)) { // Suffix block number if there are more than 1 block. auto blockName = region.getBlocks().size() == 1 ? "" : ("__" + llvm::utostr(indexed_block.index())); llvm::WriteGraph(os, &indexed_block.value(), short_names, - llvm::Twine(title) + opName + blockName); + Twine(title) + opName + blockName); } } } @@ -153,29 +152,28 @@ struct PrintOpPass : public mlir::ModulePass<PrintOpPass> { void runOnModule() override { processModule(getModule()); } private: - llvm::raw_ostream &os; + raw_ostream &os; std::string title; int unnamedOpCtr = 0; bool short_names; }; } // namespace -void mlir::viewGraph(mlir::Block &block, const llvm::Twine &name, - bool shortNames, const llvm::Twine &title, - llvm::GraphProgram::Name program) { +void mlir::viewGraph(Block &block, const Twine &name, bool shortNames, + const Twine &title, llvm::GraphProgram::Name program) { llvm::ViewGraph(&block, name, shortNames, title, program); } -llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, mlir::Block &block, - bool shortNames, const llvm::Twine &title) { +raw_ostream &mlir::writeGraph(raw_ostream &os, Block &block, bool shortNames, + const Twine &title) { return llvm::WriteGraph(os, &block, shortNames, title); } -std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> -mlir::createPrintOpGraphPass(llvm::raw_ostream &os, bool shortNames, - const llvm::Twine &title) { +std::unique_ptr<OpPassBase<ModuleOp>> +mlir::createPrintOpGraphPass(raw_ostream &os, bool shortNames, + const Twine &title) { return std::make_unique<PrintOpPass>(os, shortNames, title); } -static mlir::PassRegistration<PrintOpPass> pass("print-op-graph", - "Print op graph per region"); +static PassRegistration<PrintOpPass> pass("print-op-graph", + "Print op graph per region"); |

