diff options
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 47 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils.cpp | 7 |
2 files changed, 17 insertions, 37 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 1eeb9a9aa5c..0d025f5678f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -46,34 +46,16 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or -// op traits for it are added. TODO(b/117228571) -static bool isDmaStartStmt(const OperationStmt &stmt) { - return stmt.getName().strref().contains("dma.in.start") || - stmt.getName().strref().contains("dma.out.start"); -} - -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are -// added. TODO(b/117228571) -static bool isDmaFinishStmt(const OperationStmt &stmt) { - return stmt.getName().strref().contains("dma.finish"); -} - /// Given a DMA start operation, returns the operand position of either the /// source or destination memref depending on the one that is at the higher /// level of the memory hierarchy. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) { - assert(isDmaStartStmt(dmaStartStmt)); +static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) { unsigned srcDmaPos = 0; - unsigned destDmaPos = - cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1; + unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1; - if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType()) - ->getMemorySpace() > - cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType()) - ->getMemorySpace()) + if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace()) return srcDmaPos; return destDmaPos; } @@ -81,9 +63,9 @@ static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) { // Returns the position of the tag memref operand given a DMA statement. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { - assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt)); - if (isDmaStartStmt(dmaStmt)) { +static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { + assert(dmaStmt.is<DmaStartOp>() || dmaStmt.is<DmaWaitOp>()); + if (dmaStmt.is<DmaStartOp>()) { // Second to last operand. return dmaStmt.getNumOperands() - 2; } @@ -91,7 +73,8 @@ unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { return 0; } -/// Doubles the buffer of the supplied memref. +/// Doubles the buffer of the supplied memref while replacing all uses of the +/// old memref. Returns false if such a replacement cannot be performed. static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { MLFuncBuilder bInner(forStmt, forStmt->begin()); bInner.setInsertionPoint(forStmt, forStmt->begin()); @@ -130,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { return true; } -// For testing purposes, this just runs on the first for statement of an +// For testing purposes, this just runs on the first 'for' statement of an // MLFunction at the top level. // TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when // the other TODOs listed inside are dealt with. @@ -158,9 +141,9 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; - if (isDmaStartStmt(*opStmt)) { + if (opStmt->is<DmaStartOp>()) { dmaStartStmts.push_back(opStmt); - } else if (isDmaFinishStmt(*opStmt)) { + } else if (opStmt->is<DmaWaitOp>()) { dmaFinishStmts.push_back(opStmt); } } @@ -182,8 +165,8 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { // A DMA start statement has two memref's - the one from the higher level of // memory hierarchy is the one to double buffer. for (auto *dmaStartStmt : dmaStartStmts) { - MLValue *oldMemRef = cast<MLValue>( - dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt))); + MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( + getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>()))); if (!doubleBuffer(oldMemRef, forStmt)) return PassResult::Failure; } @@ -208,10 +191,10 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { // TODO(bondhugula): check whether such statements do not have any DMAs // nested within. opDelayMap[&stmt] = 1; - } else if (isDmaStartStmt(*opStmt)) { + } else if (opStmt->is<DmaStartOp>()) { // DMA starts are not shifted. opDelayMap[&stmt] = 0; - } else if (isDmaFinishStmt(*opStmt)) { + } else if (opStmt->is<DmaWaitOp>()) { // DMA finish op shifted by one. opDelayMap[&stmt] = 1; } else if (!opStmt->is<AffineApplyOp>()) { diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp index cc1c7973858..0262eb94bd7 100644 --- a/mlir/lib/Transforms/Utils.cpp +++ b/mlir/lib/Transforms/Utils.cpp @@ -33,12 +33,9 @@ using namespace mlir; // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) static bool isMemRefDereferencingOp(const Operation &op) { - if (op.is<LoadOp>() || op.is<StoreOp>() || - op.getName().strref().contains("dma.in.start") || - op.getName().strref().contains("dma.out.start") || - op.getName().strref().contains("dma.finish")) { + if (op.is<LoadOp>() || op.is<StoreOp>() || op.is<DmaStartOp>() || + op.is<DmaWaitOp>()) return true; - } return false; } |

