summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Analysis/OpStats.cpp
blob: f4c509a51323e6d514c31403f28fbefa21aaee61 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//===- OpStats.cpp - Prints stats of operations in module -----------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "mlir/IR/Function.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;

namespace {
struct PrintOpStatsPass : public FunctionPass, StmtWalker<PrintOpStatsPass> {
  explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
      : FunctionPass(&PrintOpStatsPass::passID), os(os) {}

  // Prints the resultant operation stats post iterating over the module.
  PassResult runOnModule(Module *m) override;

  // Process CFG function considering the instructions in basic blocks.
  PassResult runOnCFGFunction(Function *function) override;

  // Process ML functions and operation statments in ML functions.
  PassResult runOnMLFunction(Function *function) override;
  void visitOperationInst(OperationInst *stmt);

  // Print summary of op stats.
  void printSummary();

  static char passID;

private:
  llvm::StringMap<int64_t> opCount;

  llvm::raw_ostream &os;
};
} // namespace

char PrintOpStatsPass::passID = 0;

PassResult PrintOpStatsPass::runOnModule(Module *m) {
  auto result = FunctionPass::runOnModule(m);
  if (!result)
    printSummary();
  return result;
}

PassResult PrintOpStatsPass::runOnCFGFunction(Function *function) {
  for (const auto &bb : *function)
    for (const auto &inst : bb)
      if (auto *op = dyn_cast<OperationInst>(&inst))
        ++opCount[op->getName().getStringRef()];
  return success();
}

void PrintOpStatsPass::visitOperationInst(OperationInst *stmt) {
  ++opCount[stmt->getName().getStringRef()];
}

PassResult PrintOpStatsPass::runOnMLFunction(Function *function) {
  walk(function);
  return success();
}

void PrintOpStatsPass::printSummary() {
  os << "Operations encountered:\n";
  os << "-----------------------\n";
  std::vector<StringRef> sorted(opCount.keys().begin(), opCount.keys().end());
  llvm::sort(sorted);

  // Returns the lenght of the dialect prefix of an op.
  auto dialectLen = [](StringRef opName) -> size_t {
    auto dialectEnd = opName.find_last_of('.');
    if (dialectEnd == StringRef::npos)
      return 0;
    // Count the period too.
    return dialectEnd + 1;
  };

  // Left-align the names (aligning on the dialect) and right-align count below.
  // The alignment is for readability and does not affect CSV/FileCheck parsing.
  size_t maxLenName = 0;
  size_t maxLenNamePrefixLen = 0;
  size_t maxLenDialect = 0;
  int maxLenCount = 0;
  for (const auto &key : sorted) {
    size_t len = key.size();
    size_t prefix = dialectLen(key);
    if (len > maxLenName) {
      maxLenName = len;
      maxLenNamePrefixLen = prefix;
    }
    maxLenDialect = max(maxLenDialect, prefix);
    // This takes advantage of the fact that opCount[key] > 0.
    maxLenCount = max(maxLenCount, (int)log10(opCount[key]) + 1);
  }
  // Adjust the max name length to account for the dialect.
  maxLenName += (maxLenDialect - maxLenNamePrefixLen);

  for (const auto &key : sorted) {
    size_t prefix = maxLenDialect - dialectLen(key);
    os.indent(2 + prefix) << '\'' << key << '\'';
    // Add one to compensate for the period of the dialect.
    os.indent(maxLenName + 1 - key.size() - prefix) << " ,";
    os.indent(maxLenCount - (int)log10(opCount[key])) << opCount[key] << "\n";
  }
}

static PassRegistration<PrintOpStatsPass>
    pass("print-op-stats", "Print statistics of operations");
OpenPOWER on IntegriCloud