summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/PipelineDataTransfer.cpp
diff options
context:
space:
mode:
authorUday Bondhugula <bondhugula@google.com>2018-10-09 15:04:27 -0700
committerjpienaar <jpienaar@google.com>2019-03-29 13:26:10 -0700
commit82e55750d2dee6b927061574a31ed3eab2d92b16 (patch)
tree020112d0306200d177096dfa41f8112dc1080a02 /mlir/lib/Transforms/PipelineDataTransfer.cpp
parent2df03be62108132c73cb94a7a0fc6bd066031d88 (diff)
downloadbcm5719-llvm-82e55750d2dee6b927061574a31ed3eab2d92b16.tar.gz
bcm5719-llvm-82e55750d2dee6b927061574a31ed3eab2d92b16.zip
Add target independent standard DMA ops: dma.start, dma.wait
Add target independent standard DMA ops: dma.start, dma.wait. Update pipeline data transfer to use these to detect DMA ops. While on this - return failure from mlir-opt::performActions if a pass generates invalid output - improve error message for verify 'n' operand traits PiperOrigin-RevId: 216429885
Diffstat (limited to 'mlir/lib/Transforms/PipelineDataTransfer.cpp')
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp47
1 files changed, 15 insertions, 32 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 1eeb9a9aa5c..0d025f5678f 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -46,34 +46,16 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
-// op traits for it are added. TODO(b/117228571)
-static bool isDmaStartStmt(const OperationStmt &stmt) {
- return stmt.getName().strref().contains("dma.in.start") ||
- stmt.getName().strref().contains("dma.out.start");
-}
-
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
-// added. TODO(b/117228571)
-static bool isDmaFinishStmt(const OperationStmt &stmt) {
- return stmt.getName().strref().contains("dma.finish");
-}
-
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
- assert(isDmaStartStmt(dmaStartStmt));
+static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) {
unsigned srcDmaPos = 0;
- unsigned destDmaPos =
- cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
+ unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1;
- if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
- ->getMemorySpace() >
- cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
- ->getMemorySpace())
+ if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace())
return srcDmaPos;
return destDmaPos;
}
@@ -81,9 +63,9 @@ static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
- assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
- if (isDmaStartStmt(dmaStmt)) {
+static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
+ assert(dmaStmt.is<DmaStartOp>() || dmaStmt.is<DmaWaitOp>());
+ if (dmaStmt.is<DmaStartOp>()) {
// Second to last operand.
return dmaStmt.getNumOperands() - 2;
}
@@ -91,7 +73,8 @@ unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
return 0;
}
-/// Doubles the buffer of the supplied memref.
+/// Doubles the buffer of the supplied memref while replacing all uses of the
+/// old memref. Returns false if such a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
@@ -130,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
return true;
}
-// For testing purposes, this just runs on the first for statement of an
+// For testing purposes, this just runs on the first 'for' statement of an
// MLFunction at the top level.
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
// the other TODOs listed inside are dealt with.
@@ -158,9 +141,9 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
- if (isDmaStartStmt(*opStmt)) {
+ if (opStmt->is<DmaStartOp>()) {
dmaStartStmts.push_back(opStmt);
- } else if (isDmaFinishStmt(*opStmt)) {
+ } else if (opStmt->is<DmaWaitOp>()) {
dmaFinishStmts.push_back(opStmt);
}
}
@@ -182,8 +165,8 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
for (auto *dmaStartStmt : dmaStartStmts) {
- MLValue *oldMemRef = cast<MLValue>(
- dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
+ MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
+ getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>())));
if (!doubleBuffer(oldMemRef, forStmt))
return PassResult::Failure;
}
@@ -208,10 +191,10 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// TODO(bondhugula): check whether such statements do not have any DMAs
// nested within.
opDelayMap[&stmt] = 1;
- } else if (isDmaStartStmt(*opStmt)) {
+ } else if (opStmt->is<DmaStartOp>()) {
// DMA starts are not shifted.
opDelayMap[&stmt] = 0;
- } else if (isDmaFinishStmt(*opStmt)) {
+ } else if (opStmt->is<DmaWaitOp>()) {
// DMA finish op shifted by one.
opDelayMap[&stmt] = 1;
} else if (!opStmt->is<AffineApplyOp>()) {
OpenPOWER on IntegriCloud