summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/PipelineDataTransfer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp43
1 files changed, 18 insertions, 25 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 8d13800160d..ba3be5e95f4 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -61,7 +61,7 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
// Returns the position of the tag memref operand given a DMA instruction.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
+static unsigned getTagMemRefPos(const Instruction &dmaInst) {
assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
if (dmaInst.isa<DmaStartOp>()) {
// Second to last operand.
@@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forOps.clear();
- f->walkPostOrder([&](OperationInst *opInst) {
+ f->walkPostOrder([&](Instruction *opInst) {
if (auto forOp = opInst->dyn_cast<AffineForOp>())
forOps.push_back(forOp);
});
@@ -180,33 +180,26 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
// Identify matching DMA start/finish instructions to overlap computation with.
static void findMatchingStartFinishInsts(
OpPointer<AffineForOp> forOp,
- SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
- &startWaitPairs) {
+ SmallVectorImpl<std::pair<Instruction *, Instruction *>> &startWaitPairs) {
// Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
for (auto &inst : *forOp->getBody()) {
- auto *opInst = dyn_cast<OperationInst>(&inst);
- if (!opInst)
- continue;
OpPointer<DmaStartOp> dmaStartOp;
- if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
+ if ((dmaStartOp = inst.dyn_cast<DmaStartOp>()) &&
dmaStartOp->isSrcMemorySpaceFaster())
outgoingDmaOps.push_back(dmaStartOp);
}
- SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
+ SmallVector<Instruction *, 4> dmaStartInsts, dmaFinishInsts;
for (auto &inst : *forOp->getBody()) {
- auto *opInst = dyn_cast<OperationInst>(&inst);
- if (!opInst)
- continue;
// Collect DMA finish instructions.
- if (opInst->isa<DmaWaitOp>()) {
- dmaFinishInsts.push_back(opInst);
+ if (inst.isa<DmaWaitOp>()) {
+ dmaFinishInsts.push_back(&inst);
continue;
}
OpPointer<DmaStartOp> dmaStartOp;
- if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
+ if (!(dmaStartOp = inst.dyn_cast<DmaStartOp>()))
continue;
// Only DMAs incoming into higher memory spaces are pipelined for now.
// TODO(bondhugula): handle outgoing DMA pipelining.
@@ -236,7 +229,7 @@ static void findMatchingStartFinishInsts(
}
}
if (!escapingUses)
- dmaStartInsts.push_back(opInst);
+ dmaStartInsts.push_back(&inst);
}
// For each start instruction, we look for a matching finish instruction.
@@ -262,7 +255,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
return success();
}
- SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
+ SmallVector<std::pair<Instruction *, Instruction *>, 4> startWaitPairs;
findMatchingStartFinishInsts(forOp, startWaitPairs);
if (startWaitPairs.empty()) {
@@ -335,7 +328,7 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
- SmallVector<OperationInst *, 4> affineApplyInsts;
+ SmallVector<Instruction *, 4> affineApplyInsts;
SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
getReachableAffineApplyOps(operands, affineApplyInsts);
for (const auto *inst : affineApplyInsts) {
@@ -356,13 +349,13 @@ PipelineDataTransfer::runOnAffineForOp(OpPointer<AffineForOp> forOp) {
for (auto &inst : *forOp->getBody()) {
assert(instShiftMap.find(&inst) != instShiftMap.end());
shifts[s++] = instShiftMap[&inst];
- LLVM_DEBUG(
- // Tagging instructions with shifts for debugging purposes.
- if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
- FuncBuilder b(opInst);
- opInst->setAttr(b.getIdentifier("shift"),
- b.getI64IntegerAttr(shifts[s - 1]));
- });
+
+ // Tagging instructions with shifts for debugging purposes.
+ LLVM_DEBUG({
+ FuncBuilder b(&inst);
+ inst.setAttr(b.getIdentifier("shift"),
+ b.getI64IntegerAttr(shifts[s - 1]));
+ });
}
if (!isInstwiseShiftValid(forOp, shifts)) {
OpenPOWER on IntegriCloud