diff options
Diffstat (limited to 'mlir/lib/Transforms/Vectorize.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index d7a1f531cef..511afa95993 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -202,15 +202,15 @@ static bool analyzeProfitability(MLFunctionMatches matches, /// sizes specified by vectorSize. The MemRef lives in the same memory space as /// tmpl. The MemRef should be promoted to a closer memory address space in a /// later pass. -static MemRefType *getVectorizedMemRefType(MemRefType *tmpl, - ArrayRef<int> vectorSizes) { - auto *elementType = tmpl->getElementType(); - assert(!dyn_cast<VectorType>(elementType) && +static MemRefType getVectorizedMemRefType(MemRefType tmpl, + ArrayRef<int> vectorSizes) { + auto elementType = tmpl.getElementType(); + assert(!elementType.dyn_cast<VectorType>() && "Can't vectorize an already vector type"); - assert(tmpl->getAffineMaps().empty() && + assert(tmpl.getAffineMaps().empty() && "Unsupported non-implicit identity map"); return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {}, - tmpl->getMemorySpace()); + tmpl.getMemorySpace()); } /// Creates an unaligned load with the following semantics: @@ -258,7 +258,7 @@ static void createUnalignedLoad(MLFuncBuilder *b, Location *loc, operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); using functional::map; - std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { + std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type { return v->getType(); }; auto types = map(getType, operands); @@ -310,7 +310,7 @@ static void createUnalignedStore(MLFuncBuilder *b, Location *loc, operands.insert(operands.end(), dstMemRef); operands.insert(operands.end(), dstIndices.begin(), dstIndices.end()); using functional::map; - std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * { + std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type { return v->getType(); }; auto types = map(getType, operands); @@ -348,8 +348,9 @@ static std::function<ToType *(T *)> unwrapPtr() { template <typename LoadOrStoreOpPointer> static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp, ArrayRef<int> vectorSize) { - auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType()); - auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize); + auto memRefType = + memoryOp->getMemRef()->getType().template cast<MemRefType>(); + auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize); // Materialize a MemRef with 1 vector. auto *opStmt = cast<OperationStmt>(memoryOp->getOperation()); |

