summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp47
-rw-r--r--mlir/lib/Transforms/Utils.cpp7
2 files changed, 17 insertions, 37 deletions
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 1eeb9a9aa5c..0d025f5678f 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -46,34 +46,16 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
-// op traits for it are added. TODO(b/117228571)
-static bool isDmaStartStmt(const OperationStmt &stmt) {
- return stmt.getName().strref().contains("dma.in.start") ||
- stmt.getName().strref().contains("dma.out.start");
-}
-
-// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
-// added. TODO(b/117228571)
-static bool isDmaFinishStmt(const OperationStmt &stmt) {
- return stmt.getName().strref().contains("dma.finish");
-}
-
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
- assert(isDmaStartStmt(dmaStartStmt));
+static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) {
unsigned srcDmaPos = 0;
- unsigned destDmaPos =
- cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
+ unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1;
- if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
- ->getMemorySpace() >
- cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
- ->getMemorySpace())
+ if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace())
return srcDmaPos;
return destDmaPos;
}
@@ -81,9 +63,9 @@ static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
-unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
- assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
- if (isDmaStartStmt(dmaStmt)) {
+static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
+ assert(dmaStmt.is<DmaStartOp>() || dmaStmt.is<DmaWaitOp>());
+ if (dmaStmt.is<DmaStartOp>()) {
// Second to last operand.
return dmaStmt.getNumOperands() - 2;
}
@@ -91,7 +73,8 @@ unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
return 0;
}
-/// Doubles the buffer of the supplied memref.
+/// Doubles the buffer of the supplied memref while replacing all uses of the
+/// old memref. Returns false if such a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
@@ -130,7 +113,7 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
return true;
}
-// For testing purposes, this just runs on the first for statement of an
+// For testing purposes, this just runs on the first 'for' statement of an
// MLFunction at the top level.
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
// the other TODOs listed inside are dealt with.
@@ -158,9 +141,9 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
- if (isDmaStartStmt(*opStmt)) {
+ if (opStmt->is<DmaStartOp>()) {
dmaStartStmts.push_back(opStmt);
- } else if (isDmaFinishStmt(*opStmt)) {
+ } else if (opStmt->is<DmaWaitOp>()) {
dmaFinishStmts.push_back(opStmt);
}
}
@@ -182,8 +165,8 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
for (auto *dmaStartStmt : dmaStartStmts) {
- MLValue *oldMemRef = cast<MLValue>(
- dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
+ MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
+ getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>())));
if (!doubleBuffer(oldMemRef, forStmt))
return PassResult::Failure;
}
@@ -208,10 +191,10 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
// TODO(bondhugula): check whether such statements do not have any DMAs
// nested within.
opDelayMap[&stmt] = 1;
- } else if (isDmaStartStmt(*opStmt)) {
+ } else if (opStmt->is<DmaStartOp>()) {
// DMA starts are not shifted.
opDelayMap[&stmt] = 0;
- } else if (isDmaFinishStmt(*opStmt)) {
+ } else if (opStmt->is<DmaWaitOp>()) {
// DMA finish op shifted by one.
opDelayMap[&stmt] = 1;
} else if (!opStmt->is<AffineApplyOp>()) {
diff --git a/mlir/lib/Transforms/Utils.cpp b/mlir/lib/Transforms/Utils.cpp
index cc1c7973858..0262eb94bd7 100644
--- a/mlir/lib/Transforms/Utils.cpp
+++ b/mlir/lib/Transforms/Utils.cpp
@@ -33,12 +33,9 @@ using namespace mlir;
// Temporary utility: will be replaced when this is modeled through
// side-effects/op traits. TODO(b/117228571)
static bool isMemRefDereferencingOp(const Operation &op) {
- if (op.is<LoadOp>() || op.is<StoreOp>() ||
- op.getName().strref().contains("dma.in.start") ||
- op.getName().strref().contains("dma.out.start") ||
- op.getName().strref().contains("dma.finish")) {
+ if (op.is<LoadOp>() || op.is<StoreOp>() || op.is<DmaStartOp>() ||
+ op.is<DmaWaitOp>())
return true;
- }
return false;
}
OpenPOWER on IntegriCloud