diff options
| author | Uday Bondhugula <bondhugula@google.com> | 2018-10-09 15:04:27 -0700 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 13:26:10 -0700 |
| commit | 82e55750d2dee6b927061574a31ed3eab2d92b16 (patch) | |
| tree | 020112d0306200d177096dfa41f8112dc1080a02 /mlir/lib/Transforms | |
| parent | 2df03be62108132c73cb94a7a0fc6bd066031d88 (diff) | |
| download | bcm5719-llvm-82e55750d2dee6b927061574a31ed3eab2d92b16.tar.gz bcm5719-llvm-82e55750d2dee6b927061574a31ed3eab2d92b16.zip | |
Add target independent standard DMA ops: dma.start, dma.wait
Add target independent standard DMA ops: dma.start, dma.wait. Update pipeline
data transfer to use these to detect DMA ops.
While on this
- return failure from mlir-opt::performActions if a pass generates invalid output
- improve error message for verify 'n' operand traits
PiperOrigin-RevId: 216429885
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 47 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils.cpp | 7 |
2 files changed, 17 insertions, 37 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 1eeb9a9aa5c..0d025f5678f 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -46,34 +46,16 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() { return new PipelineDataTransfer(); } -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or -// op traits for it are added. TODO(b/117228571) -static bool isDmaStartStmt(const OperationStmt &stmt) { - return stmt.getName().strref().contains("dma.in.start") || - stmt.getName().strref().contains("dma.out.start"); -} - -// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are -// added. TODO(b/117228571) -static bool isDmaFinishStmt(const OperationStmt &stmt) { - return stmt.getName().strref().contains("dma.finish"); -} - /// Given a DMA start operation, returns the operand position of either the /// source or destination memref depending on the one that is at the higher /// level of the memory hierarchy. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) { - assert(isDmaStartStmt(dmaStartStmt)); +static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) { unsigned srcDmaPos = 0; - unsigned destDmaPos = - cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1; + unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1; - if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType()) - ->getMemorySpace() > - cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType()) - ->getMemorySpace()) + if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace()) return srcDmaPos; return destDmaPos; } @@ -81,9 +63,9 @@ static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) { // Returns the position of the tag memref operand given a DMA statement. // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are // added. TODO(b/117228571) -unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { - assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt)); - if (isDmaStartStmt(dmaStmt)) { +static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { + assert(dmaStmt.is<DmaStartOp>() || dmaStmt.is<DmaWaitOp>()); + if (dmaStmt.is<DmaStartOp>()) { // Second to last operand. return dmaStmt.getNumOperands() - 2; } @@ -91,7 +73,8 @@ unsigned getTagMemRefPos(const OperationStmt &dmaStmt) { return 0; } -/// Doubles the buffer of the supplied memref. +/// Doubles the buffer of the supplied memref while replacing all uses of the +/// old memref. Returns false if such a replacement cannot be performed. static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { MLFuncBuilder bInner(forStmt, forStmt->begin()); bInner.setInsertionPoint(forStmt, forStmt->begin()); @@ -130,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) { return true; } -// For testing purposes, this just runs on the first for statement of an +// For testing purposes, this just runs on the first 'for' statement of an // MLFunction at the top level. // TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when // the other TODOs listed inside are dealt with. @@ -158,9 +141,9 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { auto *opStmt = dyn_cast<OperationStmt>(&stmt); if (!opStmt) continue; - if (isDmaStartStmt(*opStmt)) { + if (opStmt->is<DmaStartOp>()) { dmaStartStmts.push_back(opStmt); - } else if (isDmaFinishStmt(*opStmt)) { + } else if (opStmt->is<DmaWaitOp>()) { dmaFinishStmts.push_back(opStmt); } } @@ -182,8 +165,8 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { // A DMA start statement has two memref's - the one from the higher level of // memory hierarchy is the one to double buffer. for (auto *dmaStartStmt : dmaStartStmts) { - MLValue *oldMemRef = cast<MLValue>( - dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt))); + MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand( + getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>()))); if (!doubleBuffer(oldMemRef, forStmt)) return PassResult::Failure; } @@ -208,10 +191,10 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { // TODO(bondhugula): check whether such statements do not have any DMAs // nested within. opDelayMap[&stmt] = 1; - } else if (isDmaStartStmt(*opStmt)) { + } else if (opStmt->is<DmaStartOp>()) { // DMA starts are not shifted. opDelayMap[&stmt] = 0; - } else if (isDmaFinishStmt(*opStmt)) { + } else if (opStmt->is<DmaWaitOp>()) { // DMA finish op shifted by one. opDelayMap[&stmt] = 1; } else if (!opStmt->is<AffineApplyOp>()) { diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp index cc1c7973858..0262eb94bd7 100644 --- a/mlir/lib/Transforms/Utils.cpp +++ b/mlir/lib/Transforms/Utils.cpp @@ -33,12 +33,9 @@ using namespace mlir; // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO(b/117228571) static bool isMemRefDereferencingOp(const Operation &op) { - if (op.is<LoadOp>() || op.is<StoreOp>() || - op.getName().strref().contains("dma.in.start") || - op.getName().strref().contains("dma.out.start") || - op.getName().strref().contains("dma.finish")) { + if (op.is<LoadOp>() || op.is<StoreOp>() || op.is<DmaStartOp>() || + op.is<DmaWaitOp>()) return true; - } return false; } |

