summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Utils.cpp
blob: 2e8f0d32736bd3154ade6b569f29093af11a4f9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
//===- Utils.cpp ---- Misc utilities for code and data transformation -----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements miscellaneous transformation routines for non-loop IR
// structures.
//
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/Utils.h"

#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h"

using namespace mlir;

/// Return true if this operation dereferences one or more memref's.
// Temporary utility: will be replaced when this is modeled through
// side-effects/op traits. TODO(b/117228571)
static bool isMemRefDereferencingOp(const Operation &op) {
  if (op.is<LoadOp>() || op.is<StoreOp>() || op.is<DmaStartOp>() ||
      op.is<DmaWaitOp>())
    return true;
  return false;
}

/// Replaces all uses of oldMemRef with newMemRef while optionally remapping
/// old memref's indices to the new memref using the supplied affine map
/// and adding any additional indices. The new memref could be of a different
/// shape or rank, but of the same elemental type. Additional indices are added
/// at the start for now.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position.
bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
                                    ArrayRef<SSAValue *> extraIndices,
                                    AffineMap indexRemap) {
  unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
  (void)newMemRefRank; // unused in opt mode
  unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
  (void)newMemRefRank;
  if (indexRemap) {
    assert(indexRemap.getNumInputs() == oldMemRefRank);
    assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank);
  } else {
    assert(oldMemRefRank + extraIndices.size() == newMemRefRank);
  }

  // Assert same elemental type.
  assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
         cast<MemRefType>(newMemRef->getType())->getElementType());

  // Check if memref was used in a non-deferencing context.
  for (const StmtOperand &use : oldMemRef->getUses()) {
    auto *opStmt = cast<OperationStmt>(use.getOwner());
    // Failure: memref used in a non-deferencing op (potentially escapes); no
    // replacement in these cases.
    if (!isMemRefDereferencingOp(*opStmt))
      return false;
  }

  // Walk all uses of old memref. Statement using the memref gets replaced.
  for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
    StmtOperand &use = *(it++);
    auto *opStmt = cast<OperationStmt>(use.getOwner());
    assert(isMemRefDereferencingOp(*opStmt) &&
           "memref deferencing op expected");

    auto getMemRefOperandPos = [&]() -> unsigned {
      unsigned i;
      for (i = 0; i < opStmt->getNumOperands(); i++) {
        if (opStmt->getOperand(i) == oldMemRef)
          break;
      }
      assert(i < opStmt->getNumOperands() && "operand guaranteed to be found");
      return i;
    };
    unsigned memRefOperandPos = getMemRefOperandPos();

    // Construct the new operation statement using this memref.
    SmallVector<MLValue *, 8> operands;
    operands.reserve(opStmt->getNumOperands() + extraIndices.size());
    // Insert the non-memref operands.
    operands.insert(operands.end(), opStmt->operand_begin(),
                    opStmt->operand_begin() + memRefOperandPos);
    operands.push_back(newMemRef);

    MLFuncBuilder builder(opStmt);
    // Normally, we could just use extraIndices as operands, but we will
    // clone it so that each op gets its own "private" index. See b/117159533.
    for (auto *extraIndex : extraIndices) {
      OperationStmt::OperandMapTy operandMap;
      // TODO(mlir-team): An operation/SSA value should provide a method to
      // return the position of an SSA result in its defining
      // operation.
      assert(extraIndex->getDefiningStmt()->getNumResults() == 1 &&
             "single result op's expected to generate these indices");
      // TODO: actually check if this is a result of an affine_apply op.
      assert((cast<MLValue>(extraIndex)->isValidDim() ||
              cast<MLValue>(extraIndex)->isValidSymbol()) &&
             "invalid memory op index");
      auto *clonedExtraIndex =
          cast<OperationStmt>(
              builder.clone(*extraIndex->getDefiningStmt(), operandMap))
              ->getResult(0);
      operands.push_back(cast<MLValue>(clonedExtraIndex));
    }

    // Construct new indices. The indices of a memref come right after it, i.e.,
    // at position memRefOperandPos + 1.
    SmallVector<SSAValue *, 4> indices(
        opStmt->operand_begin() + memRefOperandPos + 1,
        opStmt->operand_begin() + memRefOperandPos + 1 + oldMemRefRank);
    if (indexRemap) {
      auto remapOp =
          builder.create<AffineApplyOp>(opStmt->getLoc(), indexRemap, indices);
      // Remapped indices.
      for (auto *index : remapOp->getOperation()->getResults())
        operands.push_back(cast<MLValue>(index));
    } else {
      // No remapping specified.
      for (auto *index : indices)
        operands.push_back(cast<MLValue>(index));
    }

    // Insert the remaining operands unmodified.
    operands.insert(operands.end(),
                    opStmt->operand_begin() + memRefOperandPos + 1 +
                        oldMemRefRank,
                    opStmt->operand_end());

    // Result types don't change. Both memref's are of the same elemental type.
    SmallVector<Type *, 8> resultTypes;
    resultTypes.reserve(opStmt->getNumResults());
    for (const auto *result : opStmt->getResults())
      resultTypes.push_back(result->getType());

    // Create the new operation.
    auto *repOp =
        builder.createOperation(opStmt->getLoc(), opStmt->getName(), operands,
                                resultTypes, opStmt->getAttrs());
    // Replace old memref's deferencing op's uses.
    unsigned r = 0;
    for (auto *res : opStmt->getResults()) {
      res->replaceAllUsesWith(repOp->getResult(r++));
    }
    opStmt->eraseFromBlock();
  }
  return true;
}
OpenPOWER on IntegriCloud