summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/ViewFunctionGraph.cpp
blob: 1f2ab69409e4b382202b955800c5c96e02dad816 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//===- ViewFunctionGraph.cpp - View/write graphviz graphs -----------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "mlir/Transforms/ViewFunctionGraph.h"
#include "mlir/IR/FunctionGraphTraits.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace llvm {

// Specialize DOTGraphTraits to produce more readable output.
template <> struct DOTGraphTraits<Function *> : public DefaultDOTGraphTraits {
  using DefaultDOTGraphTraits::DefaultDOTGraphTraits;

  static std::string getNodeLabel(Block *Block, Function *);
};

std::string DOTGraphTraits<Function *>::getNodeLabel(Block *Block, Function *) {
  // Reuse the print output for the node labels.
  std::string outStreamStr;
  raw_string_ostream os(outStreamStr);
  Block->print(os);
  std::string &outStr = os.str();

  if (outStr[0] == '\n')
    outStr.erase(outStr.begin());

  // Process string output to left justify the block.
  for (unsigned i = 0; i != outStr.length(); ++i) {
    if (outStr[i] == '\n') {
      outStr[i] = '\\';
      outStr.insert(outStr.begin() + i + 1, 'l');
    }
  }

  return outStr;
}

} // end namespace llvm

void mlir::viewGraph(Function &function, const llvm::Twine &name,
                     bool shortNames, const llvm::Twine &title,
                     llvm::GraphProgram::Name program) {
  llvm::ViewGraph(&function, name, shortNames, title, program);
}

llvm::raw_ostream &mlir::writeGraph(llvm::raw_ostream &os, Function &function,
                                    bool shortNames, const llvm::Twine &title) {
  return llvm::WriteGraph(os, &function, shortNames, title);
}

void mlir::Function::viewGraph() {
  ::mlir::viewGraph(*this, llvm::Twine("cfgfunc ") + getName().str());
}

namespace {
struct PrintCFGPass : public FunctionPass<PrintCFGPass> {
  PrintCFGPass(llvm::raw_ostream &os = llvm::errs(), bool shortNames = false,
               const llvm::Twine &title = "")
      : os(os), shortNames(shortNames), title(title) {}
  void runOnFunction() {
    mlir::writeGraph(os, getFunction(), shortNames, title);
  }

private:
  llvm::raw_ostream &os;
  bool shortNames;
  const llvm::Twine &title;
};
} // namespace

FunctionPassBase *mlir::createPrintCFGGraphPass(llvm::raw_ostream &os,
                                                bool shortNames,
                                                const llvm::Twine &title) {
  return new PrintCFGPass(os, shortNames, title);
}

static PassRegistration<PrintCFGPass> pass("print-cfg-graph",
                                           "Print CFG graph per function");
OpenPOWER on IntegriCloud