summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp11
-rw-r--r--mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp71
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp192
3 files changed, 140 insertions, 134 deletions
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 2744b1d624c..432ad1f39b8 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -201,9 +201,6 @@ struct MaterializeVectorsPass : public FunctionPass {
PassResult runOnFunction(Function *f) override;
- // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- NestedPatternContext mlContext;
-
static char passID;
};
@@ -744,6 +741,9 @@ static bool materialize(Function *f,
}
PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
+ // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+ NestedPatternContext mlContext;
+
// TODO(ntv): Check to see if this supports arbitrary top-level code.
if (f->getBlocks().size() != 1)
return success();
@@ -768,10 +768,11 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
return matcher::operatesOnSuperVectors(opInst, subVectorType);
};
auto pat = Op(filter);
- auto matches = pat.match(f);
+ SmallVector<NestedMatch, 8> matches;
+ pat.match(f, &matches);
SetVector<OperationInst *> terminators;
for (auto m : matches) {
- terminators.insert(cast<OperationInst>(m.first));
+ terminators.insert(cast<OperationInst>(m.getMatchedInstruction()));
}
auto fail = materialize(f, terminators, &state);
diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
index a01b8fdf216..a9b9752ef51 100644
--- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
+++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
@@ -30,6 +30,7 @@
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/Passes.h"
+#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -94,9 +95,6 @@ struct VectorizerTestPass : public FunctionPass {
void testComposeMaps(Function *f);
void testNormalizeMaps(Function *f);
- // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- NestedPatternContext MLContext;
-
static char passID;
};
@@ -128,9 +126,10 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
return true;
};
auto pat = Op(filter);
- auto matches = pat.match(f);
+ SmallVector<NestedMatch, 8> matches;
+ pat.match(f, &matches);
for (auto m : matches) {
- auto *opInst = cast<OperationInst>(m.first);
+ auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
// This is a unit test that only checks and prints shape ratio.
// As a consequence we write only Ops with a single return type for the
// purpose of this test. If we need to test more intricate behavior in the
@@ -153,7 +152,7 @@ static std::string toString(Instruction *inst) {
return res;
}
-static NestedMatch matchTestSlicingOps(Function *f) {
+static NestedPattern patternTestSlicingOps() {
// Just use a custom op name for this test, it makes life easier.
constexpr auto kTestSlicingOpName = "slicing-test-op";
using functional::map;
@@ -163,17 +162,18 @@ static NestedMatch matchTestSlicingOps(Function *f) {
const auto &opInst = cast<OperationInst>(inst);
return opInst.getName().getStringRef() == kTestSlicingOpName;
};
- auto pat = Op(filter);
- return pat.match(f);
+ return Op(filter);
}
void VectorizerTestPass::testBackwardSlicing(Function *f) {
- auto matches = matchTestSlicingOps(f);
+ SmallVector<NestedMatch, 8> matches;
+ patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
SetVector<Instruction *> backwardSlice;
- getBackwardSlice(m.first, &backwardSlice);
+ getBackwardSlice(m.getMatchedInstruction(), &backwardSlice);
auto strs = map(toString, backwardSlice);
- outs() << "\nmatched: " << *m.first << " backward static slice: ";
+ outs() << "\nmatched: " << *m.getMatchedInstruction()
+ << " backward static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
}
@@ -181,12 +181,14 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) {
}
void VectorizerTestPass::testForwardSlicing(Function *f) {
- auto matches = matchTestSlicingOps(f);
+ SmallVector<NestedMatch, 8> matches;
+ patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
SetVector<Instruction *> forwardSlice;
- getForwardSlice(m.first, &forwardSlice);
+ getForwardSlice(m.getMatchedInstruction(), &forwardSlice);
auto strs = map(toString, forwardSlice);
- outs() << "\nmatched: " << *m.first << " forward static slice: ";
+ outs() << "\nmatched: " << *m.getMatchedInstruction()
+ << " forward static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
}
@@ -194,11 +196,12 @@ void VectorizerTestPass::testForwardSlicing(Function *f) {
}
void VectorizerTestPass::testSlicing(Function *f) {
- auto matches = matchTestSlicingOps(f);
+ SmallVector<NestedMatch, 8> matches;
+ patternTestSlicingOps().match(f, &matches);
for (auto m : matches) {
- SetVector<Instruction *> staticSlice = getSlice(m.first);
+ SetVector<Instruction *> staticSlice = getSlice(m.getMatchedInstruction());
auto strs = map(toString, staticSlice);
- outs() << "\nmatched: " << *m.first << " static slice: ";
+ outs() << "\nmatched: " << *m.getMatchedInstruction() << " static slice: ";
for (const auto &s : strs) {
outs() << "\n" << s;
}
@@ -214,12 +217,12 @@ static bool customOpWithAffineMapAttribute(const Instruction &inst) {
void VectorizerTestPass::testComposeMaps(Function *f) {
using matcher::Op;
auto pattern = Op(customOpWithAffineMapAttribute);
- auto matches = pattern.match(f);
+ SmallVector<NestedMatch, 8> matches;
+ pattern.match(f, &matches);
SmallVector<AffineMap, 4> maps;
maps.reserve(matches.size());
- std::reverse(matches.begin(), matches.end());
- for (auto m : matches) {
- auto *opInst = cast<OperationInst>(m.first);
+ for (auto m : llvm::reverse(matches)) {
+ auto *opInst = cast<OperationInst>(m.getMatchedInstruction());
auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName)
.cast<AffineMapAttr>()
.getValue();
@@ -248,29 +251,31 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) {
// Save matched AffineApplyOp that all need to be erased in the end.
auto pattern = Op(affineApplyOp);
- auto toErase = pattern.match(f);
- std::reverse(toErase.begin(), toErase.end());
+ SmallVector<NestedMatch, 8> toErase;
+ pattern.match(f, &toErase);
{
// Compose maps.
auto pattern = Op(singleResultAffineApplyOpWithoutUses);
- for (auto m : pattern.match(f)) {
- auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>();
- FuncBuilder b(m.first);
-
- using ValueTy = decltype(*(app->getOperands().begin()));
- SmallVector<Value *, 8> operands =
- functional::map([](ValueTy v) { return static_cast<Value *>(v); },
- app->getOperands().begin(), app->getOperands().end());
+ SmallVector<NestedMatch, 8> matches;
+ pattern.match(f, &matches);
+ for (auto m : matches) {
+ auto app =
+ cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>();
+ FuncBuilder b(m.getMatchedInstruction());
+ SmallVector<Value *, 8> operands(app->getOperands());
makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands);
}
}
// We should now be able to erase everything in reverse order in this test.
- for (auto m : toErase) {
- m.first->erase();
+ for (auto m : llvm::reverse(toErase)) {
+ m.getMatchedInstruction()->erase();
}
}
PassResult VectorizerTestPass::runOnFunction(Function *f) {
+ // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+ NestedPatternContext mlContext;
+
// Only support single block functions at this point.
if (f->getBlocks().size() != 1)
return success();
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index cfde1ecf0a8..73893599d17 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -655,9 +655,6 @@ struct Vectorize : public FunctionPass {
PassResult runOnFunction(Function *f) override;
- // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit.
- NestedPatternContext MLContext;
-
static char passID;
};
@@ -703,13 +700,13 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern,
/// 3. account for impact of vectorization on maximal loop fusion.
/// Then we can quantify the above to build a cost model and search over
/// strategies.
-static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern,
- unsigned patternDepth,
+static bool analyzeProfitability(ArrayRef<NestedMatch> matches,
+ unsigned depthInPattern, unsigned patternDepth,
VectorizationStrategy *strategy) {
for (auto m : matches) {
- auto *loop = cast<ForInst>(m.first);
- bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth,
- strategy);
+ auto *loop = cast<ForInst>(m.getMatchedInstruction());
+ bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
+ patternDepth, strategy);
if (fail) {
return fail;
}
@@ -875,9 +872,10 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
state->terminators.count(opInst) == 0;
};
auto loadAndStores = matcher::Op(notVectorizedThisPattern);
- auto matches = loadAndStores.match(loop);
- for (auto ls : matches) {
- auto *opInst = cast<OperationInst>(ls.first);
+ SmallVector<NestedMatch, 8> loadAndStoresMatches;
+ loadAndStores.match(loop, &loadAndStoresMatches);
+ for (auto ls : loadAndStoresMatches) {
+ auto *opInst = cast<OperationInst>(ls.getMatchedInstruction());
auto load = opInst->dyn_cast<LoadOp>();
auto store = opInst->dyn_cast<StoreOp>();
LLVM_DEBUG(opInst->print(dbgs()));
@@ -907,15 +905,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) {
}
/// Forward-declaration.
-static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state);
+static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
+ VectorizationState *state);
/// Apply vectorization of `loop` according to `state`. This is only triggered
/// if all vectorizations in `childrenMatches` have already succeeded
/// recursively in DFS post-order.
-static bool doVectorize(NestedMatch::EntryType oneMatch,
- VectorizationState *state) {
- ForInst *loop = cast<ForInst>(oneMatch.first);
- NestedMatch childrenMatches = oneMatch.second;
+static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) {
+ ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction());
+ auto childrenMatches = oneMatch.getMatchedChildren();
// 1. DFS postorder recursion, if any of my children fails, I fail too.
auto fail = vectorizeNonRoot(childrenMatches, state);
@@ -949,7 +947,8 @@ static bool doVectorize(NestedMatch::EntryType oneMatch,
/// Non-root pattern iterates over the matches at this level, calls doVectorize
/// and exits early if anything below fails.
-static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) {
+static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches,
+ VectorizationState *state) {
for (auto m : matches) {
auto fail = doVectorize(m, state);
if (fail) {
@@ -1185,99 +1184,100 @@ static bool vectorizeOperations(VectorizationState *state) {
/// The root match thus needs to maintain a clone for handling failure.
/// Each root may succeed independently but will otherwise clean after itself if
/// anything below it fails.
-static bool vectorizeRootMatches(NestedMatch matches,
- VectorizationStrategy *strategy) {
- for (auto m : matches) {
- auto *loop = cast<ForInst>(m.first);
- VectorizationState state;
- state.strategy = strategy;
-
- // Since patterns are recursive, they can very well intersect.
- // Since we do not want a fully greedy strategy in general, we decouple
- // pattern matching, from profitability analysis, from application.
- // As a consequence we must check that each root pattern is still
- // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
- // TODO(ntv): implement a non-greedy profitability analysis that keeps only
- // non-intersecting patterns.
- if (!isVectorizableLoop(*loop)) {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
- continue;
- }
- FuncBuilder builder(loop); // builder to insert in place of loop
- ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop));
- auto fail = doVectorize(m, &state);
- /// Sets up error handling for this root loop. This is how the root match
- /// maintains a clone for handling failure and restores the proper state via
- /// RAII.
- ScopeGuard sg2([&fail, loop, clonedLoop]() {
- if (fail) {
- loop->getInductionVar()->replaceAllUsesWith(
- clonedLoop->getInductionVar());
- loop->erase();
- } else {
- clonedLoop->erase();
- }
- });
+static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) {
+ auto *loop = cast<ForInst>(m.getMatchedInstruction());
+ VectorizationState state;
+ state.strategy = strategy;
+
+ // Since patterns are recursive, they can very well intersect.
+ // Since we do not want a fully greedy strategy in general, we decouple
+ // pattern matching, from profitability analysis, from application.
+ // As a consequence we must check that each root pattern is still
+ // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
+ // TODO(ntv): implement a non-greedy profitability analysis that keeps only
+ // non-intersecting patterns.
+ if (!isVectorizableLoop(*loop)) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable");
+ return true;
+ }
+ FuncBuilder builder(loop); // builder to insert in place of loop
+ ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop));
+ auto fail = doVectorize(m, &state);
+ /// Sets up error handling for this root loop. This is how the root match
+ /// maintains a clone for handling failure and restores the proper state via
+ /// RAII.
+ ScopeGuard sg2([&fail, loop, clonedLoop]() {
if (fail) {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize");
- continue;
+ loop->getInductionVar()->replaceAllUsesWith(
+ clonedLoop->getInductionVar());
+ loop->erase();
+ } else {
+ clonedLoop->erase();
}
+ });
+ if (fail) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize");
+ return true;
+ }
- // Form the root operationsthat have been set in the replacementMap.
- // For now, these roots are the loads for which vector_transfer_read
- // operations have been inserted.
- auto getDefiningInst = [](const Value *val) {
- return const_cast<Value *>(val)->getDefiningInst();
- };
- using ReferenceTy = decltype(*(state.replacementMap.begin()));
- auto getKey = [](ReferenceTy it) { return it.first; };
- auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
-
- // Vectorize the root operations and everything reached by use-def chains
- // except the terminators (store instructions) that need to be
- // post-processed separately.
- fail = vectorizeOperations(&state);
- if (fail) {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations");
- continue;
- }
+ // Form the root operationsthat have been set in the replacementMap.
+ // For now, these roots are the loads for which vector_transfer_read
+ // operations have been inserted.
+ auto getDefiningInst = [](const Value *val) {
+ return const_cast<Value *>(val)->getDefiningInst();
+ };
+ using ReferenceTy = decltype(*(state.replacementMap.begin()));
+ auto getKey = [](ReferenceTy it) { return it.first; };
+ auto roots = map(getDefiningInst, map(getKey, state.replacementMap));
+
+ // Vectorize the root operations and everything reached by use-def chains
+ // except the terminators (store instructions) that need to be
+ // post-processed separately.
+ fail = vectorizeOperations(&state);
+ if (fail) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations");
+ return true;
+ }
- // Finally, vectorize the terminators. If anything fails to vectorize, skip.
- auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
- if (fail) {
- return;
- }
- FuncBuilder b(inst);
- auto *res = vectorizeOneOperationInst(&b, inst, &state);
- if (res == nullptr) {
- fail = true;
- }
- };
- apply(vectorizeOrFail, state.terminators);
+ // Finally, vectorize the terminators. If anything fails to vectorize, skip.
+ auto vectorizeOrFail = [&fail, &state](OperationInst *inst) {
if (fail) {
- LLVM_DEBUG(
- dbgs() << "\n[early-vect]+++++ failed to vectorize terminators");
- continue;
- } else {
- LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
+ return;
}
-
- // Finish this vectorization pattern.
- state.finishVectorizationPattern();
+ FuncBuilder b(inst);
+ auto *res = vectorizeOneOperationInst(&b, inst, &state);
+ if (res == nullptr) {
+ fail = true;
+ }
+ };
+ apply(vectorizeOrFail, state.terminators);
+ if (fail) {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminators");
+ return true;
+ } else {
+ LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern");
}
+
+ // Finish this vectorization pattern.
+ state.finishVectorizationPattern();
return false;
}
/// Applies vectorization to the current Function by searching over a bunch of
/// predetermined patterns.
PassResult Vectorize::runOnFunction(Function *f) {
- for (auto pat : makePatterns()) {
+ // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
+ NestedPatternContext mlContext;
+
+ for (auto &pat : makePatterns()) {
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n******************************************");
LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n");
LLVM_DEBUG(f->print(dbgs()));
unsigned patternDepth = pat.getDepth();
- auto matches = pat.match(f);
+
+ SmallVector<NestedMatch, 8> matches;
+ pat.match(f, &matches);
// Iterate over all the top-level matches and vectorize eagerly.
// This automatically prunes intersecting matches.
for (auto m : matches) {
@@ -1285,16 +1285,16 @@ PassResult Vectorize::runOnFunction(Function *f) {
// TODO(ntv): depending on profitability, elect to reduce the vector size.
strategy.vectorSizes.assign(clVirtualVectorSize.begin(),
clVirtualVectorSize.end());
- auto fail = analyzeProfitability(m.second, 1, patternDepth, &strategy);
+ auto fail = analyzeProfitability(m.getMatchedChildren(), 1, patternDepth,
+ &strategy);
if (fail) {
continue;
}
- auto *loop = cast<ForInst>(m.first);
+ auto *loop = cast<ForInst>(m.getMatchedInstruction());
vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy);
// TODO(ntv): if pattern does not apply, report it; alter the
// cost/benefit.
- fail = vectorizeRootMatches(matches, &strategy);
- assert(!fail && "top-level failure should not happen");
+ fail = vectorizeRootMatch(m, &strategy);
// TODO(ntv): some diagnostics.
}
}
OpenPOWER on IntegriCloud