diff options
Diffstat (limited to 'mlir/lib/Transforms/Vectorize.cpp')
-rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index a1e87568745..b3eea35a55f 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -35,6 +35,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMap.h" @@ -718,6 +719,8 @@ struct VectorizationState { // Checks that the type of `op` is AffineStoreOp and adds it to the terminals // set. void registerTerminal(Operation *op); + // Folder used to factor out constant creation. + OperationFolder *folder; private: void registerReplacement(Value *key, Value *value); @@ -832,7 +835,11 @@ static LogicalResult vectorizeRootOrTerminal(Value *iv, LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<vector::VectorTransferReadOp>( opInst->getLoc(), vectorType, memoryOp.getMemRef(), - map(makePtrDynCaster<Value>(), indices), permutationMap); + map(makePtrDynCaster<Value>(), indices), + AffineMapAttr::get(permutationMap), + // TODO(b/144455320) add a proper padding value, not just 0.0 : f32 + state->folder->create<ConstantFloatOp>( + b, opInst->getLoc(), llvm::APFloat(0.0f), b.getF32Type())); state->registerReplacement(opInst, transfer.getOperation()); } else { state->registerTerminal(opInst); @@ -1058,7 +1065,8 @@ static Operation *vectorizeOneOperation(Operation *opInst, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<vector::VectorTransferWriteOp>( - opInst->getLoc(), vectorValue, memRef, indices, permutationMap); + opInst->getLoc(), vectorValue, memRef, indices, + AffineMapAttr::get(permutationMap)); auto *res = transfer.getOperation(); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminals" (i.e. AffineStoreOps) are erased on the spot. @@ -1152,8 +1160,10 @@ static LogicalResult vectorizeNonTerminals(VectorizationState *state) { static LogicalResult vectorizeRootMatch(NestedMatch m, VectorizationStrategy *strategy) { auto loop = cast<AffineForOp>(m.getMatchedOperation()); + OperationFolder folder(loop.getContext()); VectorizationState state; state.strategy = strategy; + state.folder = &folder; // Since patterns are recursive, they can very well intersect. // Since we do not want a fully greedy strategy in general, we decouple |