diff options
Diffstat (limited to 'mlir/lib/Transforms/Vectorize.cpp')
-rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index e3212d54e42..d8f5b1dc0e4 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -705,7 +705,7 @@ struct VectorizationState { // Map of old scalar Operation to new vectorized Operation. DenseMap<Operation *, Operation *> vectorizationMap; // Map of old scalar Value to new vectorized Value. - DenseMap<Value *, Value *> replacementMap; + DenseMap<ValuePtr, ValuePtr> replacementMap; // The strategy drives which loop to vectorize by which amount. const VectorizationStrategy *strategy; // Use-def roots. These represent the starting points for the worklist in the @@ -728,7 +728,7 @@ struct VectorizationState { OperationFolder *folder; private: - void registerReplacement(Value *key, Value *value); + void registerReplacement(ValuePtr key, ValuePtr value); }; } // end namespace @@ -768,7 +768,7 @@ void VectorizationState::finishVectorizationPattern() { } } -void VectorizationState::registerReplacement(Value *key, Value *value) { +void VectorizationState::registerReplacement(ValuePtr key, ValuePtr value) { assert(replacementMap.count(key) == 0 && "replacement already registered"); replacementMap.insert(std::make_pair(key, value)); } @@ -776,7 +776,7 @@ void VectorizationState::registerReplacement(Value *key, Value *value) { // Apply 'map' with 'mapOperands' returning resulting values in 'results'. static void computeMemoryOpIndices(Operation *op, AffineMap map, ValueRange mapOperands, - SmallVectorImpl<Value *> &results) { + SmallVectorImpl<ValuePtr> &results) { OpBuilder builder(op); for (auto resultExpr : map.getResults()) { auto singleResMap = @@ -803,7 +803,7 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map, /// Such special cases force us to delay the vectorization of the stores until /// the last step. Here we merely register the store operation. template <typename LoadOrStoreOpPointer> -static LogicalResult vectorizeRootOrTerminal(Value *iv, +static LogicalResult vectorizeRootOrTerminal(ValuePtr iv, LoadOrStoreOpPointer memoryOp, VectorizationState *state) { auto memRefType = memoryOp.getMemRef()->getType().template cast<MemRefType>(); @@ -823,7 +823,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, if (auto load = dyn_cast<AffineLoadOp>(opInst)) { OpBuilder b(opInst); ValueRange mapOperands = load.getMapOperands(); - SmallVector<Value *, 8> indices; + SmallVector<ValuePtr, 8> indices; indices.reserve(load.getMemRefType().getRank()); if (load.getAffineMap() != b.getMultiDimIdentityMap(load.getMemRefType().getRank())) { @@ -838,8 +838,7 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<vector::TransferReadOp>( - opInst->getLoc(), vectorType, memoryOp.getMemRef(), - map(makePtrDynCaster<Value>(), indices), + opInst->getLoc(), vectorType, memoryOp.getMemRef(), indices, AffineMapAttr::get(permutationMap), // TODO(b/144455320) add a proper padding value, not just 0.0 : f32 state->folder->create<ConstantFloatOp>(b, opInst->getLoc(), @@ -951,7 +950,8 @@ vectorizeLoopsAndLoadsRecursively(NestedMatch oneMatch, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) { +static ValuePtr vectorizeConstant(Operation *op, ConstantOp constant, + Type type) { if (!type || !type.isa<VectorType>() || !VectorType::isValidElementType(constant.getType())) { return nullptr; @@ -989,8 +989,8 @@ static Value *vectorizeConstant(Operation *op, ConstantOp constant, Type type) { /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static Value *vectorizeOperand(Value *operand, Operation *op, - VectorizationState *state) { +static ValuePtr vectorizeOperand(ValuePtr operand, Operation *op, + VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); // 1. If this value has already been vectorized this round, we are done. @@ -1004,7 +1004,7 @@ static Value *vectorizeOperand(Value *operand, Operation *op, // been vectorized. This would be invalid IR. auto it = state->replacementMap.find(operand); if (it != state->replacementMap.end()) { - auto *res = it->second; + auto res = it->second; LLVM_DEBUG(dbgs() << "-> delayed replacement by: "); LLVM_DEBUG(res->print(dbgs())); return res; @@ -1047,12 +1047,12 @@ static Operation *vectorizeOneOperation(Operation *opInst, if (auto store = dyn_cast<AffineStoreOp>(opInst)) { OpBuilder b(opInst); - auto *memRef = store.getMemRef(); - auto *value = store.getValueToStore(); - auto *vectorValue = vectorizeOperand(value, opInst, state); + auto memRef = store.getMemRef(); + auto value = store.getValueToStore(); + auto vectorValue = vectorizeOperand(value, opInst, state); ValueRange mapOperands = store.getMapOperands(); - SmallVector<Value *, 8> indices; + SmallVector<ValuePtr, 8> indices; indices.reserve(store.getMemRefType().getRank()); if (store.getAffineMap() != b.getMultiDimIdentityMap(store.getMemRefType().getRank())) { @@ -1081,16 +1081,16 @@ static Operation *vectorizeOneOperation(Operation *opInst, return nullptr; SmallVector<Type, 8> vectorTypes; - for (auto *v : opInst->getResults()) { + for (auto v : opInst->getResults()) { vectorTypes.push_back( VectorType::get(state->strategy->vectorSizes, v->getType())); } - SmallVector<Value *, 8> vectorOperands; - for (auto *v : opInst->getOperands()) { + SmallVector<ValuePtr, 8> vectorOperands; + for (auto v : opInst->getOperands()) { vectorOperands.push_back(vectorizeOperand(v, opInst, state)); } // Check whether a single operand is null. If so, vectorization failed. - bool success = llvm::all_of(vectorOperands, [](Value *op) { return op; }); + bool success = llvm::all_of(vectorOperands, [](ValuePtr op) { return op; }); if (!success) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize"); return nullptr; |