summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Analysis/VectorAnalysis.h10
-rw-r--r--mlir/lib/Analysis/VectorAnalysis.cpp18
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp2
-rw-r--r--mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp2
4 files changed, 12 insertions, 20 deletions
diff --git a/mlir/include/mlir/Analysis/VectorAnalysis.h b/mlir/include/mlir/Analysis/VectorAnalysis.h
index 9f9eaba056f..c7f9e8a6155 100644
--- a/mlir/include/mlir/Analysis/VectorAnalysis.h
+++ b/mlir/include/mlir/Analysis/VectorAnalysis.h
@@ -124,15 +124,15 @@ makePermutationMap(OperationInst *opInst,
namespace matcher {
/// Matches vector_transfer_read, vector_transfer_write and ops that return a
-/// vector type that is at least a 2-multiple of the sub-vector type. This
-/// allows passing over other smaller vector types in the function and avoids
-/// interfering with operations on those.
+/// vector type that is a multiple of the sub-vector type. This allows passing
+/// over other smaller vector types in the function and avoids interfering with
+/// operations on those.
/// This is a first approximation, it can easily be extended in the future.
/// TODO(ntv): this could all be much simpler if we added a bit that a vector
/// type to mark that a vector is a strict super-vector but it still does not
/// warrant adding even 1 extra bit in the IR for now.
-bool operatesOnStrictSuperVectors(const OperationInst &inst,
- VectorType subVectorType);
+bool operatesOnSuperVectors(const OperationInst &inst,
+ VectorType subVectorType);
} // end namespace matcher
} // end namespace mlir
diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp
index e092b29a13b..f00ab4c4c93 100644
--- a/mlir/lib/Analysis/VectorAnalysis.cpp
+++ b/mlir/lib/Analysis/VectorAnalysis.cpp
@@ -179,15 +179,15 @@ mlir::makePermutationMap(OperationInst *opInst,
enclosingLoopToVectorDim);
}
-bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst,
- VectorType subVectorType) {
+bool mlir::matcher::operatesOnSuperVectors(const OperationInst &opInst,
+ VectorType subVectorType) {
// First, extract the vector type and ditinguish between:
// a. ops that *must* lower a super-vector (i.e. vector_transfer_read,
// vector_transfer_write); and
// b. ops that *may* lower a super-vector (all other ops).
// The ops that *may* lower a super-vector only do so if the super-vector to
- // sub-vector ratio is striclty greater than 1. The ops that *must* lower a
- // super-vector are explicitly checked for this property.
+ // sub-vector ratio exists. The ops that *must* lower a super-vector are
+ // explicitly checked for this property.
/// TODO(ntv): there should be a single function for all ops to do this so we
/// do not have to special case. Maybe a trait, or just a method, unclear atm.
bool mustDivide = false;
@@ -235,13 +235,5 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationInst &opInst,
return false;
}
- // A strict super-vector is at least 2 sub-vectors.
- for (auto m : *ratio) {
- if (m > 1) {
- return true;
- }
- }
-
- // Not a strict super-vector.
- return false;
+ return true;
}
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index ebfdbc28d6c..bbe2f85319b 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -741,7 +741,7 @@ PassResult MaterializeVectorsPass::runOnFunction(Function *f) {
if (!opInst.isa<VectorTransferWriteOp>()) {
return false;
}
- return matcher::operatesOnStrictSuperVectors(opInst, subVectorType);
+ return matcher::operatesOnSuperVectors(opInst, subVectorType);
};
auto pat = Op(filter);
auto matches = pat.match(f);
diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
index 9dfcda4081f..afacc6be9f2 100644
--- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
+++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp
@@ -106,7 +106,7 @@ void VectorizerTestPass::testVectorShapeRatio(Function *f) {
assert(subVectorType.getElementType() ==
Type::getF32(subVectorType.getContext()) &&
"Only f32 supported for now");
- if (!matcher::operatesOnStrictSuperVectors(*opInst, subVectorType)) {
+ if (!matcher::operatesOnSuperVectors(*opInst, subVectorType)) {
return false;
}
if (opInst->getNumResults() != 1) {
OpenPOWER on IntegriCloud