diff options
| -rw-r--r-- | mlir/include/mlir/Analysis/Dominance.h | 16 | ||||
| -rw-r--r-- | mlir/include/mlir/Analysis/Utils.h | 6 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Block.h | 9 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Instructions.h | 3 | ||||
| -rw-r--r-- | mlir/lib/Analysis/AffineAnalysis.cpp | 29 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Dominance.cpp | 68 | ||||
| -rw-r--r-- | mlir/lib/Analysis/Utils.cpp | 77 | ||||
| -rw-r--r-- | mlir/lib/IR/Block.cpp | 5 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 7 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MemRefDataFlowOpt.cpp | 24 | ||||
| -rw-r--r-- | mlir/lib/Transforms/PipelineDataTransfer.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 7 |
12 files changed, 139 insertions, 114 deletions
diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h index 4c7dac31620..12bf5e4e9ae 100644 --- a/mlir/include/mlir/Analysis/Dominance.h +++ b/mlir/include/mlir/Analysis/Dominance.h @@ -74,6 +74,22 @@ public: } }; +/// A class for computing basic postdominance information. +class PostDominanceInfo : public PostDominatorTreeBase { + using super = PostDominatorTreeBase; + +public: + PostDominanceInfo(Function *F); + + /// Return true if instruction A properly postdominates instruction B. + bool properlyPostDominates(const Instruction *a, const Instruction *b); + + /// Return true if instruction A postdominates instruction B. + bool postDominates(const Instruction *a, const Instruction *b) { + return a == b || properlyPostDominates(a, b); + } +}; + } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index e4ab4ffda1c..fc08caa83e8 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -50,12 +50,6 @@ bool properlyDominates(const Instruction &a, const Instruction &b); // TODO(bondhugula): handle 'if' inst's. void getLoopIVs(const Instruction &inst, SmallVectorImpl<ForInst *> *loops); -/// Returns true if instruction 'a' postdominates instruction b. -bool postDominates(const Instruction &a, const Instruction &b); - -/// Returns true if instruction 'a' properly postdominates instruction b. -bool properlyPostDominates(const Instruction &a, const Instruction &b); - /// Returns the nesting depth of this instruction, i.e., the number of loops /// surrounding this instruction. unsigned getNestingDepth(const Instruction &stmt); diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index 892f69a5187..5f7665260d3 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -156,11 +156,10 @@ public: /// the latter fails. /// TODO: This is very specific functionality that should live somewhere else, /// probably in Dominance.cpp. - const Instruction *findAncestorInstInBlock(const Instruction &inst) const; - // TODO: it doesn't make sense for the former method to take the instruction - // by reference but this one to take it by pointer. - Instruction *findAncestorInstInBlock(Instruction *inst) { - return const_cast<Instruction *>(findAncestorInstInBlock(*inst)); + Instruction *findAncestorInstInBlock(Instruction *inst); + const Instruction *findAncestorInstInBlock(const Instruction &inst) const { + return const_cast<Block *>(this)->findAncestorInstInBlock( + const_cast<Instruction *>(&inst)); } //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index bd1b371ed06..5929dde4440 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -665,7 +665,8 @@ public: } private: - // The Block for the body. + // The Block for the body. By construction, this list always contains exactly + // one block. BlockList body; // Affine map for the lower bound. diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index bfdfb1a79b4..9275cc26b5f 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -937,17 +937,16 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess, return cast<ForInst>(commonForValue)->getBody(); } -// Returns true if the ancestor operation instruction of 'srcAccess' properly -// dominates the ancestor operation instruction of 'dstAccess' in the same +// Returns true if the ancestor operation instruction of 'srcAccess' appears +// before the ancestor operation instruction of 'dstAccess' in the same // instruction block. Returns false otherwise. // Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals, -// the function is named 'srcMayExecuteBeforeDst'. +// the function is named 'srcAppearsBeforeDstInCommonBlock'. // Note that 'numCommonLoops' is the number of contiguous surrounding outer // loops. -static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - const FlatAffineConstraints &srcDomain, - unsigned numCommonLoops) { +static bool srcAppearsBeforeDstInCommonBlock( + const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, + const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) { // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. auto *commonBlock = getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); @@ -957,7 +956,17 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess, assert(srcInst != nullptr); auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst); assert(dstInst != nullptr); - return mlir::properlyDominates(*srcInst, *dstInst); + + // Do a linear scan to determine whether dstInst comes after srcInst. + auto aIter = Block::const_iterator(srcInst); + auto bIter = Block::const_iterator(dstInst); + auto aBlockStart = srcInst->getBlock()->begin(); + while (bIter != aBlockStart) { + --bIter; + if (bIter == aIter) + return true; + } + return false; } // Adds ordering constraints to 'dependenceDomain' based on number of loops @@ -1231,8 +1240,8 @@ bool mlir::checkMemrefAccessDependence( unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); assert(loopDepth <= numCommonLoops + 1); if (loopDepth > numCommonLoops && - !srcMayExecuteBeforeDst(srcAccess, dstAccess, srcDomain, - numCommonLoops)) { + !srcAppearsBeforeDstInCommonBlock(srcAccess, dstAccess, srcDomain, + numCommonLoops)) { return false; } // Build dim and symbol position maps for each access from access operand diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index d99001f7fb3..13bf2827a63 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -25,8 +25,8 @@ #include "llvm/Support/GenericDomTreeConstruction.h" using namespace mlir; -template class llvm::DominatorTreeBase<Block, false>; -template class llvm::DominatorTreeBase<Block, true>; +template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/false>; +template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/true>; template class llvm::DomTreeNodeBase<Block>; /// Compute the immediate-dominators map. @@ -35,6 +35,13 @@ DominanceInfo::DominanceInfo(Function *function) : DominatorTreeBase() { recalculate(function->getBlockList()); } +/// Compute the immediate-dominators map. +PostDominanceInfo::PostDominanceInfo(Function *function) + : PostDominatorTreeBase() { + // Build the post dominator tree for the function. + recalculate(function->getBlockList()); +} + bool DominanceInfo::properlyDominates(const Block *a, const Block *b) { // A block dominates itself but does not properly dominate itself. if (a == b) @@ -100,15 +107,17 @@ bool DominanceInfo::properlyDominates(const Instruction *a, // If the blocks are different, but in the same function-level block list, // then a standard block dominance query is sufficient. - if (aBlock->getParent()->getContainingFunction() && - bBlock->getParent()->getContainingFunction()) + auto *aFunction = aBlock->getParent()->getContainingFunction(); + auto *bFunction = bBlock->getParent()->getContainingFunction(); + if (aFunction && bFunction && aFunction == bFunction) return DominatorTreeBase::properlyDominates(aBlock, bBlock); // Traverse up b's hierarchy to check if b's block is contained in a's. if (auto *bAncestor = aBlock->findAncestorInstInBlock(*b)) { + // Since we already know that aBlock != bBlock, here bAncestor != b. // a and bAncestor are in the same block; check if 'a' dominates // bAncestor. - return properlyDominates(a, bAncestor); + return dominates(a, bAncestor); } return false; @@ -128,3 +137,52 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { // we use a dominates check here, not a properlyDominates check. return dominates(cast<BlockArgument>(a)->getOwner(), b->getBlock()); } + +/// Returns true if statement 'a' properly postdominates statement b. +bool PostDominanceInfo::properlyPostDominates(const Instruction *a, + const Instruction *b) { + // If a/b are the same, then they don't properly dominate each other. + if (a == b) + return false; + + auto *aBlock = a->getBlock(); + auto *bBlock = b->getBlock(); + + // If the blocks are the same, then we do a linear scan. + if (aBlock == bBlock) { + // If one is a terminator, it postdominates the other. + if (a->isTerminator()) + return true; + + if (b->isTerminator()) + return false; + + // Otherwise, do a linear scan to determine whether A comes after B. + // TODO: This is an O(n) scan that can be bad for very large blocks. + auto aIter = Block::const_iterator(a); + auto bIter = Block::const_iterator(b); + auto fIter = bBlock->begin(); + while (aIter != fIter) { + --aIter; + if (aIter == bIter) + return true; + } + return false; + } + + // If the blocks are different, but in the same function-level block list, + // then a standard block dominance query is sufficient. + if (aBlock->getParent()->getContainingFunction() && + bBlock->getParent()->getContainingFunction()) + return PostDominatorTreeBase::properlyDominates(aBlock, bBlock); + + // Traverse up b's hierarchy to check if b's block is contained in a's. + if (const auto *bAncestor = a->getBlock()->findAncestorInstInBlock(*b)) + // Since we already know that aBlock != bBlock, here bAncestor != b. + // a and bAncestor are in the same block; check if 'a' postdominates + // bAncestor. + return postDominates(a, bAncestor); + + // b's block is not contained in A's. + return false; +} diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 75aec132060..7c07060386c 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -34,83 +34,6 @@ using namespace mlir; -/// Returns true if instruction 'a' properly dominates instruction b. -bool mlir::properlyDominates(const Instruction &a, const Instruction &b) { - if (&a == &b) - return false; - - if (a.getFunction() != b.getFunction()) - return false; - - if (a.getBlock() == b.getBlock()) { - // Do a linear scan to determine whether b comes after a. - auto aIter = Block::const_iterator(a); - auto bIter = Block::const_iterator(b); - auto aBlockStart = a.getBlock()->begin(); - while (bIter != aBlockStart) { - --bIter; - if (aIter == bIter) - return true; - } - return false; - } - - // Traverse up b's hierarchy to check if b's block is contained in a's. - if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b)) - // a and bAncestor are in the same block; check if the former dominates it. - return dominates(a, *bAncestor); - - // b's block is not contained in A. - return false; -} - -/// Returns true if statement 'a' properly postdominates statement b. -bool mlir::properlyPostDominates(const Instruction &a, const Instruction &b) { - // Only applicable to ML functions. - assert(a.getFunction()->isML() && b.getFunction()->isML()); - - if (&a == &b) - return false; - - if (a.getFunction() != b.getFunction()) - return false; - - if (a.getBlock() == b.getBlock()) { - // Do a linear scan to determine whether a comes after b. - auto aIter = Block::const_iterator(a); - auto bIter = Block::const_iterator(b); - auto bBlockStart = b.getBlock()->begin(); - while (aIter != bBlockStart) { - --aIter; - if (aIter == bIter) - return true; - } - return false; - } - - // Traverse up b's hierarchy to check if b's block is contained in a's. - if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b)) - // a and bAncestor are in the same block; check if 'a' postdominates - // bAncestor. - return postDominates(a, *bAncestor); - - // b's block is not contained in A's. - return false; -} - -/// Returns true if instruction A dominates instruction B. -bool mlir::dominates(const Instruction &a, const Instruction &b) { - return &a == &b || properlyDominates(a, b); -} - -/// Returns true if statement A postdominates statement B. -bool mlir::postDominates(const Instruction &a, const Instruction &b) { - // Only applicable to ML functions. - assert(a.getFunction()->isML() && b.getFunction()->isML()); - - return &a == &b || properlyPostDominates(a, b); -} - /// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from /// the outermost 'for' instruction to the innermost one. void mlir::getLoopIVs(const Instruction &inst, diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 1b364034d47..a1b7aba15d3 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -60,11 +60,10 @@ void Block::eraseFromFunction() { /// Returns 'inst' if 'inst' lies in this block, or otherwise finds the /// ancestor instruction of 'inst' that lies in this block. Returns nullptr if /// the latter fails. -const Instruction * -Block::findAncestorInstInBlock(const Instruction &inst) const { +Instruction *Block::findAncestorInstInBlock(Instruction *inst) { // Traverse up the instruction hierarchy starting from the owner of operand to // find the ancestor instruction that resides in the block of 'forInst'. - const auto *currInst = &inst; + auto *currInst = inst; while (currInst->getBlock() != this) { currInst = currInst->getParentInst(); if (!currInst) diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 37f0f571a0f..1ab1f6361d3 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" #include "mlir/Analysis/MLFunctionMatcher.h" #include "mlir/Analysis/SliceAnalysis.h" @@ -650,10 +651,12 @@ static bool emitSlice(MaterializationState *state, /// Additionally, this set is limited to instructions in the same lexical scope /// because we currently disallow vectorization of defs that come from another /// scope. +/// TODO(ntv): please document return value. static bool materialize(Function *f, const SetVector<OperationInst *> &terminators, MaterializationState *state) { DenseSet<Instruction *> seen; + DominanceInfo domInfo(f); for (auto *term : terminators) { // Short-circuit test, a given terminator may have been reached by some // other previous transitive use-def chains. @@ -669,13 +672,13 @@ static bool materialize(Function *f, // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. auto *enclosingScope = term->getParentInst(); - auto keepIfInSameScope = [enclosingScope](Instruction *inst) { + auto keepIfInSameScope = [enclosingScope, &domInfo](Instruction *inst) { assert(inst && "NULL inst"); if (!enclosingScope) { // by construction, everyone is always under the top scope (null scope). return true; } - return properlyDominates(*enclosingScope, *inst); + return domInfo.properlyDominates(enclosingScope, inst); }; SetVector<Instruction *> slice = getSlice(term, keepIfInSameScope, keepIfInSameScope); diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp index bad1cf9d101..1a30e2b289d 100644 --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Analysis/AffineAnalysis.h" +#include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/InstVisitor.h" #include "mlir/Pass.h" @@ -73,6 +74,11 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> { // A list of memref's that are potentially dead / could be eliminated. SmallPtrSet<Value *, 4> memrefsToErase; + // Load op's whose results were replaced by those forwarded from stores. + std::vector<OperationInst *> loadOpsToErase; + + DominanceInfo *domInfo = nullptr; + PostDominanceInfo *postDomInfo = nullptr; static char passID; }; @@ -152,7 +158,7 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { // strictly a necessary condition since dominance isn't a prerequisite for // a memref element store to reach a load, but this is sufficient and // reasonably powerful in practice. - if (!dominates(*storeOpInst, *loadOpInst)) + if (!domInfo->dominates(storeOpInst, loadOpInst)) break; // Finally, forwarding is only possible if the load touches a single @@ -182,7 +188,7 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { // unique store providing the value to the load, i.e., provably the last // writer to that memref loc. if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) { - return postDominates(*storeOpInst, *depStore); + return postDomInfo->postDominates(storeOpInst, depStore); })) { lastWriteStoreOp = storeOpInst; break; @@ -200,15 +206,27 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) { loadOp->getResult()->replaceAllUsesWith(storeVal); // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp->getMemRef()); - loadOp->erase(); + // Record this to erase later. + loadOpsToErase.push_back(loadOpInst); } PassResult MemRefDataFlowOpt::runOnMLFunction(Function *f) { + DominanceInfo theDomInfo(f); + domInfo = &theDomInfo; + PostDominanceInfo thePostDomInfo(f); + postDomInfo = &thePostDomInfo; + + loadOpsToErase.clear(); memrefsToErase.clear(); // Walk all load's and perform load/store forwarding. walk(f); + // Erase all load op's whose results were replaced with store fwd'ed ones. + for (auto *loadOp : loadOpsToErase) { + loadOp->erase(); + } + // Check if the store fwd'ed memrefs are now left with only stores and can // thus be completely deleted. Note: the canononicalize pass should be able // to do this as well, but we'll do it here since we collected these anyway. diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index debaac3a33c..321bf20cf0b 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -227,7 +227,7 @@ static void findMatchingStartFinishInsts( auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()); bool escapingUses = false; for (const auto &use : memref->getUses()) { - if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) { + if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) { LLVM_DEBUG(llvm::dbgs() << "can't pipeline: buffer is live out of loop\n";); escapingUses = true; diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 8cfe2619e2a..b51c419f767 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -24,6 +24,7 @@ #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" +#include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/Utils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/InstVisitor.h" @@ -82,13 +83,17 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef, assert(oldMemRef->getType().cast<MemRefType>().getElementType() == newMemRef->getType().cast<MemRefType>().getElementType()); + std::unique_ptr<DominanceInfo> domInfo; + if (domInstFilter) + domInfo = std::make_unique<DominanceInfo>(domInstFilter->getFunction()); + // Walk all uses of old memref. Operation using the memref gets replaced. for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) { InstOperand &use = *(it++); auto *opInst = cast<OperationInst>(use.getOwner()); // Skip this use if it's not dominated by domInstFilter. - if (domInstFilter && !dominates(*domInstFilter, *opInst)) + if (domInstFilter && !domInfo->dominates(domInstFilter, opInst)) continue; // Check if the memref was used in a non-deferencing context. It is fine for |

