From 5052bd8582fbcfc0a4774c34141c2dd04b333613 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 1 Feb 2019 16:42:18 -0800 Subject: Define the AffineForOp and replace ForInst with it. This patch is largely mechanical, i.e. changing usages of ForInst to OpPointer. An important difference is that upon construction an AffineForOp no longer automatically creates the body and induction variable. To generate the body/iv, 'createBody' can be called on an AffineForOp with no body. PiperOrigin-RevId: 232060516 --- mlir/lib/Transforms/PipelineDataTransfer.cpp | 70 ++++++++++++++-------------- 1 file changed, 36 insertions(+), 34 deletions(-) (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp') diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 811741d08d1..2e083bbfd79 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -21,11 +21,11 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" #include "mlir/StandardOps/StandardOps.h" #include "mlir/Transforms/LoopUtils.h" @@ -38,15 +38,12 @@ using namespace mlir; namespace { -struct PipelineDataTransfer : public FunctionPass, - InstWalker { +struct PipelineDataTransfer : public FunctionPass { PipelineDataTransfer() : FunctionPass(&PipelineDataTransfer::passID) {} PassResult runOnFunction(Function *f) override; - PassResult runOnForInst(ForInst *forInst); + PassResult runOnAffineForOp(OpPointer forOp); - // Collect all 'for' instructions. - void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } - std::vector forInsts; + std::vector> forOps; static char passID; }; @@ -79,8 +76,8 @@ static unsigned getTagMemRefPos(const OperationInst &dmaInst) { /// of the old memref by the new one while indexing the newly added dimension by /// the loop IV of the specified 'for' instruction modulo 2. Returns false if /// such a replacement cannot be performed. -static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { - auto *forBody = forInst->getBody(); +static bool doubleBuffer(Value *oldMemRef, OpPointer forOp) { + auto *forBody = forOp->getBody(); FuncBuilder bInner(forBody, forBody->begin()); bInner.setInsertionPoint(forBody, forBody->begin()); @@ -101,6 +98,7 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { auto newMemRefType = doubleShape(oldMemRefType); // Put together alloc operands for the dynamic dimensions of the memref. + auto *forInst = forOp->getInstruction(); FuncBuilder bOuter(forInst); SmallVector allocOperands; unsigned dynamicDimCount = 0; @@ -118,16 +116,16 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); - int64_t step = forInst->getStep(); + int64_t step = forOp->getStep(); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0.floorDiv(step) % 2}, {}); - auto ivModTwoOp = bInner.create(forInst->getLoc(), modTwoMap, - forInst->getInductionVar()); + auto ivModTwoOp = bInner.create(forOp->getLoc(), modTwoMap, + forOp->getInductionVar()); - // replaceAllMemRefUsesWith will always succeed unless the forInst body has + // replaceAllMemRefUsesWith will always succeed unless the forOp body has // non-deferencing uses of the memref. if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, {ivModTwoOp}, AffineMap(), - {}, &*forInst->getBody()->begin())) { + {}, &*forOp->getBody()->begin())) { LLVM_DEBUG(llvm::dbgs() << "memref replacement for double buffering failed\n";); ivModTwoOp->getInstruction()->erase(); @@ -143,11 +141,14 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) { // invalid (erased) when the outer loop is pipelined (the pipelined one gets // deleted and replaced by a prologue, a new steady-state loop and an // epilogue). - forInsts.clear(); - walkPostOrder(f); + forOps.clear(); + f->walkOpsPostOrder([&](OperationInst *opInst) { + if (auto forOp = opInst->dyn_cast()) + forOps.push_back(forOp); + }); bool ret = false; - for (auto *forInst : forInsts) { - ret = ret | runOnForInst(forInst); + for (auto forOp : forOps) { + ret = ret | runOnAffineForOp(forOp); } return ret ? failure() : success(); } @@ -178,13 +179,13 @@ static bool checkTagMatch(OpPointer startOp, // Identify matching DMA start/finish instructions to overlap computation with. static void findMatchingStartFinishInsts( - ForInst *forInst, + OpPointer forOp, SmallVectorImpl> &startWaitPairs) { // Collect outgoing DMA instructions - needed to check for dependences below. SmallVector, 4> outgoingDmaOps; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast(&inst); if (!opInst) continue; @@ -195,7 +196,7 @@ static void findMatchingStartFinishInsts( } SmallVector dmaStartInsts, dmaFinishInsts; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { auto *opInst = dyn_cast(&inst); if (!opInst) continue; @@ -227,7 +228,7 @@ static void findMatchingStartFinishInsts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) { + if (!forOp->getBody()->findAncestorInstInBlock(*use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; @@ -251,17 +252,18 @@ static void findMatchingStartFinishInsts( } /// Overlap DMA transfers with computation in this loop. If successful, -/// 'forInst' is deleted, and a prologue, a new pipelined loop, and epilogue are +/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are /// inserted right before where it was. -PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { - auto mayBeConstTripCount = getConstantTripCount(*forInst); +PassResult +PipelineDataTransfer::runOnAffineForOp(OpPointer forOp) { + auto mayBeConstTripCount = getConstantTripCount(forOp); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); return success(); } SmallVector, 4> startWaitPairs; - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); @@ -280,7 +282,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaStartInst = pair.first; Value *oldMemRef = dmaStartInst->getOperand( dmaStartInst->cast()->getFasterMemPos()); - if (!doubleBuffer(oldMemRef, forInst)) { + if (!doubleBuffer(oldMemRef, forOp)) { // Normally, double buffering should not fail because we already checked // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); @@ -302,7 +304,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { auto *dmaFinishInst = pair.second; Value *oldTagMemRef = dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst)); - if (!doubleBuffer(oldTagMemRef, forInst)) { + if (!doubleBuffer(oldTagMemRef, forOp)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); return success(); } @@ -315,7 +317,7 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { // Double buffering would have invalidated all the old DMA start/wait insts. startWaitPairs.clear(); - findMatchingStartFinishInsts(forInst, startWaitPairs); + findMatchingStartFinishInsts(forOp, startWaitPairs); // Store shift for instruction for later lookup for AffineApplyOp's. DenseMap instShiftMap; @@ -342,16 +344,16 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { } } // Everything else (including compute ops and dma finish) are shifted by one. - for (const auto &inst : *forInst->getBody()) { + for (const auto &inst : *forOp->getBody()) { if (instShiftMap.find(&inst) == instShiftMap.end()) { instShiftMap[&inst] = 1; } } // Get shifts stored in map. - std::vector shifts(forInst->getBody()->getInstructions().size()); + std::vector shifts(forOp->getBody()->getInstructions().size()); unsigned s = 0; - for (auto &inst : *forInst->getBody()) { + for (auto &inst : *forOp->getBody()) { assert(instShiftMap.find(&inst) != instShiftMap.end()); shifts[s++] = instShiftMap[&inst]; LLVM_DEBUG( @@ -363,13 +365,13 @@ PassResult PipelineDataTransfer::runOnForInst(ForInst *forInst) { }); } - if (!isInstwiseShiftValid(*forInst, shifts)) { + if (!isInstwiseShiftValid(forOp, shifts)) { // Violates dependences. LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); return success(); } - if (instBodySkew(forInst, shifts)) { + if (instBodySkew(forOp, shifts)) { LLVM_DEBUG(llvm::dbgs() << "inst body skewing failed - unexpected\n";); return success(); } -- cgit v1.2.3