From 094ca64ab06359437524712e65669a10bac816a7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 29 Mar 2019 09:34:06 -0700 Subject: Refactor vectorization patterns This CL removes the reliance of the vectorize pass on the specification of a `fastestVaryingDim` parameter. This parameter is a restriction meant to more easily target a particular loop/memref combination for vectorization and is mainly used for testing. This also had the side-effect of restricting vectorization patterns to only the ones in which all memrefs were contiguous along the same loop dimension. This simple restriction prevented matmul to vectorize in 2-D. this CL removes the restriction and adds the matmul test which vectorizes in 2-D along the parallel loops. Support for reduction loops is left for future work. PiperOrigin-RevId: 240993827 --- mlir/lib/Transforms/Vectorize.cpp | 212 +++++++++++++++++--------------------- 1 file changed, 93 insertions(+), 119 deletions(-) (limited to 'mlir/lib/Transforms') diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 95b389c5559..087453235f9 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -553,7 +553,7 @@ static llvm::cl::OptionCategory clOptionsCategory("vectorize options"); static llvm::cl::list clVirtualVectorSize( "virtual-vector-size", - llvm::cl::desc("Specify n-D virtual vector size for vectorization"), + llvm::cl::desc("Specify an n-D virtual vector size for vectorization"), llvm::cl::ZeroOrMore, llvm::cl::cat(clOptionsCategory)); static llvm::cl::list clFastestVaryingPattern( @@ -567,124 +567,84 @@ static llvm::cl::list clFastestVaryingPattern( /// Forward declaration. static FilterFunctionType isVectorizableLoopPtrFactory(const llvm::DenseSet ¶llelLoops, - unsigned fastestVaryingMemRefDimension); - -// Build a bunch of predetermined patterns that will be traversed in order. -// Due to the recursive nature of NestedPatterns, this captures -// arbitrarily nested pairs of loops at any position in the tree. -/// Note that this currently only matches 2 nested loops and will be extended. -// TODO(ntv): support 3-D loop patterns with a common reduction loop that can -// be matched to GEMMs. -static std::vector -defaultPatterns(const llvm::DenseSet ¶llelLoops) { - using matcher::For; - return std::vector{ - // 3-D patterns - For(isVectorizableLoopPtrFactory(parallelLoops, 2), - For(isVectorizableLoopPtrFactory(parallelLoops, 1), - For(isVectorizableLoopPtrFactory(parallelLoops, 0)))), - // for i { for j { A[??f(not i, not j), f(i, not j), f(not i, j)];}} - // test independently with: - // --test-fastest-varying=1 --test-fastest-varying=0 - For(isVectorizableLoopPtrFactory(parallelLoops, 1), - For(isVectorizableLoopPtrFactory(parallelLoops, 0))), - // for i { for j { A[??f(not i, not j), f(i, not j), ?, f(not i, j)];}} - // test independently with: - // --test-fastest-varying=2 --test-fastest-varying=0 - For(isVectorizableLoopPtrFactory(parallelLoops, 2), - For(isVectorizableLoopPtrFactory(parallelLoops, 0))), - // for i { for j { A[??f(not i, not j), f(i, not j), ?, ?, f(not i, j)];}} - // test independently with: - // --test-fastest-varying=3 --test-fastest-varying=0 - For(isVectorizableLoopPtrFactory(parallelLoops, 3), - For(isVectorizableLoopPtrFactory(parallelLoops, 0))), - // for i { for j { A[??f(not i, not j), f(not i, j), f(i, not j)];}} - // test independently with: - // --test-fastest-varying=0 --test-fastest-varying=1 - For(isVectorizableLoopPtrFactory(parallelLoops, 0), - For(isVectorizableLoopPtrFactory(parallelLoops, 1))), - // for i { for j { A[??f(not i, not j), f(not i, j), ?, f(i, not j)];}} - // test independently with: - // --test-fastest-varying=0 --test-fastest-varying=2 - For(isVectorizableLoopPtrFactory(parallelLoops, 0), - For(isVectorizableLoopPtrFactory(parallelLoops, 2))), - // for i { for j { A[??f(not i, not j), f(not i, j), ?, ?, f(i, not j)];}} - // test independently with: - // --test-fastest-varying=0 --test-fastest-varying=3 - For(isVectorizableLoopPtrFactory(parallelLoops, 0), - For(isVectorizableLoopPtrFactory(parallelLoops, 3))), - // for i { A[??f(not i) , f(i)];} - // test independently with: --test-fastest-varying=0 - For(isVectorizableLoopPtrFactory(parallelLoops, 0)), - // for i { A[??f(not i) , f(i), ?];} - // test independently with: --test-fastest-varying=1 - For(isVectorizableLoopPtrFactory(parallelLoops, 1)), - // for i { A[??f(not i) , f(i), ?, ?];} - // test independently with: --test-fastest-varying=2 - For(isVectorizableLoopPtrFactory(parallelLoops, 2)), - // for i { A[??f(not i) , f(i), ?, ?, ?];} - // test independently with: --test-fastest-varying=3 - For(isVectorizableLoopPtrFactory(parallelLoops, 3))}; -} + int fastestVaryingMemRefDimension); /// Creates a vectorization pattern from the command line arguments. /// Up to 3-D patterns are supported. /// If the command line argument requests a pattern of higher order, returns an /// empty pattern list which will conservatively result in no vectorization. static std::vector -makePatterns(const llvm::DenseSet ¶llelLoops) { +makePatterns(const llvm::DenseSet ¶llelLoops, int vectorRank, + ArrayRef fastestVaryingPattern) { using matcher::For; - if (clFastestVaryingPattern.empty()) { - return defaultPatterns(parallelLoops); - } - switch (clFastestVaryingPattern.size()) { + int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0]; + int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1]; + int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2]; + switch (vectorRank) { case 1: - return {For(isVectorizableLoopPtrFactory(parallelLoops, - clFastestVaryingPattern[0]))}; + return {For(isVectorizableLoopPtrFactory(parallelLoops, d0))}; case 2: - return {For( - isVectorizableLoopPtrFactory(parallelLoops, clFastestVaryingPattern[0]), - For(isVectorizableLoopPtrFactory(parallelLoops, - clFastestVaryingPattern[1])))}; + return {For(isVectorizableLoopPtrFactory(parallelLoops, d0), + For(isVectorizableLoopPtrFactory(parallelLoops, d1)))}; case 3: - return {For( - isVectorizableLoopPtrFactory(parallelLoops, clFastestVaryingPattern[0]), - For(isVectorizableLoopPtrFactory(parallelLoops, - clFastestVaryingPattern[1]), - For(isVectorizableLoopPtrFactory(parallelLoops, - clFastestVaryingPattern[2]))))}; - default: + return {For(isVectorizableLoopPtrFactory(parallelLoops, d0), + For(isVectorizableLoopPtrFactory(parallelLoops, d1), + For(isVectorizableLoopPtrFactory(parallelLoops, d2))))}; + default: { return std::vector(); } + } } namespace { +/// Base state for the vectorize pass. +/// Command line arguments are preempted by non-empty pass arguments. struct Vectorize : public FunctionPass { - Vectorize() { - if (!clVirtualVectorSize.empty()) { - vectorSizes.reserve(clVirtualVectorSize.size()); - this->vectorSizes.assign(clVirtualVectorSize.begin(), - clVirtualVectorSize.end()); - } - } - Vectorize(ArrayRef virtualVectorSize) { - if (clVirtualVectorSize.empty()) { - this->vectorSizes.assign(virtualVectorSize.begin(), - virtualVectorSize.end()); - } else { - vectorSizes.reserve(clVirtualVectorSize.size()); - this->vectorSizes.assign(clVirtualVectorSize.begin(), - clVirtualVectorSize.end()); - } - } + Vectorize(); + Vectorize(ArrayRef virtualVectorSize); + Vectorize(ArrayRef virtualVectorSize, + ArrayRef fastestVaryingPattern); void runOnFunction() override; + + // The virtual vector size that we vectorize to. SmallVector vectorSizes; + // Optionally, the fixed mapping from loop to fastest varying MemRef dimension + // for all the MemRefs within a loop pattern: + // the index represents the loop depth, the value represents the k^th + // fastest varying memory dimension. + // This is voluntarily restrictive and is meant to precisely target a + // particular loop/op pair, for testing purposes. + SmallVector fastestVaryingPattern; }; } // end anonymous namespace -/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. ////// +Vectorize::Vectorize() { + this->vectorSizes.assign(clVirtualVectorSize.begin(), + clVirtualVectorSize.end()); + this->fastestVaryingPattern.assign(clFastestVaryingPattern.begin(), + clFastestVaryingPattern.end()); +} + +Vectorize::Vectorize(ArrayRef virtualVectorSize) : Vectorize() { + if (!virtualVectorSize.empty()) { + this->vectorSizes.assign(virtualVectorSize.begin(), + virtualVectorSize.end()); + } +} + +Vectorize::Vectorize(ArrayRef virtualVectorSize, + ArrayRef fastestVaryingPattern) + : Vectorize(virtualVectorSize) { + if (!fastestVaryingPattern.empty()) { + this->fastestVaryingPattern.assign(fastestVaryingPattern.begin(), + fastestVaryingPattern.end()); + } +} + +/////// TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate. +///////// namespace { struct VectorizationStrategy { @@ -833,12 +793,12 @@ void VectorizationState::registerReplacement(Value *key, Value *value) { /// vectorized immediately. The resulting vector_transfer_read is immediately /// registered to replace all uses of the LoadOp in this pattern's scope. /// -/// StoreOp are the terminals of the vectorizeNonTerminals call. They need -/// to be vectorized late once all the use-def chains have been traversed. -/// Additionally, they may have ssa-values operands which come from outside -/// the scope of the current pattern. -/// Such special cases force us to delay the vectorization of the stores -/// until the last step. Here we merely register the store operation. +/// StoreOp are the terminals of the vectorizeNonTerminals call. They need to be +/// vectorized late once all the use-def chains have been traversed. +/// Additionally, they may have ssa-values operands which come from outside the +/// scope of the current pattern. +/// Such special cases force us to delay the vectorization of the stores until +/// the last step. Here we merely register the store operation. template static LogicalResult vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, @@ -860,6 +820,8 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, if (opInst->template isa()) { auto permutationMap = makePermutationMap(opInst, state->strategy->loopToVectorDim); + if (!permutationMap) + return LogicalResult::Failure; LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); FuncBuilder b(opInst); @@ -907,22 +869,23 @@ static LogicalResult vectorizeAffineForOp(AffineForOp loop, int64_t step, return success(); } -/// Returns a FilterFunctionType that can be used in NestedPattern to -/// match a loop whose underlying load/store accesses are all varying along the -/// `fastestVaryingMemRefDimension`. -/// TODO(ntv): In the future, allow more interesting mixed layout permutation -/// once we understand better the performance implications and we are confident -/// we can build a cost model and a search procedure. +/// Returns a FilterFunctionType that can be used in NestedPattern to match a +/// loop whose underlying load/store accesses are either invariant or all +// varying along the `fastestVaryingMemRefDimension`. static FilterFunctionType isVectorizableLoopPtrFactory(const llvm::DenseSet ¶llelLoops, - unsigned fastestVaryingMemRefDimension) { + int fastestVaryingMemRefDimension) { return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) { auto loop = forOp.cast(); auto parallelIt = parallelLoops.find(loop); if (parallelIt == parallelLoops.end()) return false; - return isVectorizableLoopBodyAlongFastestVaryingMemRefDim( - loop, fastestVaryingMemRefDimension); + int memRefDim = -1; + auto vectorizableBody = isVectorizableLoopBody(loop, &memRefDim); + if (!vectorizableBody) + return false; + return memRefDim == -1 || fastestVaryingMemRefDimension == -1 || + memRefDim == fastestVaryingMemRefDimension; }; } @@ -1047,15 +1010,15 @@ static Value *vectorizeOperand(Value *operand, Operation *op, return nullptr; }; -/// Encodes Operation-specific behavior for vectorization. In general we -/// assume that all operands of an op must be vectorized but this is not always -/// true. In the future, it would be nice to have a trait that describes how a +/// Encodes Operation-specific behavior for vectorization. In general we assume +/// that all operands of an op must be vectorized but this is not always true. +/// In the future, it would be nice to have a trait that describes how a /// particular operation vectorizes. For now we implement the case distinction /// here. /// Returns a vectorized form of an operation or nullptr if vectorization fails. -/// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. -/// Maybe some Ops are not vectorizable or require some tricky logic, we cannot -/// do one-off logic here; ideally it would be TableGen'd. +// TODO(ntv): consider adding a trait to Op to describe how it gets vectorized. +// Maybe some Ops are not vectorizable or require some tricky logic, we cannot +// do one-off logic here; ideally it would be TableGen'd. static Operation *vectorizeOneOperation(Operation *opInst, VectorizationState *state) { // Sanity checks. @@ -1074,6 +1037,8 @@ static Operation *vectorizeOneOperation(Operation *opInst, FuncBuilder b(opInst); auto permutationMap = makePermutationMap(opInst, state->strategy->loopToVectorDim); + if (!permutationMap) + return nullptr; LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create( @@ -1249,10 +1214,18 @@ static LogicalResult vectorizeRootMatch(NestedMatch m, /// Applies vectorization to the current Function by searching over a bunch of /// predetermined patterns. void Vectorize::runOnFunction() { + Function &f = getFunction(); + if (!fastestVaryingPattern.empty() && + fastestVaryingPattern.size() != vectorSizes.size()) { + f.emitNote("Fastest varying pattern specified with different size than the " + "vector size."); + this->signalPassFailure(); + return; + } + // Thread-safe RAII local context, BumpPtrAllocator freed on exit. NestedPatternContext mlContext; - Function &f = getFunction(); llvm::DenseSet parallelLoops; f.walkPostOrder([¶llelLoops](Operation *op) { if (auto loop = op->dyn_cast()) { @@ -1262,7 +1235,8 @@ void Vectorize::runOnFunction() { } }); - for (auto &pat : makePatterns(parallelLoops)) { + for (auto &pat : + makePatterns(parallelLoops, vectorSizes.size(), fastestVaryingPattern)) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on Function\n"); -- cgit v1.2.3