summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Analysis/Dominance.h16
-rw-r--r--mlir/include/mlir/Analysis/Utils.h6
-rw-r--r--mlir/include/mlir/IR/Block.h9
-rw-r--r--mlir/include/mlir/IR/Instructions.h3
-rw-r--r--mlir/lib/Analysis/AffineAnalysis.cpp29
-rw-r--r--mlir/lib/Analysis/Dominance.cpp68
-rw-r--r--mlir/lib/Analysis/Utils.cpp77
-rw-r--r--mlir/lib/IR/Block.cpp5
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp7
-rw-r--r--mlir/lib/Transforms/MemRefDataFlowOpt.cpp24
-rw-r--r--mlir/lib/Transforms/PipelineDataTransfer.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/Utils.cpp7
12 files changed, 139 insertions, 114 deletions
diff --git a/mlir/include/mlir/Analysis/Dominance.h b/mlir/include/mlir/Analysis/Dominance.h
index 4c7dac31620..12bf5e4e9ae 100644
--- a/mlir/include/mlir/Analysis/Dominance.h
+++ b/mlir/include/mlir/Analysis/Dominance.h
@@ -74,6 +74,22 @@ public:
}
};
+/// A class for computing basic postdominance information.
+class PostDominanceInfo : public PostDominatorTreeBase {
+ using super = PostDominatorTreeBase;
+
+public:
+ PostDominanceInfo(Function *F);
+
+ /// Return true if instruction A properly postdominates instruction B.
+ bool properlyPostDominates(const Instruction *a, const Instruction *b);
+
+ /// Return true if instruction A postdominates instruction B.
+ bool postDominates(const Instruction *a, const Instruction *b) {
+ return a == b || properlyPostDominates(a, b);
+ }
+};
+
} // end namespace mlir
namespace llvm {
diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h
index e4ab4ffda1c..fc08caa83e8 100644
--- a/mlir/include/mlir/Analysis/Utils.h
+++ b/mlir/include/mlir/Analysis/Utils.h
@@ -50,12 +50,6 @@ bool properlyDominates(const Instruction &a, const Instruction &b);
// TODO(bondhugula): handle 'if' inst's.
void getLoopIVs(const Instruction &inst, SmallVectorImpl<ForInst *> *loops);
-/// Returns true if instruction 'a' postdominates instruction b.
-bool postDominates(const Instruction &a, const Instruction &b);
-
-/// Returns true if instruction 'a' properly postdominates instruction b.
-bool properlyPostDominates(const Instruction &a, const Instruction &b);
-
/// Returns the nesting depth of this instruction, i.e., the number of loops
/// surrounding this instruction.
unsigned getNestingDepth(const Instruction &stmt);
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 892f69a5187..5f7665260d3 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -156,11 +156,10 @@ public:
/// the latter fails.
/// TODO: This is very specific functionality that should live somewhere else,
/// probably in Dominance.cpp.
- const Instruction *findAncestorInstInBlock(const Instruction &inst) const;
- // TODO: it doesn't make sense for the former method to take the instruction
- // by reference but this one to take it by pointer.
- Instruction *findAncestorInstInBlock(Instruction *inst) {
- return const_cast<Instruction *>(findAncestorInstInBlock(*inst));
+ Instruction *findAncestorInstInBlock(Instruction *inst);
+ const Instruction *findAncestorInstInBlock(const Instruction &inst) const {
+ return const_cast<Block *>(this)->findAncestorInstInBlock(
+ const_cast<Instruction *>(&inst));
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h
index bd1b371ed06..5929dde4440 100644
--- a/mlir/include/mlir/IR/Instructions.h
+++ b/mlir/include/mlir/IR/Instructions.h
@@ -665,7 +665,8 @@ public:
}
private:
- // The Block for the body.
+ // The Block for the body. By construction, this list always contains exactly
+ // one block.
BlockList body;
// Affine map for the lower bound.
diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp
index bfdfb1a79b4..9275cc26b5f 100644
--- a/mlir/lib/Analysis/AffineAnalysis.cpp
+++ b/mlir/lib/Analysis/AffineAnalysis.cpp
@@ -937,17 +937,16 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess,
return cast<ForInst>(commonForValue)->getBody();
}
-// Returns true if the ancestor operation instruction of 'srcAccess' properly
-// dominates the ancestor operation instruction of 'dstAccess' in the same
+// Returns true if the ancestor operation instruction of 'srcAccess' appears
+// before the ancestor operation instruction of 'dstAccess' in the same
// instruction block. Returns false otherwise.
// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals,
-// the function is named 'srcMayExecuteBeforeDst'.
+// the function is named 'srcAppearsBeforeDstInCommonBlock'.
// Note that 'numCommonLoops' is the number of contiguous surrounding outer
// loops.
-static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess,
- const MemRefAccess &dstAccess,
- const FlatAffineConstraints &srcDomain,
- unsigned numCommonLoops) {
+static bool srcAppearsBeforeDstInCommonBlock(
+ const MemRefAccess &srcAccess, const MemRefAccess &dstAccess,
+ const FlatAffineConstraints &srcDomain, unsigned numCommonLoops) {
// Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'.
auto *commonBlock =
getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops);
@@ -957,7 +956,17 @@ static bool srcMayExecuteBeforeDst(const MemRefAccess &srcAccess,
assert(srcInst != nullptr);
auto *dstInst = commonBlock->findAncestorInstInBlock(*dstAccess.opInst);
assert(dstInst != nullptr);
- return mlir::properlyDominates(*srcInst, *dstInst);
+
+ // Do a linear scan to determine whether dstInst comes after srcInst.
+ auto aIter = Block::const_iterator(srcInst);
+ auto bIter = Block::const_iterator(dstInst);
+ auto aBlockStart = srcInst->getBlock()->begin();
+ while (bIter != aBlockStart) {
+ --bIter;
+ if (bIter == aIter)
+ return true;
+ }
+ return false;
}
// Adds ordering constraints to 'dependenceDomain' based on number of loops
@@ -1231,8 +1240,8 @@ bool mlir::checkMemrefAccessDependence(
unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain);
assert(loopDepth <= numCommonLoops + 1);
if (loopDepth > numCommonLoops &&
- !srcMayExecuteBeforeDst(srcAccess, dstAccess, srcDomain,
- numCommonLoops)) {
+ !srcAppearsBeforeDstInCommonBlock(srcAccess, dstAccess, srcDomain,
+ numCommonLoops)) {
return false;
}
// Build dim and symbol position maps for each access from access operand
diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp
index d99001f7fb3..13bf2827a63 100644
--- a/mlir/lib/Analysis/Dominance.cpp
+++ b/mlir/lib/Analysis/Dominance.cpp
@@ -25,8 +25,8 @@
#include "llvm/Support/GenericDomTreeConstruction.h"
using namespace mlir;
-template class llvm::DominatorTreeBase<Block, false>;
-template class llvm::DominatorTreeBase<Block, true>;
+template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/false>;
+template class llvm::DominatorTreeBase<Block, /*IsPostDom=*/true>;
template class llvm::DomTreeNodeBase<Block>;
/// Compute the immediate-dominators map.
@@ -35,6 +35,13 @@ DominanceInfo::DominanceInfo(Function *function) : DominatorTreeBase() {
recalculate(function->getBlockList());
}
+/// Compute the immediate-dominators map.
+PostDominanceInfo::PostDominanceInfo(Function *function)
+ : PostDominatorTreeBase() {
+ // Build the post dominator tree for the function.
+ recalculate(function->getBlockList());
+}
+
bool DominanceInfo::properlyDominates(const Block *a, const Block *b) {
// A block dominates itself but does not properly dominate itself.
if (a == b)
@@ -100,15 +107,17 @@ bool DominanceInfo::properlyDominates(const Instruction *a,
// If the blocks are different, but in the same function-level block list,
// then a standard block dominance query is sufficient.
- if (aBlock->getParent()->getContainingFunction() &&
- bBlock->getParent()->getContainingFunction())
+ auto *aFunction = aBlock->getParent()->getContainingFunction();
+ auto *bFunction = bBlock->getParent()->getContainingFunction();
+ if (aFunction && bFunction && aFunction == bFunction)
return DominatorTreeBase::properlyDominates(aBlock, bBlock);
// Traverse up b's hierarchy to check if b's block is contained in a's.
if (auto *bAncestor = aBlock->findAncestorInstInBlock(*b)) {
+ // Since we already know that aBlock != bBlock, here bAncestor != b.
// a and bAncestor are in the same block; check if 'a' dominates
// bAncestor.
- return properlyDominates(a, bAncestor);
+ return dominates(a, bAncestor);
}
return false;
@@ -128,3 +137,52 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) {
// we use a dominates check here, not a properlyDominates check.
return dominates(cast<BlockArgument>(a)->getOwner(), b->getBlock());
}
+
+/// Returns true if statement 'a' properly postdominates statement b.
+bool PostDominanceInfo::properlyPostDominates(const Instruction *a,
+ const Instruction *b) {
+ // If a/b are the same, then they don't properly dominate each other.
+ if (a == b)
+ return false;
+
+ auto *aBlock = a->getBlock();
+ auto *bBlock = b->getBlock();
+
+ // If the blocks are the same, then we do a linear scan.
+ if (aBlock == bBlock) {
+ // If one is a terminator, it postdominates the other.
+ if (a->isTerminator())
+ return true;
+
+ if (b->isTerminator())
+ return false;
+
+ // Otherwise, do a linear scan to determine whether A comes after B.
+ // TODO: This is an O(n) scan that can be bad for very large blocks.
+ auto aIter = Block::const_iterator(a);
+ auto bIter = Block::const_iterator(b);
+ auto fIter = bBlock->begin();
+ while (aIter != fIter) {
+ --aIter;
+ if (aIter == bIter)
+ return true;
+ }
+ return false;
+ }
+
+ // If the blocks are different, but in the same function-level block list,
+ // then a standard block dominance query is sufficient.
+ if (aBlock->getParent()->getContainingFunction() &&
+ bBlock->getParent()->getContainingFunction())
+ return PostDominatorTreeBase::properlyDominates(aBlock, bBlock);
+
+ // Traverse up b's hierarchy to check if b's block is contained in a's.
+ if (const auto *bAncestor = a->getBlock()->findAncestorInstInBlock(*b))
+ // Since we already know that aBlock != bBlock, here bAncestor != b.
+ // a and bAncestor are in the same block; check if 'a' postdominates
+ // bAncestor.
+ return postDominates(a, bAncestor);
+
+ // b's block is not contained in A's.
+ return false;
+}
diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp
index 75aec132060..7c07060386c 100644
--- a/mlir/lib/Analysis/Utils.cpp
+++ b/mlir/lib/Analysis/Utils.cpp
@@ -34,83 +34,6 @@
using namespace mlir;
-/// Returns true if instruction 'a' properly dominates instruction b.
-bool mlir::properlyDominates(const Instruction &a, const Instruction &b) {
- if (&a == &b)
- return false;
-
- if (a.getFunction() != b.getFunction())
- return false;
-
- if (a.getBlock() == b.getBlock()) {
- // Do a linear scan to determine whether b comes after a.
- auto aIter = Block::const_iterator(a);
- auto bIter = Block::const_iterator(b);
- auto aBlockStart = a.getBlock()->begin();
- while (bIter != aBlockStart) {
- --bIter;
- if (aIter == bIter)
- return true;
- }
- return false;
- }
-
- // Traverse up b's hierarchy to check if b's block is contained in a's.
- if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b))
- // a and bAncestor are in the same block; check if the former dominates it.
- return dominates(a, *bAncestor);
-
- // b's block is not contained in A.
- return false;
-}
-
-/// Returns true if statement 'a' properly postdominates statement b.
-bool mlir::properlyPostDominates(const Instruction &a, const Instruction &b) {
- // Only applicable to ML functions.
- assert(a.getFunction()->isML() && b.getFunction()->isML());
-
- if (&a == &b)
- return false;
-
- if (a.getFunction() != b.getFunction())
- return false;
-
- if (a.getBlock() == b.getBlock()) {
- // Do a linear scan to determine whether a comes after b.
- auto aIter = Block::const_iterator(a);
- auto bIter = Block::const_iterator(b);
- auto bBlockStart = b.getBlock()->begin();
- while (aIter != bBlockStart) {
- --aIter;
- if (aIter == bIter)
- return true;
- }
- return false;
- }
-
- // Traverse up b's hierarchy to check if b's block is contained in a's.
- if (const auto *bAncestor = a.getBlock()->findAncestorInstInBlock(b))
- // a and bAncestor are in the same block; check if 'a' postdominates
- // bAncestor.
- return postDominates(a, *bAncestor);
-
- // b's block is not contained in A's.
- return false;
-}
-
-/// Returns true if instruction A dominates instruction B.
-bool mlir::dominates(const Instruction &a, const Instruction &b) {
- return &a == &b || properlyDominates(a, b);
-}
-
-/// Returns true if statement A postdominates statement B.
-bool mlir::postDominates(const Instruction &a, const Instruction &b) {
- // Only applicable to ML functions.
- assert(a.getFunction()->isML() && b.getFunction()->isML());
-
- return &a == &b || properlyPostDominates(a, b);
-}
-
/// Populates 'loops' with IVs of the loops surrounding 'inst' ordered from
/// the outermost 'for' instruction to the innermost one.
void mlir::getLoopIVs(const Instruction &inst,
diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp
index 1b364034d47..a1b7aba15d3 100644
--- a/mlir/lib/IR/Block.cpp
+++ b/mlir/lib/IR/Block.cpp
@@ -60,11 +60,10 @@ void Block::eraseFromFunction() {
/// Returns 'inst' if 'inst' lies in this block, or otherwise finds the
/// ancestor instruction of 'inst' that lies in this block. Returns nullptr if
/// the latter fails.
-const Instruction *
-Block::findAncestorInstInBlock(const Instruction &inst) const {
+Instruction *Block::findAncestorInstInBlock(Instruction *inst) {
// Traverse up the instruction hierarchy starting from the owner of operand to
// find the ancestor instruction that resides in the block of 'forInst'.
- const auto *currInst = &inst;
+ auto *currInst = inst;
while (currInst->getBlock() != this) {
currInst = currInst->getParentInst();
if (!currInst)
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 37f0f571a0f..1ab1f6361d3 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -21,6 +21,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/MLFunctionMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"
@@ -650,10 +651,12 @@ static bool emitSlice(MaterializationState *state,
/// Additionally, this set is limited to instructions in the same lexical scope
/// because we currently disallow vectorization of defs that come from another
/// scope.
+/// TODO(ntv): please document return value.
static bool materialize(Function *f,
const SetVector<OperationInst *> &terminators,
MaterializationState *state) {
DenseSet<Instruction *> seen;
+ DominanceInfo domInfo(f);
for (auto *term : terminators) {
// Short-circuit test, a given terminator may have been reached by some
// other previous transitive use-def chains.
@@ -669,13 +672,13 @@ static bool materialize(Function *f,
// Note for the justification of this restriction.
// TODO(ntv): relax scoping constraints.
auto *enclosingScope = term->getParentInst();
- auto keepIfInSameScope = [enclosingScope](Instruction *inst) {
+ auto keepIfInSameScope = [enclosingScope, &domInfo](Instruction *inst) {
assert(inst && "NULL inst");
if (!enclosingScope) {
// by construction, everyone is always under the top scope (null scope).
return true;
}
- return properlyDominates(*enclosingScope, *inst);
+ return domInfo.properlyDominates(enclosingScope, inst);
};
SetVector<Instruction *> slice =
getSlice(term, keepIfInSameScope, keepIfInSameScope);
diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
index bad1cf9d101..1a30e2b289d 100644
--- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
+++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp
@@ -23,6 +23,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
@@ -73,6 +74,11 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
// A list of memref's that are potentially dead / could be eliminated.
SmallPtrSet<Value *, 4> memrefsToErase;
+ // Load op's whose results were replaced by those forwarded from stores.
+ std::vector<OperationInst *> loadOpsToErase;
+
+ DominanceInfo *domInfo = nullptr;
+ PostDominanceInfo *postDomInfo = nullptr;
static char passID;
};
@@ -152,7 +158,7 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
// strictly a necessary condition since dominance isn't a prerequisite for
// a memref element store to reach a load, but this is sufficient and
// reasonably powerful in practice.
- if (!dominates(*storeOpInst, *loadOpInst))
+ if (!domInfo->dominates(storeOpInst, loadOpInst))
break;
// Finally, forwarding is only possible if the load touches a single
@@ -182,7 +188,7 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
// unique store providing the value to the load, i.e., provably the last
// writer to that memref loc.
if (llvm::all_of(depSrcStores, [&](OperationInst *depStore) {
- return postDominates(*storeOpInst, *depStore);
+ return postDomInfo->postDominates(storeOpInst, depStore);
})) {
lastWriteStoreOp = storeOpInst;
break;
@@ -200,15 +206,27 @@ void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
loadOp->getResult()->replaceAllUsesWith(storeVal);
// Record the memref for a later sweep to optimize away.
memrefsToErase.insert(loadOp->getMemRef());
- loadOp->erase();
+ // Record this to erase later.
+ loadOpsToErase.push_back(loadOpInst);
}
PassResult MemRefDataFlowOpt::runOnMLFunction(Function *f) {
+ DominanceInfo theDomInfo(f);
+ domInfo = &theDomInfo;
+ PostDominanceInfo thePostDomInfo(f);
+ postDomInfo = &thePostDomInfo;
+
+ loadOpsToErase.clear();
memrefsToErase.clear();
// Walk all load's and perform load/store forwarding.
walk(f);
+ // Erase all load op's whose results were replaced with store fwd'ed ones.
+ for (auto *loadOp : loadOpsToErase) {
+ loadOp->erase();
+ }
+
// Check if the store fwd'ed memrefs are now left with only stores and can
// thus be completely deleted. Note: the canononicalize pass should be able
// to do this as well, but we'll do it here since we collected these anyway.
diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp
index debaac3a33c..321bf20cf0b 100644
--- a/mlir/lib/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp
@@ -227,7 +227,7 @@ static void findMatchingStartFinishInsts(
auto *memref = dmaStartOp->getOperand(dmaStartOp->getFasterMemPos());
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!dominates(*forInst->getBody()->begin(), *use.getOwner())) {
+ if (!forInst->getBody()->findAncestorInstInBlock(*use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp
index 8cfe2619e2a..b51c419f767 100644
--- a/mlir/lib/Transforms/Utils/Utils.cpp
+++ b/mlir/lib/Transforms/Utils/Utils.cpp
@@ -24,6 +24,7 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/InstVisitor.h"
@@ -82,13 +83,17 @@ bool mlir::replaceAllMemRefUsesWith(const Value *oldMemRef, Value *newMemRef,
assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
newMemRef->getType().cast<MemRefType>().getElementType());
+ std::unique_ptr<DominanceInfo> domInfo;
+ if (domInstFilter)
+ domInfo = std::make_unique<DominanceInfo>(domInstFilter->getFunction());
+
// Walk all uses of old memref. Operation using the memref gets replaced.
for (auto it = oldMemRef->use_begin(); it != oldMemRef->use_end();) {
InstOperand &use = *(it++);
auto *opInst = cast<OperationInst>(use.getOwner());
// Skip this use if it's not dominated by domInstFilter.
- if (domInstFilter && !dominates(*domInstFilter, *opInst))
+ if (domInstFilter && !domInfo->dominates(domInstFilter, opInst))
continue;
// Check if the memref was used in a non-deferencing context. It is fine for
OpenPOWER on IntegriCloud