diff options
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 11 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 71 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 192 |
3 files changed, 140 insertions, 134 deletions
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2744b1d624c..432ad1f39b8 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -201,9 +201,6 @@ struct MaterializeVectorsPass : public FunctionPass { PassResult runOnFunction(Function *f) override; - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext mlContext; - static char passID; }; @@ -744,6 +741,9 @@ static bool materialize(Function *f, } PassResult MaterializeVectorsPass::runOnFunction(Function *f) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + // TODO(ntv): Check to see if this supports arbitrary top-level code. if (f->getBlocks().size() != 1) return success(); @@ -768,10 +768,11 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { return matcher::operatesOnSuperVectors(opInst, subVectorType); }; auto pat = Op(filter); - auto matches = pat.match(f); + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); SetVector<OperationInst *> terminators; for (auto m : matches) { - terminators.insert(cast<OperationInst>(m.first)); + terminators.insert(cast<OperationInst>(m.getMatchedInstruction())); } auto fail = materialize(f, terminators, &state); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index a01b8fdf216..a9b9752ef51 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -30,6 +30,7 @@ #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/Passes.h" +#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -94,9 +95,6 @@ struct VectorizerTestPass : public FunctionPass { void testComposeMaps(Function *f); void testNormalizeMaps(Function *f); - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext MLContext; - static char passID; }; @@ -128,9 +126,10 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { return true; }; auto pat = Op(filter); - auto matches = pat.match(f); + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); for (auto m : matches) { - auto *opInst = cast<OperationInst>(m.first); + auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); // This is a unit test that only checks and prints shape ratio. // As a consequence we write only Ops with a single return type for the // purpose of this test. If we need to test more intricate behavior in the @@ -153,7 +152,7 @@ static std::string toString(Instruction *inst) { return res; } -static NestedMatch matchTestSlicingOps(Function *f) { +static NestedPattern patternTestSlicingOps() { // Just use a custom op name for this test, it makes life easier. constexpr auto kTestSlicingOpName = "slicing-test-op"; using functional::map; @@ -163,17 +162,18 @@ static NestedMatch matchTestSlicingOps(Function *f) { const auto &opInst = cast<OperationInst>(inst); return opInst.getName().getStringRef() == kTestSlicingOpName; }; - auto pat = Op(filter); - return pat.match(f); + return Op(filter); } void VectorizerTestPass::testBackwardSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector<Instruction *> backwardSlice; - getBackwardSlice(m.first, &backwardSlice); + getBackwardSlice(m.getMatchedInstruction(), &backwardSlice); auto strs = map(toString, backwardSlice); - outs() << "\nmatched: " << *m.first << " backward static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() + << " backward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -181,12 +181,14 @@ void VectorizerTestPass::testBackwardSlicing(Function *f) { } void VectorizerTestPass::testForwardSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { SetVector<Instruction *> forwardSlice; - getForwardSlice(m.first, &forwardSlice); + getForwardSlice(m.getMatchedInstruction(), &forwardSlice); auto strs = map(toString, forwardSlice); - outs() << "\nmatched: " << *m.first << " forward static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() + << " forward static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -194,11 +196,12 @@ void VectorizerTestPass::testForwardSlicing(Function *f) { } void VectorizerTestPass::testSlicing(Function *f) { - auto matches = matchTestSlicingOps(f); + SmallVector<NestedMatch, 8> matches; + patternTestSlicingOps().match(f, &matches); for (auto m : matches) { - SetVector<Instruction *> staticSlice = getSlice(m.first); + SetVector<Instruction *> staticSlice = getSlice(m.getMatchedInstruction()); auto strs = map(toString, staticSlice); - outs() << "\nmatched: " << *m.first << " static slice: "; + outs() << "\nmatched: " << *m.getMatchedInstruction() << " static slice: "; for (const auto &s : strs) { outs() << "\n" << s; } @@ -214,12 +217,12 @@ static bool customOpWithAffineMapAttribute(const Instruction &inst) { void VectorizerTestPass::testComposeMaps(Function *f) { using matcher::Op; auto pattern = Op(customOpWithAffineMapAttribute); - auto matches = pattern.match(f); + SmallVector<NestedMatch, 8> matches; + pattern.match(f, &matches); SmallVector<AffineMap, 4> maps; maps.reserve(matches.size()); - std::reverse(matches.begin(), matches.end()); - for (auto m : matches) { - auto *opInst = cast<OperationInst>(m.first); + for (auto m : llvm::reverse(matches)) { + auto *opInst = cast<OperationInst>(m.getMatchedInstruction()); auto map = opInst->getAttr(VectorizerTestPass::kTestAffineMapAttrName) .cast<AffineMapAttr>() .getValue(); @@ -248,29 +251,31 @@ void VectorizerTestPass::testNormalizeMaps(Function *f) { // Save matched AffineApplyOp that all need to be erased in the end. auto pattern = Op(affineApplyOp); - auto toErase = pattern.match(f); - std::reverse(toErase.begin(), toErase.end()); + SmallVector<NestedMatch, 8> toErase; + pattern.match(f, &toErase); { // Compose maps. auto pattern = Op(singleResultAffineApplyOpWithoutUses); - for (auto m : pattern.match(f)) { - auto app = cast<OperationInst>(m.first)->cast<AffineApplyOp>(); - FuncBuilder b(m.first); - - using ValueTy = decltype(*(app->getOperands().begin())); - SmallVector<Value *, 8> operands = - functional::map([](ValueTy v) { return static_cast<Value *>(v); }, - app->getOperands().begin(), app->getOperands().end()); + SmallVector<NestedMatch, 8> matches; + pattern.match(f, &matches); + for (auto m : matches) { + auto app = + cast<OperationInst>(m.getMatchedInstruction())->cast<AffineApplyOp>(); + FuncBuilder b(m.getMatchedInstruction()); + SmallVector<Value *, 8> operands(app->getOperands()); makeComposedAffineApply(&b, app->getLoc(), app->getAffineMap(), operands); } } // We should now be able to erase everything in reverse order in this test. - for (auto m : toErase) { - m.first->erase(); + for (auto m : llvm::reverse(toErase)) { + m.getMatchedInstruction()->erase(); } } PassResult VectorizerTestPass::runOnFunction(Function *f) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + // Only support single block functions at this point. if (f->getBlocks().size() != 1) return success(); diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index cfde1ecf0a8..73893599d17 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -655,9 +655,6 @@ struct Vectorize : public FunctionPass { PassResult runOnFunction(Function *f) override; - // Thread-safe RAII contexts local to pass, BumpPtrAllocator freed on exit. - NestedPatternContext MLContext; - static char passID; }; @@ -703,13 +700,13 @@ static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, /// 3. account for impact of vectorization on maximal loop fusion. /// Then we can quantify the above to build a cost model and search over /// strategies. -static bool analyzeProfitability(NestedMatch matches, unsigned depthInPattern, - unsigned patternDepth, +static bool analyzeProfitability(ArrayRef<NestedMatch> matches, + unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForInst>(m.first); - bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth, - strategy); + auto *loop = cast<ForInst>(m.getMatchedInstruction()); + bool fail = analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, + patternDepth, strategy); if (fail) { return fail; } @@ -875,9 +872,10 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, state->terminators.count(opInst) == 0; }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); - auto matches = loadAndStores.match(loop); - for (auto ls : matches) { - auto *opInst = cast<OperationInst>(ls.first); + SmallVector<NestedMatch, 8> loadAndStoresMatches; + loadAndStores.match(loop, &loadAndStoresMatches); + for (auto ls : loadAndStoresMatches) { + auto *opInst = cast<OperationInst>(ls.getMatchedInstruction()); auto load = opInst->dyn_cast<LoadOp>(); auto store = opInst->dyn_cast<StoreOp>(); LLVM_DEBUG(opInst->print(dbgs())); @@ -907,15 +905,15 @@ isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { } /// Forward-declaration. -static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state); +static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, + VectorizationState *state); /// Apply vectorization of `loop` according to `state`. This is only triggered /// if all vectorizations in `childrenMatches` have already succeeded /// recursively in DFS post-order. -static bool doVectorize(NestedMatch::EntryType oneMatch, - VectorizationState *state) { - ForInst *loop = cast<ForInst>(oneMatch.first); - NestedMatch childrenMatches = oneMatch.second; +static bool doVectorize(NestedMatch oneMatch, VectorizationState *state) { + ForInst *loop = cast<ForInst>(oneMatch.getMatchedInstruction()); + auto childrenMatches = oneMatch.getMatchedChildren(); // 1. DFS postorder recursion, if any of my children fails, I fail too. auto fail = vectorizeNonRoot(childrenMatches, state); @@ -949,7 +947,8 @@ static bool doVectorize(NestedMatch::EntryType oneMatch, /// Non-root pattern iterates over the matches at this level, calls doVectorize /// and exits early if anything below fails. -static bool vectorizeNonRoot(NestedMatch matches, VectorizationState *state) { +static bool vectorizeNonRoot(ArrayRef<NestedMatch> matches, + VectorizationState *state) { for (auto m : matches) { auto fail = doVectorize(m, state); if (fail) { @@ -1185,99 +1184,100 @@ static bool vectorizeOperations(VectorizationState *state) { /// The root match thus needs to maintain a clone for handling failure. /// Each root may succeed independently but will otherwise clean after itself if /// anything below it fails. -static bool vectorizeRootMatches(NestedMatch matches, - VectorizationStrategy *strategy) { - for (auto m : matches) { - auto *loop = cast<ForInst>(m.first); - VectorizationState state; - state.strategy = strategy; - - // Since patterns are recursive, they can very well intersect. - // Since we do not want a fully greedy strategy in general, we decouple - // pattern matching, from profitability analysis, from application. - // As a consequence we must check that each root pattern is still - // vectorizable. If a pattern is not vectorizable anymore, we just skip it. - // TODO(ntv): implement a non-greedy profitability analysis that keeps only - // non-intersecting patterns. - if (!isVectorizableLoop(*loop)) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); - continue; - } - FuncBuilder builder(loop); // builder to insert in place of loop - ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop)); - auto fail = doVectorize(m, &state); - /// Sets up error handling for this root loop. This is how the root match - /// maintains a clone for handling failure and restores the proper state via - /// RAII. - ScopeGuard sg2([&fail, loop, clonedLoop]() { - if (fail) { - loop->getInductionVar()->replaceAllUsesWith( - clonedLoop->getInductionVar()); - loop->erase(); - } else { - clonedLoop->erase(); - } - }); +static bool vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { + auto *loop = cast<ForInst>(m.getMatchedInstruction()); + VectorizationState state; + state.strategy = strategy; + + // Since patterns are recursive, they can very well intersect. + // Since we do not want a fully greedy strategy in general, we decouple + // pattern matching, from profitability analysis, from application. + // As a consequence we must check that each root pattern is still + // vectorizable. If a pattern is not vectorizable anymore, we just skip it. + // TODO(ntv): implement a non-greedy profitability analysis that keeps only + // non-intersecting patterns. + if (!isVectorizableLoop(*loop)) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable"); + return true; + } + FuncBuilder builder(loop); // builder to insert in place of loop + ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop)); + auto fail = doVectorize(m, &state); + /// Sets up error handling for this root loop. This is how the root match + /// maintains a clone for handling failure and restores the proper state via + /// RAII. + ScopeGuard sg2([&fail, loop, clonedLoop]() { if (fail) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize"); - continue; + loop->getInductionVar()->replaceAllUsesWith( + clonedLoop->getInductionVar()); + loop->erase(); + } else { + clonedLoop->erase(); } + }); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed root doVectorize"); + return true; + } - // Form the root operationsthat have been set in the replacementMap. - // For now, these roots are the loads for which vector_transfer_read - // operations have been inserted. - auto getDefiningInst = [](const Value *val) { - return const_cast<Value *>(val)->getDefiningInst(); - }; - using ReferenceTy = decltype(*(state.replacementMap.begin())); - auto getKey = [](ReferenceTy it) { return it.first; }; - auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); - - // Vectorize the root operations and everything reached by use-def chains - // except the terminators (store instructions) that need to be - // post-processed separately. - fail = vectorizeOperations(&state); - if (fail) { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); - continue; - } + // Form the root operationsthat have been set in the replacementMap. + // For now, these roots are the loads for which vector_transfer_read + // operations have been inserted. + auto getDefiningInst = [](const Value *val) { + return const_cast<Value *>(val)->getDefiningInst(); + }; + using ReferenceTy = decltype(*(state.replacementMap.begin())); + auto getKey = [](ReferenceTy it) { return it.first; }; + auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); + + // Vectorize the root operations and everything reached by use-def chains + // except the terminators (store instructions) that need to be + // post-processed separately. + fail = vectorizeOperations(&state); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); + return true; + } - // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { - if (fail) { - return; - } - FuncBuilder b(inst); - auto *res = vectorizeOneOperationInst(&b, inst, &state); - if (res == nullptr) { - fail = true; - } - }; - apply(vectorizeOrFail, state.terminators); + // Finally, vectorize the terminators. If anything fails to vectorize, skip. + auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { if (fail) { - LLVM_DEBUG( - dbgs() << "\n[early-vect]+++++ failed to vectorize terminators"); - continue; - } else { - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); + return; } - - // Finish this vectorization pattern. - state.finishVectorizationPattern(); + FuncBuilder b(inst); + auto *res = vectorizeOneOperationInst(&b, inst, &state); + if (res == nullptr) { + fail = true; + } + }; + apply(vectorizeOrFail, state.terminators); + if (fail) { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed to vectorize terminators"); + return true; + } else { + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern"); } + + // Finish this vectorization pattern. + state.finishVectorizationPattern(); return false; } /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. PassResult Vectorize::runOnFunction(Function *f) { - for (auto pat : makePatterns()) { + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. + NestedPatternContext mlContext; + + for (auto &pat : makePatterns()) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); LLVM_DEBUG(f->print(dbgs())); unsigned patternDepth = pat.getDepth(); - auto matches = pat.match(f); + + SmallVector<NestedMatch, 8> matches; + pat.match(f, &matches); // Iterate over all the top-level matches and vectorize eagerly. // This automatically prunes intersecting matches. for (auto m : matches) { @@ -1285,16 +1285,16 @@ PassResult Vectorize::runOnFunction(Function *f) { // TODO(ntv): depending on profitability, elect to reduce the vector size. strategy.vectorSizes.assign(clVirtualVectorSize.begin(), clVirtualVectorSize.end()); - auto fail = analyzeProfitability(m.second, 1, patternDepth, &strategy); + auto fail = analyzeProfitability(m.getMatchedChildren(), 1, patternDepth, + &strategy); if (fail) { continue; } - auto *loop = cast<ForInst>(m.first); + auto *loop = cast<ForInst>(m.getMatchedInstruction()); vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. - fail = vectorizeRootMatches(matches, &strategy); - assert(!fail && "top-level failure should not happen"); + fail = vectorizeRootMatch(m, &strategy); // TODO(ntv): some diagnostics. } } |

