summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/NestedMatcher.h165
-rw-r--r--mlir/lib/Analysis/LoopAnalysis.cpp14
-rw-r--r--mlir/lib/Analysis/NestedMatcher.cpp158
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp11
-rw-r--r--mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp71
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp192
6 files changed, 291 insertions, 320 deletions
diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h
index 161bb217a10..2a1c469348d 100644
--- a/mlir/include/mlir/Analysis/NestedMatcher.h
+++ b/mlir/include/mlir/Analysis/NestedMatcher.h
@@ -20,20 +20,19 @@
#include "mlir/IR/InstVisitor.h"
#include "llvm/Support/Allocator.h"
-#include <utility>
namespace mlir {
-struct NestedPatternStorage;
-struct NestedMatchStorage;
+struct NestedPattern;
class Instruction;
-/// An NestedPattern captures nested patterns. It is used in conjunction with
-/// a scoped NestedPatternContext which is an llvm::BumPtrAllocator that
-/// handles memory allocations efficiently and avoids ownership issues.
+/// An NestedPattern captures nested patterns in the IR.
+/// It is used in conjunction with a scoped NestedPatternContext which is an
+/// llvm::BumpPtrAllocator that handles memory allocations efficiently and
+/// avoids ownership issues.
///
-/// In order to use NestedPatterns, first create a scoped context. When the
-/// context goes out of scope, everything is freed.
+/// In order to use NestedPatterns, first create a scoped context.
+/// When the context goes out of scope, everything is freed.
/// This design simplifies the API by avoiding references to the context and
/// makes it clear that references to matchers must not escape.
///
@@ -45,109 +44,145 @@ class Instruction;
/// // do work on matches
/// } // everything is freed
///
-
-/// Recursive abstraction for matching results.
-/// Provides iteration over the Instruction* captured by a Matcher.
///
-/// Implemented as a POD value-type with underlying storage pointer.
-/// The underlying storage lives in a scoped bumper allocator whose lifetime
-/// is managed by an RAII NestedPatternContext.
-/// This is used by value everywhere.
+/// Nested abstraction for matching results.
+/// Provides access to the nested Instruction* captured by a Matcher.
+///
+/// A NestedMatch contains an Instruction* and the children NestedMatch and is
+/// thus cheap to copy. NestedMatch is stored in a scoped bumper allocator whose
+/// lifetime is managed by an RAII NestedPatternContext.
struct NestedMatch {
- using EntryType = std::pair<Instruction *, NestedMatch>;
- using iterator = EntryType *;
-
- static NestedMatch build(ArrayRef<NestedMatch::EntryType> elements = {});
+ static NestedMatch build(Instruction *instruction,
+ ArrayRef<NestedMatch> nestedMatches);
NestedMatch(const NestedMatch &) = default;
NestedMatch &operator=(const NestedMatch &) = default;
- explicit operator bool() { return !empty(); }
+ explicit operator bool() { return matchedInstruction != nullptr; }
- iterator begin();
- iterator end();
- EntryType &front();
- EntryType &back();
- unsigned size() { return end() - begin(); }
- unsigned empty() { return size() == 0; }
+ Instruction *getMatchedInstruction() { return matchedInstruction; }
+ ArrayRef<NestedMatch> getMatchedChildren() { return matchedChildren; }
private:
friend class NestedPattern;
friend class NestedPatternContext;
- friend class NestedMatchStorage;
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
- NestedMatch(NestedMatchStorage *storage) : storage(storage){};
-
- /// Copy the specified array of elements into memory managed by our bump
- /// pointer allocator. The elements are all PODs by constructions.
- static NestedMatch copyInto(ArrayRef<NestedMatch::EntryType> elements);
+ NestedMatch() = default;
- /// POD payload.
- NestedMatchStorage *storage;
+ /// Payload, holds a NestedMatch and all its children along this branch.
+ Instruction *matchedInstruction;
+ ArrayRef<NestedMatch> matchedChildren;
};
-/// A NestedPattern is a special type of InstWalker that:
+/// A NestedPattern is a nested InstWalker that:
/// 1. recursively matches a substructure in the tree;
/// 2. uses a filter function to refine matches with extra semantic
/// constraints (passed via a lambda of type FilterFunctionType);
-/// 3. TODO(ntv) Optionally applies actions (lambda).
+/// 3. TODO(ntv) optionally applies actions (lambda).
///
-/// Implemented as a POD value-type with underlying storage pointer.
-/// The underlying storage lives in a scoped bumper allocator whose lifetime
-/// is managed by an RAII NestedPatternContext.
-/// This should be used by value everywhere.
+/// Nested patterns are meant to capture imperfectly nested loops while matching
+/// properties over the whole loop nest. For instance, in vectorization we are
+/// interested in capturing all the imperfectly nested loops of a certain type
+/// and such that all the load and stores have certain access patterns along the
+/// loops' induction variables). Such NestedMatches are first captured using the
+/// `match` function and are later processed to analyze properties and apply
+/// transformations in a non-greedy way.
+///
+/// The NestedMatches captured in the IR can grow large, especially after
+/// aggressive unrolling. As experience has shown, it is generally better to use
+/// a plain InstWalker to match flat patterns but the current implementation is
+/// competitive nonetheless.
using FilterFunctionType = std::function<bool(const Instruction &)>;
static bool defaultFilterFunction(const Instruction &) { return true; };
-struct NestedPattern : public InstWalker<NestedPattern> {
+struct NestedPattern {
NestedPattern(Instruction::Kind k, ArrayRef<NestedPattern> nested,
FilterFunctionType filter = defaultFilterFunction);
NestedPattern(const NestedPattern &) = default;
NestedPattern &operator=(const NestedPattern &) = default;
- /// Returns all the matches in `function`.
- NestedMatch match(Function *function);
+ /// Returns all the top-level matches in `function`.
+ void match(Function *function, SmallVectorImpl<NestedMatch> *matches) {
+ State state(*this, matches);
+ state.walkPostOrder(function);
+ }
- /// Returns all the matches nested under `instruction`.
- NestedMatch match(Instruction *instruction);
+ /// Returns all the top-level matches in `inst`.
+ void match(Instruction *inst, SmallVectorImpl<NestedMatch> *matches) {
+ State state(*this, matches);
+ state.walkPostOrder(inst);
+ }
- unsigned getDepth();
+ /// Returns the depth of the pattern.
+ unsigned getDepth() const;
private:
friend class NestedPatternContext;
- friend InstWalker<NestedPattern>;
+ friend class NestedMatch;
+ friend class InstWalker<NestedPattern>;
+ friend struct State;
+
+ /// Helper state that temporarily holds matches for the next level of nesting.
+ struct State : public InstWalker<State> {
+ State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches)
+ : pattern(pattern), matches(matches) {}
+ void visitForInst(ForInst *forInst) { pattern.matchOne(forInst, matches); }
+ void visitOperationInst(OperationInst *opInst) {
+ pattern.matchOne(opInst, matches);
+ }
+
+ private:
+ NestedPattern &pattern;
+ SmallVectorImpl<NestedMatch> *matches;
+ };
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
- Instruction::Kind getKind();
- ArrayRef<NestedPattern> getNestedPatterns();
- FilterFunctionType getFilterFunction();
-
- void matchOne(Instruction *elem);
-
- void visitForInst(ForInst *forInst) { matchOne(forInst); }
- void visitOperationInst(OperationInst *opInst) { matchOne(opInst); }
-
- /// POD paylod.
- /// Storage for the PatternMatcher.
- NestedPatternStorage *storage;
-
- // By-value POD wrapper to underlying storage pointer.
- NestedMatch matches;
+ /// Matches this pattern against a single `inst` and fills matches with the
+ /// result.
+ void matchOne(Instruction *inst, SmallVectorImpl<NestedMatch> *matches);
+
+ /// Instruction kind matched by this pattern.
+ Instruction::Kind kind;
+
+ /// Nested patterns to be matched.
+ ArrayRef<NestedPattern> nestedPatterns;
+
+ /// Extra filter function to apply to prune patterns as the IR is walked.
+ FilterFunctionType filter;
+
+ /// skip is an implementation detail needed so that we can implement match
+ /// without switching on the type of the Instruction. The idea is that a
+ /// NestedPattern first checks if it matches locally and then recursively
+ /// applies its nested matchers to its elem->nested. Since we want to rely on
+ /// the InstWalker impl rather than duplicate its the logic, we allow an
+ /// off-by-one traversal to account for the fact that we write:
+ ///
+ /// void match(Instruction *elem) {
+ /// for (auto &c : getNestedPatterns()) {
+ /// NestedPattern childPattern(...);
+ /// ^~~~ Needs off-by-one skip.
+ ///
+ Instruction *skip;
};
/// RAII structure to transparently manage the bump allocator for
-/// NestedPattern and NestedMatch classes.
+/// NestedPattern and NestedMatch classes. This avoids passing a context to
+/// all the API functions.
struct NestedPatternContext {
NestedPatternContext() {
- NestedPattern::allocator() = &allocator;
+ assert(NestedMatch::allocator() == nullptr &&
+ "Only a single NestedPatternContext is supported");
+ assert(NestedPattern::allocator() == nullptr &&
+ "Only a single NestedPatternContext is supported");
NestedMatch::allocator() = &allocator;
+ NestedPattern::allocator() = &allocator;
}
~NestedPatternContext() {
- NestedPattern::allocator() = nullptr;
NestedMatch::allocator() = nullptr;
+ NestedPattern::allocator() = nullptr;
}
llvm::BumpPtrAllocator allocator;
};
diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp
index 07c903a6613..7d88a3d9b9f 100644
--- a/mlir/lib/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Analysis/LoopAnalysis.cpp
@@ -242,7 +242,8 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
// No vectorization across conditionals for now.
auto conditionals = matcher::If();
auto *forInst = const_cast<ForInst *>(&loop);
- auto conditionalsMatched = conditionals.match(forInst);
+ SmallVector<NestedMatch, 8> conditionalsMatched;
+ conditionals.match(forInst, &conditionalsMatched);
if (!conditionalsMatched.empty()) {
return false;
}
@@ -252,21 +253,24 @@ static bool isVectorizableLoopWithCond(const ForInst &loop,
auto &opInst = cast<OperationInst>(inst);
return opInst.getNumBlockLists() != 0 && !opInst.isa<AffineIfOp>();
});
- auto regionsMatched = regions.match(forInst);
+ SmallVector<NestedMatch, 8> regionsMatched;
+ regions.match(forInst, &regionsMatched);
if (!regionsMatched.empty()) {
return false;
}
auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
- auto vectorTransfersMatched = vectorTransfers.match(forInst);
+ SmallVector<NestedMatch, 8> vectorTransfersMatched;
+ vectorTransfers.match(forInst, &vectorTransfersMatched);
if (!vectorTransfersMatched.empty()) {
return false;
}
auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
- auto loadAndStoresMatched = loadAndStores.match(forInst);
+ SmallVector<NestedMatch, 8> loadAndStoresMatched;
+ loadAndStores.match(forInst, &loadAndStoresMatched);
for (auto ls : loadAndStoresMatched) {
- auto *op = cast<OperationInst>(ls.first);
+ auto *op = cast<OperationInst>(ls.getMatchedInstruction());
auto load = op->dyn_cast<LoadOp>();
auto store = op->dyn_cast<StoreOp>();
// Only scalar types are considered vectorizable, all load/store must be
diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp
index 491a9bef1b9..43e3b332a58 100644
--- a/mlir/lib/Analysis/NestedMatcher.cpp
+++ b/mlir/lib/Analysis/NestedMatcher.cpp
@@ -20,93 +20,49 @@
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/raw_ostream.h"
-namespace mlir {
-
-/// Underlying storage for NestedMatch.
-struct NestedMatchStorage {
- MutableArrayRef<NestedMatch::EntryType> matches;
-};
-
-/// Underlying storage for NestedPattern.
-struct NestedPatternStorage {
- NestedPatternStorage(Instruction::Kind k, ArrayRef<NestedPattern> c,
- FilterFunctionType filter, Instruction *skip)
- : kind(k), nestedPatterns(c), filter(filter), skip(skip) {}
-
- Instruction::Kind kind;
- ArrayRef<NestedPattern> nestedPatterns;
- FilterFunctionType filter;
- /// skip is needed so that we can implement match without switching on the
- /// type of the Instruction.
- /// The idea is that a NestedPattern first checks if it matches locally
- /// and then recursively applies its nested matchers to its elem->nested.
- /// Since we want to rely on the InstWalker impl rather than duplicate its
- /// the logic, we allow an off-by-one traversal to account for the fact that
- /// we write:
- ///
- /// void match(Instruction *elem) {
- /// for (auto &c : getNestedPatterns()) {
- /// NestedPattern childPattern(...);
- /// ^~~~ Needs off-by-one skip.
- ///
- Instruction *skip;
-};
-
-} // end namespace mlir
-
using namespace mlir;
llvm::BumpPtrAllocator *&NestedMatch::allocator() {
- static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ thread_local llvm::BumpPtrAllocator *allocator = nullptr;
return allocator;
}
-NestedMatch NestedMatch::build(ArrayRef<NestedMatch::EntryType> elements) {
- auto *matches =
- allocator()->Allocate<NestedMatch::EntryType>(elements.size());
- std::uninitialized_copy(elements.begin(), elements.end(), matches);
- auto *storage = allocator()->Allocate<NestedMatchStorage>();
- new (storage) NestedMatchStorage();
- storage->matches =
- MutableArrayRef<NestedMatch::EntryType>(matches, elements.size());
+NestedMatch NestedMatch::build(Instruction *instruction,
+ ArrayRef<NestedMatch> nestedMatches) {
auto *result = allocator()->Allocate<NestedMatch>();
- new (result) NestedMatch(storage);
+ auto *children = allocator()->Allocate<NestedMatch>(nestedMatches.size());
+ std::uninitialized_copy(nestedMatches.begin(), nestedMatches.end(), children);
+ new (result) NestedMatch();
+ result->matchedInstruction = instruction;
+ result->matchedChildren =
+ ArrayRef<NestedMatch>(children, nestedMatches.size());
return *result;
}
-NestedMatch::iterator NestedMatch::begin() { return storage->matches.begin(); }
-NestedMatch::iterator NestedMatch::end() { return storage->matches.end(); }
-NestedMatch::EntryType &NestedMatch::front() {
- return *storage->matches.begin();
-}
-NestedMatch::EntryType &NestedMatch::back() {
- return *(storage->matches.begin() + size() - 1);
-}
-
-/// Calls walk on `function`.
-NestedMatch NestedPattern::match(Function *function) {
- assert(!matches && "NestedPattern already matched!");
- this->walkPostOrder(function);
- return matches;
+llvm::BumpPtrAllocator *&NestedPattern::allocator() {
+ thread_local llvm::BumpPtrAllocator *allocator = nullptr;
+ return allocator;
}
-/// Calls walk on `instruction`.
-NestedMatch NestedPattern::match(Instruction *instruction) {
- assert(!matches && "NestedPattern already matched!");
- this->walkPostOrder(instruction);
- return matches;
+NestedPattern::NestedPattern(Instruction::Kind k,
+ ArrayRef<NestedPattern> nested,
+ FilterFunctionType filter)
+ : kind(k), nestedPatterns(ArrayRef<NestedPattern>(nested)), filter(filter) {
+ auto *newNested = allocator()->Allocate<NestedPattern>(nested.size());
+ std::uninitialized_copy(nested.begin(), nested.end(), newNested);
+ nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
}
-unsigned NestedPattern::getDepth() {
- auto nested = getNestedPatterns();
- if (nested.empty()) {
+unsigned NestedPattern::getDepth() const {
+ if (nestedPatterns.empty()) {
return 1;
}
unsigned depth = 0;
- for (auto c : nested) {
+ for (auto &c : nestedPatterns) {
depth = std::max(depth, c.getDepth());
}
return depth + 1;
@@ -122,69 +78,39 @@ unsigned NestedPattern::getDepth() {
/// appended to the list of matches;
/// 5. TODO(ntv) Optionally applies actions (lambda), in which case we will
/// want to traverse in post-order DFS to avoid invalidating iterators.
-void NestedPattern::matchOne(Instruction *elem) {
- if (storage->skip == elem) {
+void NestedPattern::matchOne(Instruction *inst,
+ SmallVectorImpl<NestedMatch> *matches) {
+ if (skip == inst) {
return;
}
// Structural filter
- if (elem->getKind() != getKind()) {
+ if (inst->getKind() != kind) {
return;
}
// Local custom filter function
- if (!getFilterFunction()(*elem)) {
+ if (!filter(*inst)) {
return;
}
- SmallVector<NestedMatch::EntryType, 8> nestedEntries;
- for (auto c : getNestedPatterns()) {
- /// We create a new nestedPattern here because a matcher holds its
- /// results. So we concretely need multiple copies of a given matcher, one
- /// for each matching result.
- NestedPattern nestedPattern(c);
+ if (nestedPatterns.empty()) {
+ SmallVector<NestedMatch, 8> nestedMatches;
+ matches->push_back(NestedMatch::build(inst, nestedMatches));
+ return;
+ }
+ // Take a copy of each nested pattern so we can match it.
+ for (auto nestedPattern : nestedPatterns) {
+ SmallVector<NestedMatch, 8> nestedMatches;
// Skip elem in the walk immediately following. Without this we would
// essentially need to reimplement walkPostOrder here.
- nestedPattern.storage->skip = elem;
- nestedPattern.walkPostOrder(elem);
- if (!nestedPattern.matches) {
+ nestedPattern.skip = inst;
+ nestedPattern.match(inst, &nestedMatches);
+ // If we could not match even one of the specified nestedPattern, early exit
+ // as this whole branch is not a match.
+ if (nestedMatches.empty()) {
return;
}
- for (auto m : nestedPattern.matches) {
- nestedEntries.push_back(m);
- }
+ matches->push_back(NestedMatch::build(inst, nestedMatches));
}
-
- SmallVector<NestedMatch::EntryType, 8> newEntries(
- matches.storage->matches.begin(), matches.storage->matches.end());
- newEntries.push_back(std::make_pair(elem, NestedMatch::build(nestedEntries)));
- matches = NestedMatch::build(newEntries);
-}
-
-llvm::BumpPtrAllocator *&NestedPattern::allocator() {
- static thread_local llvm::BumpPtrAllocator *allocator = nullptr;
- return allocator;
-}
-
-NestedPattern::NestedPattern(Instruction::Kind k,
- ArrayRef<NestedPattern> nested,
- FilterFunctionType filter)
- : storage(allocator()->Allocate<NestedPatternStorage>()),
- matches(NestedMatch::build({})) {
- auto *newChildren = allocator()->Allocate<NestedPattern>(nested.size());
- std::uninitialized_copy(nested.begin(), nested.end(), newChildren);
- // Initialize with placement new.
- new (storage) NestedPatternStorage(
- k, ArrayRef<NestedPattern>(newChildren, nested.size()), filter,
- nullptr /* skip */);
-}
-
-Instruction::Kind NestedPattern::getKind() { return storage->kind; }
-
-ArrayRef<NestedPattern> NestedPattern::getNestedPatterns() {
- return storage->nestedPatterns;
-}
-
-FilterFunctionType NestedPattern::getFilterFunction() {
- return storage->filter;
}
static bool isAffineIfOp(const Instruction &inst) {
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 2744b1d624c..432ad1f39b8 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -201,9 +201,6 @@ struct MaterializeVectorsPass : public FunctionPass {
PassResult runOnFunction(Function *f) override;
- // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- NestedPatternContext mlContext;
-
static char passID;
};
@@ -744,6 +741,9 @@ static bool materialize(Function *f,
}
PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
+ // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+ NestedPatternContext mlContext;
+
// TODO(ntv): Check to see if this supports arbitrary top-level code.
if (f->getBlocks().size() != 1)
return success();
@@ -768,10 +768,11 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
return matcher::operatesOnSuperVectors(opInst, subVectorType);
};
auto pat = Op(filter);
- auto matches = pat.match(f);
+ SmallVector<NestedMatch, 8> matches;
+ pat.match(f, &matches);
SetVector<OperationInst *> terminators;
for (auto m : matches) {
- terminators.insert(cast<OperationInst>(m.first));
+ terminators.insert(cast<OperationInst>(m.getMatchedInstruction()));
}
auto fail = materialize(f, terminators, &state);
diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
index a01b8fdf216..a9b9752ef51 100644
--- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
+++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
@@ -30,6 +30,7 @@
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/Passes.h"
+#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -94,9 +95,6 @@ struct VectorizerTestPass : public FunctionPass {
void testComposeMaps(Function *f);
void testNormalizeMaps(Function *f);
- // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- NestedPatternContext MLContext;
-
static char passID;
};
@@ -128,9 +126,10 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
return true;
};
auto pat = Op(filter);
- auto matches = pat.match(f);
+ SmallVector<NestedMatch, 8> matches;
+ pat.match(f, &matches);
for (auto m : matches) {
- auto *opInst = cast<OperationInst>(m.first);
+ auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
// This is a unit test that only checks and prints shape ratio.
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
@@ -153,7 +152,7 @@ static std::string toString(Instruction *inst) {
return res;
}
-static NestedMatch matchTestSlicingOps(Function *f) {
+static NestedPattern patternTestSlicingOps() {
// Just use a custom op name for this test, it makes life easier.
constexpr auto kTestSlicingOpName = "slicing-test-op";
using functional::map;
@@ -163,17 +162,18 @@ static NestedMatch matchTestSlicingOps(Function *f) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.getName().getStringRef() == kTestSlicingOpName;
};
- auto pat = Op(filter);
- return pat.match(f);
+ return Op(filter);
}
void VectorizerTestPass::testBackwardSlicing(Function *f) {
- auto matches = matchTestSlicingOps(f);
+ SmallVector<NestedMatch, 8> matches;
+ patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
SetVector<Instruction *> backwardSlice;
- getBackwardSlice(m.first, &backwardSlice);
+ getBackwardSlice(m.getMatchedInstruction(), &backwardSlice);
auto strs = map(toString, backwardSlice);
- outs() << "\nmatched: " << *m.first << " backward static slice: ";
+ outs() << "\nmatched: " << *m.getMatchedInstruction()
+ << " backward static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
}
@@ -181,12 +181,14 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) {
}
void VectorizerTestPass::testForwardSlicing(Function *f) {
- auto matches = matchTestSlicingOps(f);
+ SmallVector<NestedMatch, 8> matches;
+ patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
SetVector<Instruction *> forwardSlice;
- getForwardSlice(m.first, &forwardSlice);
+ getForwardSlice(m.getMatchedInstruction(), &forwardSlice);
auto strs = map(toString, forwardSlice);
- outs() << "\nmatched: " << *m.first << " forward static slice: ";
+ outs() << "\nmatched: " << *m.getMatchedInstruction()
+ << " forward static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
}
@@ -194,11 +196,12 @@ void VectorizerTestPass::testForwardSlicing(Function *f) {
}
void VectorizerTestPass::testSlicing(Function *f) {
- auto matches = matchTestSlicingOps(f);
+ SmallVector<NestedMatch, 8> matches;
+ patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
- SetVector<Instruction *> staticSlice = getSlice(m.first);
+ SetVector<Instruction *> staticSlice = getSlice(m.getMatchedInstruction());
auto strs = map(toString, staticSlice);
- outs() << "\nmatched: " << *m.first << " static slice: ";
+ outs() << "\nmatched: " << *m.getMatchedInstruction() << " static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
}
@@ -214,12 +217,12 @@ static bool customOpWithAffineMapAttribute(const Instruction &inst) {
void VectorizerTestPass::testComposeMaps(Function *f) {
using matcher::Op;
auto pattern = Op(customOpWithAffineMapAttribute);
- auto matches = pattern.match(f);
+ SmallVector<NestedMatch, 8> matches;
+ pattern.match(f, &matches);
SmallVector<AffineMap, 4> maps;
maps.reserve(matches.size());
- std::reverse(matches.begin(), matches.end());
- for (auto m : matches) {
- auto *opInst = cast<OperationInst>(m.first);
+ for (auto m : llvm::reverse(matches)) {
+ auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
.cast<AffineMapAttr>()
.getValue();
@@ -248,29 +251,31 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) {
// Save matched AffineApplyOp that all need to be erased in the end.
auto pattern = Op(affineApplyOp);
- auto toErase = pattern.match(f);
- std::reverse(toErase.begin(), toErase.end());
+ SmallVector<NestedMatch, 8> toErase;
+ pattern.match(f, &toErase);
{
// Compose maps.
auto pattern = Op(singleResultAffineApplyOpWithoutUses);
- for (auto m : pattern.match(f)) {
- auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>();
- FuncBuilder b(m.first);
-
- using ValueTy = decltype(*(app->getOperands().begin()));
- SmallVector<Value *, 8> operands =
- functional::map([](ValueTy v) { return static_cast<Value *>(v); },
- app->getOperands().begin(), app->getOperands().end());
+ SmallVector<NestedMatch, 8> matches;
+ pattern.match(f, &matches);
+ for (auto m : matches) {
+ auto app =
+ cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>();
+ FuncBuilder b(m.getMatchedInstruction());
+ SmallVector<Value *, 8> operands(app->getOperands());
makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands);
}
}
// We should now be able to erase everything in reverse order in this test.
- for (auto m : toErase) {
- m.first->erase();
+ for (auto m : llvm::reverse(toErase)) {
+ m.getMatchedInstruction()->erase();
}
}
PassResult VectorizerTestPass::runOnFunction(Function *f) {
+ // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+ NestedPatternContext mlContext;
+
// Only support single block functions at this point.
if (f->getBlocks().size() != 1)
return success();
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index cfde1ecf0a8..73893599d17 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -655,9 +655,6 @@ struct Vectorize : public FunctionPass {
PassResult runOnFunction(Function *f) override;
- // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- NestedPatternContext MLContext;
-
static char passID;
};
@@ -703,13 +700,13 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
/// 3. account for impact of vectorization on maximal loop fusion.
/// Then we can quantify the above to build a cost model and search over
/// strategies.
-static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern,
- unsigned patternDepth,
+static bool analyzeProfitability(ArrayRef<NestedMatch> matches,
+ unsigned depthInPattern, unsigned patternDepth,
VectorizationStrategy *strategy) {
for (auto m : matches) {
- auto *loop = cast<ForInst>(m.first);
- bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth,
- strategy);
+ auto *loop = cast<ForInst>(m.getMatchedInstruction());
+ bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
+ patternDepth, strategy);
if (fail) {
return fail;
}
@@ -875,9 +872,10 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
state->terminators.count(opInst) == 0;
};
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
- auto matches = loadAndStores.match(loop);
- for (auto ls : matches) {
- auto *opInst = cast<OperationInst>(ls.first);
+ SmallVector<NestedMatch, 8> loadAndStoresMatches;
+ loadAndStores.match(loop, &loadAndStoresMatches);
+ for (auto ls : loadAndStoresMatches) {
+ auto *opInst = cast<OperationInst>(ls.getMatchedInstruction());
auto load = opInst->dyn_cast<LoadOp>();
auto store = opInst->dyn_cast<StoreOp>();
LLVM_DEBUG(opInst->print(dbgs()));
@@ -907,15 +905,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
}
/// Forward-declaration.
-static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state);
+static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
+ VectorizationState *state);
/// Apply vectorization of `loop` according to `state`. This is only triggered
/// if all vectorizations in `childrenMatches` have already succeeded
/// recursively in DFS post-order.
-static bool doVectorize(NestedMatch::EntryType oneMatch,
- VectorizationState *state) {
- ForInst *loop = cast<ForInst>(oneMatch.first);
- NestedMatch childrenMatches = oneMatch.second;
+static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
+ ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction());
+ auto childrenMatches = oneMatch.getMatchedChildren();
// 1. DFS postorder recursion, if any of my children fails, I fail too.
auto fail = vectorizeNonRoot(childrenMatches, state);
@@ -949,7 +947,8 @@ static bool doVectorize(NestedMatch::EntryType oneMatch,
/// Non-root pattern iterates over the matches at this level, calls doVectorize
/// and exits early if anything below fails.
-static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) {
+static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
+ VectorizationState *state) {
for (auto m : matches) {
auto fail = doVectorize(m, state);
if (fail) {
@@ -1185,99 +1184,100 @@ static bool vectorizeOperations(VectorizationState *state) {
/// The root match thus needs to maintain a clone for handling failure.
/// Each root may succeed independently but will otherwise clean after itself if
/// anything below it fails.
-static bool vectorizeRootMatches(NestedMatch matches,
- VectorizationStrategy *strategy) {
- for (auto m : matches) {
- auto *loop = cast<ForInst>(m.first);
- VectorizationState state;
- state.strategy = strategy;
-
- // Since patterns are recursive, they can very well intersect.
- // Since we do not want a fully greedy strategy in general, we decouple
- // pattern matching, from profitability analysis, from application.
- // As a consequence we must check that each root pattern is still
- // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
- // TODO(ntv): implement a non-greedy profitability analysis that keeps only
- // non-intersecting patterns.
- if (!isVectorizableLoop(*loop)) {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
- continue;
- }
- FuncBuilder builder(loop); // builder to insert in place of loop
- ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop));
- auto fail = doVectorize(m, &state);
- /// Sets up error handling for this root loop. This is how the root match
- /// maintains a clone for handling failure and restores the proper state via
- /// RAII.
- ScopeGuard sg2([&fail, loop, clonedLoop]() {
- if (fail) {
- loop->getInductionVar()->replaceAllUsesWith(
- clonedLoop->getInductionVar());
- loop->erase();
- } else {
- clonedLoop->erase();
- }
- });
+static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
+ auto *loop = cast<ForInst>(m.getMatchedInstruction());
+ VectorizationState state;
+ state.strategy = strategy;
+
+ // Since patterns are recursive, they can very well intersect.
+ // Since we do not want a fully greedy strategy in general, we decouple
+ // pattern matching, from profitability analysis, from application.
+ // As a consequence we must check that each root pattern is still
+ // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
+ // TODO(ntv): implement a non-greedy profitability analysis that keeps only
+ // non-intersecting patterns.
+ if (!isVectorizableLoop(*loop)) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
+ return true;
+ }
+ FuncBuilder builder(loop); // builder to insert in place of loop
+ ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop));
+ auto fail = doVectorize(m, &state);
+ /// Sets up error handling for this root loop. This is how the root match
+ /// maintains a clone for handling failure and restores the proper state via
+ /// RAII.
+ ScopeGuard sg2([&fail, loop, clonedLoop]() {
if (fail) {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize");
- continue;
+ loop->getInductionVar()->replaceAllUsesWith(
+ clonedLoop->getInductionVar());
+ loop->erase();
+ } else {
+ clonedLoop->erase();
}
+ });
+ if (fail) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize");
+ return true;
+ }
- // Form the root operationsthat have been set in the replacementMap.
- // For now, these roots are the loads for which vector_transfer_read
- // operations have been inserted.
- auto getDefiningInst = [](const Value *val) {
- return const_cast<Value *>(val)->getDefiningInst();
- };
- using ReferenceTy = decltype(*(state.replacementMap.begin()));
- auto getKey = [](ReferenceTy it) { return it.first; };
- auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
-
- // Vectorize the root operations and everything reached by use-def chains
- // except the terminators (store instructions) that need to be
- // post-processed separately.
- fail = vectorizeOperations(&state);
- if (fail) {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations");
- continue;
- }
+ // Form the root operationsthat have been set in the replacementMap.
+ // For now, these roots are the loads for which vector_transfer_read
+ // operations have been inserted.
+ auto getDefiningInst = [](const Value *val) {
+ return const_cast<Value *>(val)->getDefiningInst();
+ };
+ using ReferenceTy = decltype(*(state.replacementMap.begin()));
+ auto getKey = [](ReferenceTy it) { return it.first; };
+ auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
+
+ // Vectorize the root operations and everything reached by use-def chains
+ // except the terminators (store instructions) that need to be
+ // post-processed separately.
+ fail = vectorizeOperations(&state);
+ if (fail) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations");
+ return true;
+ }
- // Finally, vectorize the terminators. If anything fails to vectorize, skip.
- auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
- if (fail) {
- return;
- }
- FuncBuilder b(inst);
- auto *res = vectorizeOneOperationInst(&b, inst, &state);
- if (res == nullptr) {
- fail = true;
- }
- };
- apply(vectorizeOrFail, state.terminators);
+ // Finally, vectorize the terminators. If anything fails to vectorize, skip.
+ auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
if (fail) {
- LLVM_DEBUG(
- dbgs() << "\n[early-vect]+++++ failed to vectorize terminators");
- continue;
- } else {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
+ return;
}
-
- // Finish this vectorization pattern.
- state.finishVectorizationPattern();
+ FuncBuilder b(inst);
+ auto *res = vectorizeOneOperationInst(&b, inst, &state);
+ if (res == nullptr) {
+ fail = true;
+ }
+ };
+ apply(vectorizeOrFail, state.terminators);
+ if (fail) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminators");
+ return true;
+ } else {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
}
+
+ // Finish this vectorization pattern.
+ state.finishVectorizationPattern();
return false;
}
/// Applies vectorization to the current Function by searching over a bunch of
/// predetermined patterns.
PassResult Vectorize::runOnFunction(Function *f) {
- for (auto pat : makePatterns()) {
+ // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+ NestedPatternContext mlContext;
+
+ for (auto &pat : makePatterns()) {
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n");
LLVM_DEBUG(f->print(dbgs()));
unsigned patternDepth = pat.getDepth();
- auto matches = pat.match(f);
+
+ SmallVector<NestedMatch, 8> matches;
+ pat.match(f, &matches);
// Iterate over all the top-level matches and vectorize eagerly.
// This automatically prunes intersecting matches.
for (auto m : matches) {
@@ -1285,16 +1285,16 @@ PassResult Vectorize::runOnFunction(Function *f) {
// TODO(ntv): depending on profitability, elect to reduce the vector size.
strategy.vectorSizes.assign(clVirtualVectorSize.begin(),
clVirtualVectorSize.end());
- auto fail = analyzeProfitability(m.second, 1, patternDepth, &strategy);
+ auto fail = analyzeProfitability(m.getMatchedChildren(), 1, patternDepth,
+ &strategy);
if (fail) {
continue;
}
- auto *loop = cast<ForInst>(m.first);
+ auto *loop = cast<ForInst>(m.getMatchedInstruction());
vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy);
// TODO(ntv): if pattern does not apply, report it; alter the
// cost/benefit.
- fail = vectorizeRootMatches(matches, &strategy);
- assert(!fail && "top-level failure should not happen");
+ fail = vectorizeRootMatch(m, &strategy);
// TODO(ntv): some diagnostics.
}
}
OpenPOWER on IntegriCloud