diff options
| -rw-r--r-- | mlir/include/mlir/Analysis/VectorAnalysis.h | 10 | ||||
| -rw-r--r-- | mlir/lib/Analysis/VectorAnalysis.cpp | 18 | ||||
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp | 2 |
4 files changed, 12 insertions, 20 deletions
diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h index 9f9eaba056f..c7f9e8a6155 100644 --- a/mlir/include/mlir/Analysis/VectorAnalysis.h +++ b/mlir/include/mlir/Analysis/VectorAnalysis.h @@ -124,15 +124,15 @@ makePermutationMap(OperationInst *opInst, namespace matcher { /// Matches vector_transfer_read, vector_transfer_write and ops that return a -/// vector type that is at least a 2-multiple of the sub-vector type. This -/// allows passing over other smaller vector types in the function and avoids -/// interfering with operations on those. +/// vector type that is a multiple of the sub-vector type. This allows passing +/// over other smaller vector types in the function and avoids interfering with +/// operations on those. /// This is a first approximation, it can easily be extended in the future. /// TODO(ntv): this could all be much simpler if we added a bit that a vector /// type to mark that a vector is a strict super-vector but it still does not /// warrant adding even 1 extra bit in the IR for now. -bool operatesOnStrictSuperVectors(const OperationInst &inst, - VectorType subVectorType); +bool operatesOnSuperVectors(const OperationInst &inst, + VectorType subVectorType); } // end namespace matcher } // end namespace mlir diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index e092b29a13b..f00ab4c4c93 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -179,15 +179,15 @@ mlir::makePermutationMap(OperationInst *opInst, enclosingLoopToVectorDim); } -bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst, - VectorType subVectorType) { +bool mlir::matcher::operatesOnSuperVectors(const OperationInst &opInst, + VectorType subVectorType) { // First, extract the vector type and ditinguish between: // a. ops that *must* lower a super-vector (i.e. vector_transfer_read, // vector_transfer_write); and // b. ops that *may* lower a super-vector (all other ops). // The ops that *may* lower a super-vector only do so if the super-vector to - // sub-vector ratio is striclty greater than 1. The ops that *must* lower a - // super-vector are explicitly checked for this property. + // sub-vector ratio exists. The ops that *must* lower a super-vector are + // explicitly checked for this property. /// TODO(ntv): there should be a single function for all ops to do this so we /// do not have to special case. Maybe a trait, or just a method, unclear atm. bool mustDivide = false; @@ -235,13 +235,5 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst, return false; } - // A strict super-vector is at least 2 sub-vectors. - for (auto m : *ratio) { - if (m > 1) { - return true; - } - } - - // Not a strict super-vector. - return false; + return true; } diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index ebfdbc28d6c..bbe2f85319b 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -741,7 +741,7 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) { if (!opInst.isa<VectorTransferWriteOp>()) { return false; } - return matcher::operatesOnStrictSuperVectors(opInst, subVectorType); + return matcher::operatesOnSuperVectors(opInst, subVectorType); }; auto pat = Op(filter); auto matches = pat.match(f); diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index 9dfcda4081f..afacc6be9f2 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -106,7 +106,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) { assert(subVectorType.getElementType() == Type::getF32(subVectorType.getContext()) && "Only f32 supported for now"); - if (!matcher::operatesOnStrictSuperVectors(*opInst, subVectorType)) { + if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) { return false; } if (opInst->getNumResults() != 1) { |

