summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Transforms/LoopUtils.cpp17
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp22
2 files changed, 22 insertions, 17 deletions
diff --git a/mlir/lib/Transforms/LoopUtils.cpp b/mlir/lib/Transforms/LoopUtils.cpp
index e1f1b6914d0..cc3de09ebb6 100644
--- a/mlir/lib/Transforms/LoopUtils.cpp
+++ b/mlir/lib/Transforms/LoopUtils.cpp
@@ -30,6 +30,9 @@
#include "mlir/IR/StmtVisitor.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "LoopUtils"
using namespace mlir;
@@ -205,12 +208,14 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
// better way to pipeline for such loops is to first tile them and extract
// constant trip count "full tiles" before applying this.
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
- if (!mayBeConstTripCount.hasValue())
- return UtilResult::Failure;
+ if (!mayBeConstTripCount.hasValue()) {
+ LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";);
+ return UtilResult::Success;
+ }
uint64_t tripCount = mayBeConstTripCount.getValue();
assert(isStmtwiseShiftValid(*forStmt, delays) &&
- "dominance preservation failed\n");
+ "shifts will lead to an invalid transformation\n");
unsigned numChildStmts = forStmt->getStatements().size();
@@ -220,8 +225,10 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
maxDelay = std::max(maxDelay, delays[i]);
}
// Such large delays are not the typical use case.
- if (maxDelay >= numChildStmts)
- return UtilResult::Failure;
+ if (maxDelay >= numChildStmts) {
+ LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";);
+ return UtilResult::Success;
+ }
// An array of statement groups sorted by delay amount; each group has all
// statements with the same delay in the order in which they appear in the
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index 187205bb15a..91f3b845a98 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -114,9 +114,6 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
/// Returns false if this succeeds on at least one 'for' stmt.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
- if (f->empty())
- return PassResult::Success;
-
// Do a post order walk so that inner loop DMAs are processed first. This is
// necessary since 'for' statements nested within would otherwise become
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
@@ -213,7 +210,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
- return PassResult::Failure;
+ return success();
}
SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs;
@@ -221,7 +218,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
if (startWaitPairs.empty()) {
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
- return failure();
+ return success();
}
// Double the buffers for the higher memory space memref's.
@@ -241,7 +238,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
LLVM_DEBUG(dmaStartStmt->dump());
- return failure();
+ // IR still in a valid state.
+ return success();
}
}
@@ -252,7 +250,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
if (!doubleBuffer(oldTagMemRef, forStmt)) {
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
- return failure();
+ return success();
}
}
@@ -296,14 +294,14 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
}
if (!isStmtwiseShiftValid(*forStmt, delays)) {
- // Violates SSA dominance.
- LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";);
- return PassResult::Failure;
+ // Violates dependences.
+ LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
+ return success();
}
if (stmtBodySkew(forStmt, delays)) {
- LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";);
- return PassResult::Failure;
+ LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";);
+ return success();
}
return success();
OpenPOWER on IntegriCloud