diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Transforms/LoopUtils.cpp | 17 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 22 |
2 files changed, 22 insertions, 17 deletions
diff --git a/mlir/lib/Transforms/LoopUtils.cpp b/mlir/lib/Transforms/LoopUtils.cpp index e1f1b6914d0..cc3de09ebb6 100644 --- a/mlir/lib/Transforms/LoopUtils.cpp +++ b/mlir/lib/Transforms/LoopUtils.cpp @@ -30,6 +30,9 @@ #include "mlir/IR/StmtVisitor.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "LoopUtils" using namespace mlir; @@ -205,12 +208,14 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays, // better way to pipeline for such loops is to first tile them and extract // constant trip count "full tiles" before applying this. auto mayBeConstTripCount = getConstantTripCount(*forStmt); - if (!mayBeConstTripCount.hasValue()) - return UtilResult::Failure; + if (!mayBeConstTripCount.hasValue()) { + LLVM_DEBUG(llvm::dbgs() << "non-constant trip count loop\n";); + return UtilResult::Success; + } uint64_t tripCount = mayBeConstTripCount.getValue(); assert(isStmtwiseShiftValid(*forStmt, delays) && - "dominance preservation failed\n"); + "shifts will lead to an invalid transformation\n"); unsigned numChildStmts = forStmt->getStatements().size(); @@ -220,8 +225,10 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays, maxDelay = std::max(maxDelay, delays[i]); } // Such large delays are not the typical use case. - if (maxDelay >= numChildStmts) - return UtilResult::Failure; + if (maxDelay >= numChildStmts) { + LLVM_DEBUG(llvm::dbgs() << "stmt delays too large - unexpected\n";); + return UtilResult::Success; + } // An array of statement groups sorted by delay amount; each group has all // statements with the same delay in the order in which they appear in the diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index 187205bb15a..91f3b845a98 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -114,9 +114,6 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) { /// Returns false if this succeeds on at least one 'for' stmt. PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) { - if (f->empty()) - return PassResult::Success; - // Do a post order walk so that inner loop DMAs are processed first. This is // necessary since 'for' statements nested within would otherwise become // invalid (erased) when the outer loop is pipelined (the pipelined one gets @@ -213,7 +210,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { auto mayBeConstTripCount = getConstantTripCount(*forStmt); if (!mayBeConstTripCount.hasValue()) { LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n"); - return PassResult::Failure; + return success(); } SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs; @@ -221,7 +218,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { if (startWaitPairs.empty()) { LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";); - return failure(); + return success(); } // Double the buffers for the higher memory space memref's. @@ -241,7 +238,8 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { // that there are no uses outside. LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";); LLVM_DEBUG(dmaStartStmt->dump()); - return failure(); + // IR still in a valid state. + return success(); } } @@ -252,7 +250,7 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt))); if (!doubleBuffer(oldTagMemRef, forStmt)) { LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";); - return failure(); + return success(); } } @@ -296,14 +294,14 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) { } if (!isStmtwiseShiftValid(*forStmt, delays)) { - // Violates SSA dominance. - LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";); - return PassResult::Failure; + // Violates dependences. + LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";); + return success(); } if (stmtBodySkew(forStmt, delays)) { - LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";); - return PassResult::Failure; + LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed - unexpected\n";); + return success(); } return success(); |

