diff options
Diffstat (limited to 'mlir/lib/Transforms/Vectorize.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 184 |
1 files changed, 73 insertions, 111 deletions
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index e4822c27ac9..1e4102156af 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -671,6 +671,18 @@ struct VectorizationStrategy { } // end anonymous namespace +static void vectorizeLoopIfProfitable(ForStmt *loop, unsigned depthInPattern, + unsigned patternDepth, + VectorizationStrategy *strategy) { + assert(patternDepth > depthInPattern); + if (patternDepth - depthInPattern > strategy->vectorSizes.size()) { + // Don't vectorize this loop + return; + } + strategy->loopToVectorDim[loop] = + strategy->vectorSizes.size() - (patternDepth - depthInPattern); +} + /// Implements a simple strawman strategy for vectorization. /// Given a matched pattern `matches` of depth `patternDepth`, this strategy /// greedily assigns the fastest varying dimension ** of the vector ** to the @@ -696,17 +708,11 @@ static bool analyzeProfitability(MLFunctionMatches matches, if (fail) { return fail; } - assert(patternDepth > depthInPattern); - if (patternDepth - depthInPattern <= strategy->vectorSizes.size()) { - strategy->loopToVectorDim[loop] = - strategy->vectorSizes.size() - (patternDepth - depthInPattern); - } else { - // Don't vectorize - strategy->loopToVectorDim[loop] = -1; - } + vectorizeLoopIfProfitable(loop, depthInPattern, patternDepth, strategy); } return false; } + ///// end TODO(ntv): Hoist to a VectorizationStrategy.cpp when appropriate ///// namespace { @@ -799,39 +805,6 @@ void VectorizationState::registerReplacement(const SSAValue *key, ////// TODO(ntv): Hoist to a VectorizationMaterialize.cpp when appropriate. //// -/// Creates a vector_transfer_read that loads a scalar MemRef into a -/// super-vector register. -/// -/// Usage: -/// This vector_transfer_read op will be implemented as a PseudoOp for -/// different backends. In its current form it is only used to load into a -/// vector; where the vector may have any shape that is some multiple of the -/// hardware-specific vector size used to implement the PseudoOp efficiently. -/// This is used to implement "non-effecting padding" for early vectorization -/// and allows higher-level passes in the codegen to not worry about -/// hardware-specific implementation details. -/// -/// TODO(ntv): -/// 1. implement this end-to-end for some backend; -/// 2. support operation-specific padding values to properly implement -/// "non-effecting padding"; -/// 3. support input map for on-the-fly transpositions (point 1 above); -/// 4. support broadcast map (point 5 above). -/// -/// TODO(andydavis,bondhugula,ntv): -/// 1. generalize to support padding semantics and offsets within vector type. -static OperationStmt * -createVectorTransferRead(OperationStmt *loadOp, VectorType vectorType, - SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices) { - auto memRefType = srcMemRef->getType().cast<MemRefType>(); - MLFuncBuilder b(loadOp); - // TODO(ntv): neutral for noneffective padding. - auto transfer = b.create<VectorTransferReadOp>( - loadOp->getLoc(), vectorType, srcMemRef, srcIndices, - makePermutationMap(memRefType, vectorType)); - return cast<OperationStmt>(transfer->getOperation()); -} - /// Handles the vectorization of load and store MLIR operations. /// /// LoadOp operations are the roots of the vectorizeOperations call. They are @@ -863,10 +836,17 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp, // TODO(ntv): increase the expressiveness power of vector_transfer operations // as needed by various targets. if (opStmt->template isa<LoadOp>()) { - auto *transfer = createVectorTransferRead( - opStmt, vectorType, memoryOp->getMemRef(), - map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices())); - state->registerReplacement(opStmt, transfer); + auto permutationMap = + makePermutationMap(opStmt, state->strategy->loopToVectorDim); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); + LLVM_DEBUG(permutationMap.print(dbgs())); + MLFuncBuilder b(opStmt); + auto transfer = b.create<VectorTransferReadOp>( + opStmt->getLoc(), vectorType, memoryOp->getMemRef(), + map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices()), + permutationMap); + state->registerReplacement(opStmt, + cast<OperationStmt>(transfer->getOperation())); } else { state->registerTerminator(opStmt); } @@ -943,16 +923,13 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch, // 2. This loop may have been omitted from vectorization for various reasons // (e.g. due to the performance model or pattern depth > vector size). - assert(state->strategy->loopToVectorDim.count(loop)); - assert(state->strategy->loopToVectorDim.find(loop) != - state->strategy->loopToVectorDim.end() && - "Key not found"); - int vectorDim = state->strategy->loopToVectorDim.lookup(loop); - if (vectorDim < 0) { + auto it = state->strategy->loopToVectorDim.find(loop); + if (it == state->strategy->loopToVectorDim.end()) { return false; } // 3. Actual post-order transformation. + auto vectorDim = it->second; assert(vectorDim < state->strategy->vectorSizes.size() && "vector dim overflow"); // a. get actual vector size @@ -1077,40 +1054,6 @@ static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt, return nullptr; }; -/// Creates and returns a vector_transfer_write operation, which writes back a -/// super-vector register into a scalar MemRef. -/// -/// Usage: -/// This vector_transfer_write op will be implemented as a PseudoOp for -/// different backends. In its current form it is only used to store from a -/// vector; where the vector may have any shape that is some multiple of -/// the hardware-specific vector size used to implement the PseudoOp -/// efficiently. This is used to implement "non-effecting padding" for early -/// vectorization and allows higher-level passes in the codegen to not worry -/// about hardware-specific implementation details. -/// -/// TODO(ntv): -/// 1. implement this end-to-end for some backend; -/// 2. support write-back in the presence of races and ; -/// 3. support input map for counterpart of broadcast (point 1 above); -/// 4. support dstMap for writing back in non-contiguous memory regions -/// (point 4 above). -static OperationStmt *createVectorTransferWrite(OperationStmt *storeOp, - VectorizationState *state) { - auto store = storeOp->cast<StoreOp>(); - auto *memRef = store->getMemRef(); - auto memRefType = memRef->getType().cast<MemRefType>(); - auto *value = store->getValueToStore(); - auto *vectorValue = vectorizeOperand(value, storeOp, state); - auto vectorType = vectorValue->getType().cast<VectorType>(); - auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices()); - MLFuncBuilder b(storeOp); - auto transfer = b.create<VectorTransferWriteOp>( - storeOp->getLoc(), vectorValue, memRef, indices, - makePermutationMap(memRefType, vectorType)); - return cast<OperationStmt>(transfer->getOperation()); -} - /// Encodes OperationStmt-specific behavior for vectorization. In general we /// assume that all operands of an op must be vectorized but this is not always /// true. In the future, it would be nice to have a trait that describes how a @@ -1121,31 +1064,41 @@ static OperationStmt *createVectorTransferWrite(OperationStmt *storeOp, /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b, - OperationStmt *stmt, + OperationStmt *opStmt, VectorizationState *state) { // Sanity checks. - assert(!stmt->isa<LoadOp>() && + assert(!opStmt->isa<LoadOp>() && "all loads must have already been fully vectorized independently"); - assert(!stmt->isa<VectorTransferReadOp>() && + assert(!opStmt->isa<VectorTransferReadOp>() && "vector_transfer_read cannot be further vectorized"); - assert(!stmt->isa<VectorTransferWriteOp>() && + assert(!opStmt->isa<VectorTransferWriteOp>() && "vector_transfer_write cannot be further vectorized"); - if (stmt->isa<StoreOp>()) { - auto *res = createVectorTransferWrite(stmt, state); - LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: "); - LLVM_DEBUG(res->print(dbgs())); - // Terminators are erased on the spot. - stmt->erase(); + if (auto store = opStmt->dyn_cast<StoreOp>()) { + auto *memRef = store->getMemRef(); + auto *value = store->getValueToStore(); + auto *vectorValue = vectorizeOperand(value, opStmt, state); + auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices()); + MLFuncBuilder b(opStmt); + auto permutationMap = + makePermutationMap(opStmt, state->strategy->loopToVectorDim); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); + LLVM_DEBUG(permutationMap.print(dbgs())); + auto transfer = b.create<VectorTransferWriteOp>( + opStmt->getLoc(), vectorValue, memRef, indices, permutationMap); + auto *res = cast<OperationStmt>(transfer->getOperation()); + LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); + // "Terminators" (i.e. StoreOps) are erased on the spot. + opStmt->erase(); return res; } auto types = map([state](SSAValue *v) { return getVectorType(v, *state); }, - stmt->getResults()); - auto vectorizeOneOperand = [stmt, state](SSAValue *op) { - return vectorizeOperand(op, stmt, state); + opStmt->getResults()); + auto vectorizeOneOperand = [opStmt, state](SSAValue *op) { + return vectorizeOperand(op, opStmt, state); }; - auto operands = map(vectorizeOneOperand, stmt->getOperands()); + auto operands = map(vectorizeOneOperand, opStmt->getOperands()); // Check whether a single operand is null. If so, vectorization failed. bool success = llvm::any_of(operands, [](SSAValue *op) { return op; }); if (!success) { @@ -1159,8 +1112,9 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b, // TODO(ntv): Is it worth considering an OperationStmt.clone operation // which changes the type so we can promote an OperationStmt with less // boilerplate? - return cast<OperationStmt>(b->createOperation( - stmt->getLoc(), stmt->getName(), operands, types, stmt->getAttrs())); + return cast<OperationStmt>(b->createOperation(opStmt->getLoc(), + opStmt->getName(), operands, + types, opStmt->getAttrs())); } /// Iterates over the OperationStmt in the loop and rewrites them using their @@ -1313,18 +1267,26 @@ PassResult Vectorize::runOnMLFunction(MLFunction *f) { LLVM_DEBUG(dbgs() << "\n******************************************"); LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on MLFunction\n"); LLVM_DEBUG(f->print(dbgs())); + unsigned patternDepth = pat.getDepth(); auto matches = pat.match(f); - VectorizationStrategy strategy; - // TODO(ntv): depending on profitability, elect to reduce the vector size. - strategy.vectorSizes = clVirtualVectorSize; - auto fail = analyzeProfitability(matches, 0, pat.getDepth(), &strategy); - if (fail) { - continue; + // Iterate over all the top-level matches and vectorize eagerly. + // This automatically prunes intersecting matches. + for (auto m : matches) { + VectorizationStrategy strategy; + // TODO(ntv): depending on profitability, elect to reduce the vector size. + strategy.vectorSizes = clVirtualVectorSize; + auto fail = analyzeProfitability(m.second, 1, patternDepth, &strategy); + if (fail) { + continue; + } + auto *loop = cast<ForStmt>(m.first); + vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); + // TODO(ntv): if pattern does not apply, report it; alter the + // cost/benefit. + fail = vectorizeRootMatches(matches, &strategy); + assert(!fail && "top-level failure should not happen"); + // TODO(ntv): some diagnostics. } - // TODO(ntv): if pattern does not apply, report it; alter the cost/benefit. - fail = vectorizeRootMatches(matches, &strategy); - assert(!fail && "top-level failure should not happen"); - // TODO(ntv): some diagnotics. } LLVM_DEBUG(dbgs() << "\n"); return PassResult::Success; |

