summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/PipelineDataTransfer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp200
1 files changed, 100 insertions, 100 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index c8a6ced4ed1..debaac3a33c 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -25,7 +25,7 @@
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@@ -39,14 +39,14 @@ using namespace mlir;
namespace {
struct PipelineDataTransfer : public FunctionPass,
- StmtWalker<PipelineDataTransfer> {
+ InstWalker<PipelineDataTransfer> {
PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {}
PassResult runOnMLFunction(Function *f) override;
- PassResult runOnForStmt(ForStmt *forStmt);
+ PassResult runOnForInst(ForInst *forInst);
- // Collect all 'for' statements.
- void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
- std::vector<ForStmt *> forStmts;
+ // Collect all 'for' instructions.
+ void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); }
+ std::vector<ForInst *> forInsts;
static char passID;
};
@@ -61,26 +61,26 @@ FunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-// Returns the position of the tag memref operand given a DMA statement.
+// Returns the position of the tag memref operand given a DMA instruction.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getTagMemRefPos(const OperationInst &dmaStmt) {
- assert(dmaStmt.isa<DmaStartOp>() || dmaStmt.isa<DmaWaitOp>());
- if (dmaStmt.isa<DmaStartOp>()) {
+static unsigned getTagMemRefPos(const OperationInst &dmaInst) {
+ assert(dmaInst.isa<DmaStartOp>() || dmaInst.isa<DmaWaitOp>());
+ if (dmaInst.isa<DmaStartOp>()) {
// Second to last operand.
- return dmaStmt.getNumOperands() - 2;
+ return dmaInst.getNumOperands() - 2;
}
- // First operand for a dma finish statement.
+ // First operand for a dma finish instruction.
return 0;
}
-/// Doubles the buffer of the supplied memref on the specified 'for' statement
+/// Doubles the buffer of the supplied memref on the specified 'for' instruction
/// by adding a leading dimension of size two to the memref. Replaces all uses
/// of the old memref by the new one while indexing the newly added dimension by
-/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
-/// a replacement cannot be performed.
-static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
- auto *forBody = forStmt->getBody();
+/// the loop IV of the specified 'for' instruction modulo 2. Returns false if
+/// such a replacement cannot be performed.
+static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
+ auto *forBody = forInst->getBody();
FuncBuilder bInner(forBody, forBody->begin());
bInner.setInsertionPoint(forBody, forBody->begin());
@@ -101,33 +101,33 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
auto newMemRefType = doubleShape(oldMemRefType);
// Put together alloc operands for the dynamic dimensions of the memref.
- FuncBuilder bOuter(forStmt);
+ FuncBuilder bOuter(forInst);
SmallVector<Value *, 4> allocOperands;
unsigned dynamicDimCount = 0;
for (auto dimSize : oldMemRefType.getShape()) {
if (dimSize == -1)
- allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef,
+ allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
dynamicDimCount++));
}
- // Create and place the alloc right before the 'for' statement.
+ // Create and place the alloc right before the 'for' instruction.
// TODO(mlir-team): we are assuming scoped allocation here, and aren't
// inserting a dealloc -- this isn't the right thing.
Value *newMemRef =
- bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
+ bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
// Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
auto modTwoMap =
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
auto ivModTwoOp =
- bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
+ bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst);
- // replaceAllMemRefUsesWith will always succeed unless the forStmt body has
+ // replaceAllMemRefUsesWith will always succeed unless the forInst body has
// non-deferencing uses of the memref.
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0),
AffineMap::Null(), {},
- &*forStmt->getBody()->begin())) {
+ &*forInst->getBody()->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getInstruction()->erase();
@@ -139,15 +139,15 @@ static bool doubleBuffer(Value *oldMemRef, ForStmt *forStmt) {
/// Returns success if the IR is in a valid state.
PassResult PipelineDataTransfer::runOnMLFunction(Function *f) {
// Do a post order walk so that inner loop DMAs are processed first. This is
- // necessary since 'for' statements nested within would otherwise become
+ // necessary since 'for' instructions nested within would otherwise become
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
- forStmts.clear();
+ forInsts.clear();
walkPostOrder(f);
bool ret = false;
- for (auto *forStmt : forStmts) {
- ret = ret | runOnForStmt(forStmt);
+ for (auto *forInst : forInsts) {
+ ret = ret | runOnForInst(forInst);
}
return ret ? failure() : success();
}
@@ -176,36 +176,36 @@ static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
return true;
}
-// Identify matching DMA start/finish statements to overlap computation with.
-static void findMatchingStartFinishStmts(
- ForStmt *forStmt,
+// Identify matching DMA start/finish instructions to overlap computation with.
+static void findMatchingStartFinishInsts(
+ ForInst *forInst,
SmallVectorImpl<std::pair<OperationInst *, OperationInst *>>
&startWaitPairs) {
- // Collect outgoing DMA statements - needed to check for dependences below.
+ // Collect outgoing DMA instructions - needed to check for dependences below.
SmallVector<OpPointer<DmaStartOp>, 4> outgoingDmaOps;
- for (auto &stmt : *forStmt->getBody()) {
- auto *opStmt = dyn_cast<OperationInst>(&stmt);
- if (!opStmt)
+ for (auto &inst : *forInst->getBody()) {
+ auto *opInst = dyn_cast<OperationInst>(&inst);
+ if (!opInst)
continue;
OpPointer<DmaStartOp> dmaStartOp;
- if ((dmaStartOp = opStmt->dyn_cast<DmaStartOp>()) &&
+ if ((dmaStartOp = opInst->dyn_cast<DmaStartOp>()) &&
dmaStartOp->isSrcMemorySpaceFaster())
outgoingDmaOps.push_back(dmaStartOp);
}
- SmallVector<OperationInst *, 4> dmaStartStmts, dmaFinishStmts;
- for (auto &stmt : *forStmt->getBody()) {
- auto *opStmt = dyn_cast<OperationInst>(&stmt);
- if (!opStmt)
+ SmallVector<OperationInst *, 4> dmaStartInsts, dmaFinishInsts;
+ for (auto &inst : *forInst->getBody()) {
+ auto *opInst = dyn_cast<OperationInst>(&inst);
+ if (!opInst)
continue;
- // Collect DMA finish statements.
- if (opStmt->isa<DmaWaitOp>()) {
- dmaFinishStmts.push_back(opStmt);
+ // Collect DMA finish instructions.
+ if (opInst->isa<DmaWaitOp>()) {
+ dmaFinishInsts.push_back(opInst);
continue;
}
OpPointer<DmaStartOp> dmaStartOp;
- if (!(dmaStartOp = opStmt->dyn_cast<DmaStartOp>()))
+ if (!(dmaStartOp = opInst->dyn_cast<DmaStartOp>()))
continue;
// Only DMAs incoming into higher memory spaces are pipelined for now.
// TODO(bondhugula): handle outgoing DMA pipelining.
@@ -227,7 +227,7 @@ static void findMatchingStartFinishStmts(
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!dominates(*forStmt->getBody()->begin(), *use.getOwner())) {
+ if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
@@ -235,15 +235,15 @@ static void findMatchingStartFinishStmts(
}
}
if (!escapingUses)
- dmaStartStmts.push_back(opStmt);
+ dmaStartInsts.push_back(opInst);
}
- // For each start statement, we look for a matching finish statement.
- for (auto *dmaStartStmt : dmaStartStmts) {
- for (auto *dmaFinishStmt : dmaFinishStmts) {
- if (checkTagMatch(dmaStartStmt->cast<DmaStartOp>(),
- dmaFinishStmt->cast<DmaWaitOp>())) {
- startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
+ // For each start instruction, we look for a matching finish instruction.
+ for (auto *dmaStartInst : dmaStartInsts) {
+ for (auto *dmaFinishInst : dmaFinishInsts) {
+ if (checkTagMatch(dmaStartInst->cast<DmaStartOp>(),
+ dmaFinishInst->cast<DmaWaitOp>())) {
+ startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
break;
}
}
@@ -251,17 +251,17 @@ static void findMatchingStartFinishStmts(
}
/// Overlap DMA transfers with computation in this loop. If successful,
-/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// inserted right before where it was.
-PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
- auto mayBeConstTripCount = getConstantTripCount(*forStmt);
+PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) {
+ auto mayBeConstTripCount = getConstantTripCount(*forInst);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return success();
}
SmallVector<std::pair<OperationInst *, OperationInst *>, 4> startWaitPairs;
- findMatchingStartFinishStmts(forStmt, startWaitPairs);
+ findMatchingStartFinishInsts(forInst, startWaitPairs);
if (startWaitPairs.empty()) {
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
@@ -269,22 +269,22 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
}
// Double the buffers for the higher memory space memref's.
- // Identify memref's to replace by scanning through all DMA start statements.
- // A DMA start statement has two memref's - the one from the higher level of
- // memory hierarchy is the one to double buffer.
+ // Identify memref's to replace by scanning through all DMA start
+ // instructions. A DMA start instruction has two memref's - the one from the
+ // higher level of memory hierarchy is the one to double buffer.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
for (auto &pair : startWaitPairs) {
- auto *dmaStartStmt = pair.first;
- Value *oldMemRef = dmaStartStmt->getOperand(
- dmaStartStmt->cast<DmaStartOp>()->getFasterMemPos());
- if (!doubleBuffer(oldMemRef, forStmt)) {
+ auto *dmaStartInst = pair.first;
+ Value *oldMemRef = dmaStartInst->getOperand(
+ dmaStartInst->cast<DmaStartOp>()->getFasterMemPos());
+ if (!doubleBuffer(oldMemRef, forInst)) {
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
- LLVM_DEBUG(dmaStartStmt->dump());
+ LLVM_DEBUG(dmaStartInst->dump());
// IR still in a valid state.
return success();
}
@@ -293,80 +293,80 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// operation could have been used on it if it was dynamically shaped in
// order to create the double buffer above)
if (oldMemRef->use_empty())
- if (auto *allocStmt = oldMemRef->getDefiningInst())
- allocStmt->erase();
+ if (auto *allocInst = oldMemRef->getDefiningInst())
+ allocInst->erase();
}
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
- auto *dmaFinishStmt = pair.second;
+ auto *dmaFinishInst = pair.second;
Value *oldTagMemRef =
- dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt));
- if (!doubleBuffer(oldTagMemRef, forStmt)) {
+ dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
+ if (!doubleBuffer(oldTagMemRef, forInst)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return success();
}
// If the old tag has no more uses, remove its 'dead' alloc if it was
// alloc'ed.
if (oldTagMemRef->use_empty())
- if (auto *allocStmt = oldTagMemRef->getDefiningInst())
- allocStmt->erase();
+ if (auto *allocInst = oldTagMemRef->getDefiningInst())
+ allocInst->erase();
}
- // Double buffering would have invalidated all the old DMA start/wait stmts.
+ // Double buffering would have invalidated all the old DMA start/wait insts.
startWaitPairs.clear();
- findMatchingStartFinishStmts(forStmt, startWaitPairs);
+ findMatchingStartFinishInsts(forInst, startWaitPairs);
- // Store shift for statement for later lookup for AffineApplyOp's.
- DenseMap<const Statement *, unsigned> stmtShiftMap;
+ // Store shift for instruction for later lookup for AffineApplyOp's.
+ DenseMap<const Instruction *, unsigned> instShiftMap;
for (auto &pair : startWaitPairs) {
- auto *dmaStartStmt = pair.first;
- assert(dmaStartStmt->isa<DmaStartOp>());
- stmtShiftMap[dmaStartStmt] = 0;
- // Set shifts for DMA start stmt's affine operand computation slices to 0.
- if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
- stmtShiftMap[slice] = 0;
+ auto *dmaStartInst = pair.first;
+ assert(dmaStartInst->isa<DmaStartOp>());
+ instShiftMap[dmaStartInst] = 0;
+ // Set shifts for DMA start inst's affine operand computation slices to 0.
+ if (auto *slice = mlir::createAffineComputationSlice(dmaStartInst)) {
+ instShiftMap[slice] = 0;
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
- SmallVector<OperationInst *, 4> affineApplyStmts;
- SmallVector<Value *, 4> operands(dmaStartStmt->getOperands());
- getReachableAffineApplyOps(operands, affineApplyStmts);
- for (const auto *stmt : affineApplyStmts) {
- stmtShiftMap[stmt] = 0;
+ SmallVector<OperationInst *, 4> affineApplyInsts;
+ SmallVector<Value *, 4> operands(dmaStartInst->getOperands());
+ getReachableAffineApplyOps(operands, affineApplyInsts);
+ for (const auto *inst : affineApplyInsts) {
+ instShiftMap[inst] = 0;
}
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
- for (const auto &stmt : *forStmt->getBody()) {
- if (stmtShiftMap.find(&stmt) == stmtShiftMap.end()) {
- stmtShiftMap[&stmt] = 1;
+ for (const auto &inst : *forInst->getBody()) {
+ if (instShiftMap.find(&inst) == instShiftMap.end()) {
+ instShiftMap[&inst] = 1;
}
}
// Get shifts stored in map.
- std::vector<uint64_t> shifts(forStmt->getBody()->getInstructions().size());
+ std::vector<uint64_t> shifts(forInst->getBody()->getInstructions().size());
unsigned s = 0;
- for (auto &stmt : *forStmt->getBody()) {
- assert(stmtShiftMap.find(&stmt) != stmtShiftMap.end());
- shifts[s++] = stmtShiftMap[&stmt];
+ for (auto &inst : *forInst->getBody()) {
+ assert(instShiftMap.find(&inst) != instShiftMap.end());
+ shifts[s++] = instShiftMap[&inst];
LLVM_DEBUG(
- // Tagging statements with shifts for debugging purposes.
- if (auto *opStmt = dyn_cast<OperationInst>(&stmt)) {
- FuncBuilder b(opStmt);
- opStmt->setAttr(b.getIdentifier("shift"),
+ // Tagging instructions with shifts for debugging purposes.
+ if (auto *opInst = dyn_cast<OperationInst>(&inst)) {
+ FuncBuilder b(opInst);
+ opInst->setAttr(b.getIdentifier("shift"),
b.getI64IntegerAttr(shifts[s - 1]));
});
}
- if (!isStmtwiseShiftValid(*forStmt, shifts)) {
+ if (!isInstwiseShiftValid(*forInst, shifts)) {
// Violates dependences.
LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
return success();
}
- if (stmtBodySkew(forStmt, shifts)) {
- LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";);
+ if (instBodySkew(forInst, shifts)) {
+ LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";);
return success();
}
OpenPOWER on IntegriCloud