summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Vectorize.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Vectorize.cpp')
-rw-r--r--mlir/lib/Transforms/Vectorize.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp
index d7a1f531cef..511afa95993 100644
--- a/mlir/lib/Transforms/Vectorize.cpp
+++ b/mlir/lib/Transforms/Vectorize.cpp
@@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches,
/// sizes specified by vectorSize. The MemRef lives in the same memory space as
/// tmpl. The MemRef should be promoted to a closer memory address space in a
/// later pass.
-static MemRefType *getVectorizedMemRefType(MemRefType *tmpl,
- ArrayRef<int> vectorSizes) {
- auto *elementType = tmpl->getElementType();
- assert(!dyn_cast<VectorType>(elementType) &&
+static MemRefType getVectorizedMemRefType(MemRefType tmpl,
+ ArrayRef<int> vectorSizes) {
+ auto elementType = tmpl.getElementType();
+ assert(!elementType.dyn_cast<VectorType>() &&
"Can't vectorize an already vector type");
- assert(tmpl->getAffineMaps().empty() &&
+ assert(tmpl.getAffineMaps().empty() &&
"Unsupported non-implicit identity map");
return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
- tmpl->getMemorySpace());
+ tmpl.getMemorySpace());
}
/// Creates an unaligned load with the following semantics:
@@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc,
operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map;
- std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+ std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType();
};
auto types = map(getType, operands);
@@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc,
operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map;
- std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+ std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType();
};
auto types = map(getType, operands);
@@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() {
template <typename LoadOrStoreOpPointer>
static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
ArrayRef<int> vectorSize) {
- auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
- auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
+ auto memRefType =
+ memoryOp->getMemRef()->getType().template cast<MemRefType>();
+ auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
// Materialize a MemRef with 1 vector.
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
OpenPOWER on IntegriCloud