diff options
| author | Chris Lattner <clattner@google.com> | 2018-12-28 16:05:35 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:44:30 -0700 |
| commit | 456ad6a8e0ca78ce6277da897a0b820533387d84 (patch) | |
| tree | d9fbb26651eed51b02281be03c9bbc66522cbacf /mlir/lib/Transforms/Vectorize.cpp | |
| parent | b1d9cc4d1ef5a1f81ca566fc06960df2bf31ddfe (diff) | |
| download | bcm5719-llvm-456ad6a8e0ca78ce6277da897a0b820533387d84.tar.gz bcm5719-llvm-456ad6a8e0ca78ce6277da897a0b820533387d84.zip | |
Standardize naming of statements -> instructions, revisting the code base to be
consistent and moving the using declarations over. Hopefully this is the last
truly massive patch in this refactoring.
This is step 21/n towards merging instructions and statements, NFC.
PiperOrigin-RevId: 227178245
Diffstat (limited to 'mlir/lib/Transforms/Vectorize.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Vectorize.cpp | 207 |
1 files changed, 104 insertions, 103 deletions
diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index ddbd6256782..bbb703cd627 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -252,7 +252,7 @@ using namespace mlir; /// ========== /// The algorithm proceeds in a few steps: /// 1. defining super-vectorization patterns and matching them on the tree of -/// ForStmt. A super-vectorization pattern is defined as a recursive data +/// ForInst. A super-vectorization pattern is defined as a recursive data /// structures that matches and captures nested, imperfectly-nested loops /// that have a. comformable loop annotations attached (e.g. parallel, /// reduction, vectoriable, ...) as well as b. all contiguous load/store @@ -279,7 +279,7 @@ using namespace mlir; /// it by its vector form. Otherwise, if the scalar value is a constant, /// it is vectorized into a splat. In all other cases, vectorization for /// the pattern currently fails. -/// e. if everything under the root ForStmt in the current pattern vectorizes +/// e. if everything under the root ForInst in the current pattern vectorizes /// properly, we commit that loop to the IR. Otherwise we discard it and /// restore a previously cloned version of the loop. Thanks to the /// recursive scoping nature of matchers and captured patterns, this is @@ -668,12 +668,12 @@ namespace { struct VectorizationStrategy { ArrayRef<int> vectorSizes; - DenseMap<ForStmt *, unsigned> loopToVectorDim; + DenseMap<ForInst *, unsigned> loopToVectorDim; }; } // end anonymous namespace -static void vectorizeLoopIfProfitable(ForStmt *loop, unsigned depthInPattern, +static void vectorizeLoopIfProfitable(ForInst *loop, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { assert(patternDepth > depthInPattern && @@ -705,7 +705,7 @@ static bool analyzeProfitability(MLFunctionMatches matches, unsigned depthInPattern, unsigned patternDepth, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForStmt>(m.first); + auto *loop = cast<ForInst>(m.first); bool fail = analyzeProfitability(m.second, depthInPattern + 1, patternDepth, strategy); if (fail) { @@ -721,7 +721,7 @@ static bool analyzeProfitability(MLFunctionMatches matches, namespace { struct VectorizationState { - /// Adds an entry of pre/post vectorization statements in the state. + /// Adds an entry of pre/post vectorization instructions in the state. void registerReplacement(OperationInst *key, OperationInst *value); /// When the current vectorization pattern is successful, this erases the /// instructions that were marked for erasure in the proper order and resets @@ -733,7 +733,7 @@ struct VectorizationState { SmallVector<OperationInst *, 16> toErase; // Set of OperationInst that have been vectorized (the values in the // vectorizationMap for hashed access). The vectorizedSet is used in - // particular to filter the statements that have already been vectorized by + // particular to filter the instructions that have already been vectorized by // this pattern, when iterating over nested loops in this pattern. DenseSet<OperationInst *> vectorizedSet; // Map of old scalar OperationInst to new vectorized OperationInst. @@ -747,16 +747,16 @@ struct VectorizationState { // that have been vectorized. They can be retrieved from `vectorizationMap` // but it is convenient to keep track of them in a separate data structure. DenseSet<OperationInst *> roots; - // Terminator statements for the worklist in the vectorizeOperations function. - // They consist of the subset of store operations that have been vectorized. - // They can be retrieved from `vectorizationMap` but it is convenient to keep - // track of them in a separate data structure. Since they do not necessarily - // belong to use-def chains starting from loads (e.g storing a constant), we - // need to handle them in a post-pass. + // Terminator instructions for the worklist in the vectorizeOperations + // function. They consist of the subset of store operations that have been + // vectorized. They can be retrieved from `vectorizationMap` but it is + // convenient to keep track of them in a separate data structure. Since they + // do not necessarily belong to use-def chains starting from loads (e.g + // storing a constant), we need to handle them in a post-pass. DenseSet<OperationInst *> terminators; - // Checks that the type of `stmt` is StoreOp and adds it to the terminators + // Checks that the type of `inst` is StoreOp and adds it to the terminators // set. - void registerTerminator(OperationInst *stmt); + void registerTerminator(OperationInst *inst); private: void registerReplacement(const Value *key, Value *value); @@ -784,19 +784,19 @@ void VectorizationState::registerReplacement(OperationInst *key, } } -void VectorizationState::registerTerminator(OperationInst *stmt) { - assert(stmt->isa<StoreOp>() && "terminator must be a StoreOp"); - assert(terminators.count(stmt) == 0 && +void VectorizationState::registerTerminator(OperationInst *inst) { + assert(inst->isa<StoreOp>() && "terminator must be a StoreOp"); + assert(terminators.count(inst) == 0 && "terminator was already inserted previously"); - terminators.insert(stmt); + terminators.insert(inst); } void VectorizationState::finishVectorizationPattern() { while (!toErase.empty()) { - auto *stmt = toErase.pop_back_val(); + auto *inst = toErase.pop_back_val(); LLVM_DEBUG(dbgs() << "\n[early-vect] finishVectorizationPattern erase: "); - LLVM_DEBUG(stmt->print(dbgs())); - stmt->erase(); + LLVM_DEBUG(inst->print(dbgs())); + inst->erase(); } } @@ -832,23 +832,23 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, auto vectorType = VectorType::get(state->strategy->vectorSizes, elementType); // Materialize a MemRef with 1 vector. - auto *opStmt = memoryOp->getInstruction(); + auto *opInst = memoryOp->getInstruction(); // For now, vector_transfers must be aligned, operate only on indices with an // identity subset of AffineMap and do not change layout. // TODO(ntv): increase the expressiveness power of vector_transfer operations // as needed by various targets. - if (opStmt->template isa<LoadOp>()) { + if (opInst->template isa<LoadOp>()) { auto permutationMap = - makePermutationMap(opStmt, state->strategy->loopToVectorDim); + makePermutationMap(opInst, state->strategy->loopToVectorDim); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto transfer = b.create<VectorTransferReadOp>( - opStmt->getLoc(), vectorType, memoryOp->getMemRef(), + opInst->getLoc(), vectorType, memoryOp->getMemRef(), map(makePtrDynCaster<Value>(), memoryOp->getIndices()), permutationMap); - state->registerReplacement(opStmt, transfer->getInstruction()); + state->registerReplacement(opInst, transfer->getInstruction()); } else { - state->registerTerminator(opStmt); + state->registerTerminator(opInst); } return false; } @@ -856,28 +856,29 @@ static bool vectorizeRootOrTerminal(Value *iv, LoadOrStoreOpPointer memoryOp, /// Coarsens the loops bounds and transforms all remaining load and store /// operations into the appropriate vector_transfer. -static bool vectorizeForStmt(ForStmt *loop, int64_t step, +static bool vectorizeForInst(ForInst *loop, int64_t step, VectorizationState *state) { using namespace functional; loop->setStep(step); - FilterFunctionType notVectorizedThisPattern = [state](const Statement &stmt) { - if (!matcher::isLoadOrStore(stmt)) { - return false; - } - auto *opStmt = cast<OperationInst>(&stmt); - return state->vectorizationMap.count(opStmt) == 0 && - state->vectorizedSet.count(opStmt) == 0 && - state->roots.count(opStmt) == 0 && - state->terminators.count(opStmt) == 0; - }; + FilterFunctionType notVectorizedThisPattern = + [state](const Instruction &inst) { + if (!matcher::isLoadOrStore(inst)) { + return false; + } + auto *opInst = cast<OperationInst>(&inst); + return state->vectorizationMap.count(opInst) == 0 && + state->vectorizedSet.count(opInst) == 0 && + state->roots.count(opInst) == 0 && + state->terminators.count(opInst) == 0; + }; auto loadAndStores = matcher::Op(notVectorizedThisPattern); auto matches = loadAndStores.match(loop); for (auto ls : matches) { - auto *opStmt = cast<OperationInst>(ls.first); - auto load = opStmt->dyn_cast<LoadOp>(); - auto store = opStmt->dyn_cast<StoreOp>(); - LLVM_DEBUG(opStmt->print(dbgs())); + auto *opInst = cast<OperationInst>(ls.first); + auto load = opInst->dyn_cast<LoadOp>(); + auto store = opInst->dyn_cast<StoreOp>(); + LLVM_DEBUG(opInst->print(dbgs())); auto fail = load ? vectorizeRootOrTerminal(loop, load, state) : vectorizeRootOrTerminal(loop, store, state); if (fail) { @@ -895,8 +896,8 @@ static bool vectorizeForStmt(ForStmt *loop, int64_t step, /// we can build a cost model and a search procedure. static FilterFunctionType isVectorizableLoopPtrFactory(unsigned fastestVaryingMemRefDimension) { - return [fastestVaryingMemRefDimension](const Statement &forStmt) { - const auto &loop = cast<ForStmt>(forStmt); + return [fastestVaryingMemRefDimension](const Instruction &forInst) { + const auto &loop = cast<ForInst>(forInst); return isVectorizableLoopAlongFastestVaryingMemRefDim( loop, fastestVaryingMemRefDimension); }; @@ -911,7 +912,7 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// recursively in DFS post-order. static bool doVectorize(MLFunctionMatches::EntryType oneMatch, VectorizationState *state) { - ForStmt *loop = cast<ForStmt>(oneMatch.first); + ForInst *loop = cast<ForInst>(oneMatch.first); MLFunctionMatches childrenMatches = oneMatch.second; // 1. DFS postorder recursion, if any of my children fails, I fail too. @@ -938,10 +939,10 @@ static bool doVectorize(MLFunctionMatches::EntryType oneMatch, // exploratory tradeoffs (see top of the file). Apply coarsening, i.e.: // | ub -> ub // | step -> step * vectorSize - LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForStmt by " << vectorSize + LLVM_DEBUG(dbgs() << "\n[early-vect] vectorizeForInst by " << vectorSize << " : "); LLVM_DEBUG(loop->print(dbgs())); - return vectorizeForStmt(loop, loop->getStep() * vectorSize, state); + return vectorizeForInst(loop, loop->getStep() * vectorSize, state); } /// Non-root pattern iterates over the matches at this level, calls doVectorize @@ -963,20 +964,20 @@ static bool vectorizeNonRoot(MLFunctionMatches matches, /// element type. /// If `type` is not a valid vector type or if the scalar constant is not a /// valid vector element type, returns nullptr. -static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, +static Value *vectorizeConstant(Instruction *inst, const ConstantOp &constant, Type type) { if (!type || !type.isa<VectorType>() || !VectorType::isValidElementType(constant.getType())) { return nullptr; } - FuncBuilder b(stmt); - Location loc = stmt->getLoc(); + FuncBuilder b(inst); + Location loc = inst->getLoc(); auto vectorType = type.cast<VectorType>(); auto attr = SplatElementsAttr::get(vectorType, constant.getValue()); - auto *constantOpStmt = cast<OperationInst>(constant.getInstruction()); + auto *constantOpInst = cast<OperationInst>(constant.getInstruction()); OperationState state( - b.getContext(), loc, constantOpStmt->getName().getStringRef(), {}, + b.getContext(), loc, constantOpInst->getName().getStringRef(), {}, {vectorType}, {make_pair(Identifier::get("value", b.getContext()), attr)}); @@ -985,7 +986,7 @@ static Value *vectorizeConstant(Statement *stmt, const ConstantOp &constant, } /// Returns a uniqu'ed VectorType. -/// In the case `v`'s defining statement is already part of the `state`'s +/// In the case `v`'s defining instruction is already part of the `state`'s /// vectorizedSet, just returns the type of `v`. /// Otherwise, constructs a new VectorType of shape defined by `state.strategy` /// and of elemental type the type of `v`. @@ -993,17 +994,17 @@ static Type getVectorType(Value *v, const VectorizationState &state) { if (!VectorType::isValidElementType(v->getType())) { return Type(); } - auto *definingOpStmt = cast<OperationInst>(v->getDefiningInst()); - if (state.vectorizedSet.count(definingOpStmt) > 0) { + auto *definingOpInst = cast<OperationInst>(v->getDefiningInst()); + if (state.vectorizedSet.count(definingOpInst) > 0) { return v->getType().cast<VectorType>(); } return VectorType::get(state.strategy->vectorSizes, v->getType()); }; -/// Tries to vectorize a given operand `op` of Statement `stmt` during def-chain -/// propagation or during terminator vectorization, by applying the following -/// logic: -/// 1. if the defining statement is part of the vectorizedSet (i.e. vectorized +/// Tries to vectorize a given operand `op` of Instruction `inst` during +/// def-chain propagation or during terminator vectorization, by applying the +/// following logic: +/// 1. if the defining instruction is part of the vectorizedSet (i.e. vectorized /// useby -def propagation), `op` is already in the proper vector form; /// 2. otherwise, the `op` may be in some other vector form that fails to /// vectorize atm (i.e. broadcasting required), returns nullptr to indicate @@ -1021,13 +1022,13 @@ static Type getVectorType(Value *v, const VectorizationState &state) { /// vectorization is possible with the above logic. Returns nullptr otherwise. /// /// TODO(ntv): handle more complex cases. -static Value *vectorizeOperand(Value *operand, Statement *stmt, +static Value *vectorizeOperand(Value *operand, Instruction *inst, VectorizationState *state) { LLVM_DEBUG(dbgs() << "\n[early-vect]vectorize operand: "); LLVM_DEBUG(operand->print(dbgs())); - auto *definingStatement = cast<OperationInst>(operand->getDefiningInst()); + auto *definingInstruction = cast<OperationInst>(operand->getDefiningInst()); // 1. If this value has already been vectorized this round, we are done. - if (state->vectorizedSet.count(definingStatement) > 0) { + if (state->vectorizedSet.count(definingInstruction) > 0) { LLVM_DEBUG(dbgs() << " -> already vector operand"); return operand; } @@ -1049,7 +1050,7 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, } // 3. vectorize constant. if (auto constant = operand->getDefiningInst()->dyn_cast<ConstantOp>()) { - return vectorizeConstant(stmt, *constant, + return vectorizeConstant(inst, *constant, getVectorType(operand, *state).cast<VectorType>()); } // 4. currently non-vectorizable. @@ -1068,41 +1069,41 @@ static Value *vectorizeOperand(Value *operand, Statement *stmt, /// Maybe some Ops are not vectorizable or require some tricky logic, we cannot /// do one-off logic here; ideally it would be TableGen'd. static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, - OperationInst *opStmt, + OperationInst *opInst, VectorizationState *state) { // Sanity checks. - assert(!opStmt->isa<LoadOp>() && + assert(!opInst->isa<LoadOp>() && "all loads must have already been fully vectorized independently"); - assert(!opStmt->isa<VectorTransferReadOp>() && + assert(!opInst->isa<VectorTransferReadOp>() && "vector_transfer_read cannot be further vectorized"); - assert(!opStmt->isa<VectorTransferWriteOp>() && + assert(!opInst->isa<VectorTransferWriteOp>() && "vector_transfer_write cannot be further vectorized"); - if (auto store = opStmt->dyn_cast<StoreOp>()) { + if (auto store = opInst->dyn_cast<StoreOp>()) { auto *memRef = store->getMemRef(); auto *value = store->getValueToStore(); - auto *vectorValue = vectorizeOperand(value, opStmt, state); + auto *vectorValue = vectorizeOperand(value, opInst, state); auto indices = map(makePtrDynCaster<Value>(), store->getIndices()); - FuncBuilder b(opStmt); + FuncBuilder b(opInst); auto permutationMap = - makePermutationMap(opStmt, state->strategy->loopToVectorDim); + makePermutationMap(opInst, state->strategy->loopToVectorDim); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); auto transfer = b.create<VectorTransferWriteOp>( - opStmt->getLoc(), vectorValue, memRef, indices, permutationMap); + opInst->getLoc(), vectorValue, memRef, indices, permutationMap); auto *res = cast<OperationInst>(transfer->getInstruction()); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << *res); // "Terminators" (i.e. StoreOps) are erased on the spot. - opStmt->erase(); + opInst->erase(); return res; } auto types = map([state](Value *v) { return getVectorType(v, *state); }, - opStmt->getResults()); - auto vectorizeOneOperand = [opStmt, state](Value *op) -> Value * { - return vectorizeOperand(op, opStmt, state); + opInst->getResults()); + auto vectorizeOneOperand = [opInst, state](Value *op) -> Value * { + return vectorizeOperand(op, opInst, state); }; - auto operands = map(vectorizeOneOperand, opStmt->getOperands()); + auto operands = map(vectorizeOneOperand, opInst->getOperands()); // Check whether a single operand is null. If so, vectorization failed. bool success = llvm::all_of(operands, [](Value *op) { return op; }); if (!success) { @@ -1116,9 +1117,9 @@ static OperationInst *vectorizeOneOperationInst(FuncBuilder *b, // TODO(ntv): Is it worth considering an OperationInst.clone operation // which changes the type so we can promote an OperationInst with less // boilerplate? - OperationState newOp(b->getContext(), opStmt->getLoc(), - opStmt->getName().getStringRef(), operands, types, - opStmt->getAttrs()); + OperationState newOp(b->getContext(), opInst->getLoc(), + opInst->getName().getStringRef(), operands, types, + opInst->getAttrs()); return b->createOperation(newOp); } @@ -1137,13 +1138,13 @@ static bool vectorizeOperations(VectorizationState *state) { auto insertUsesOf = [&worklist, state](OperationInst *vectorized) { for (auto *r : vectorized->getResults()) for (auto &u : r->getUses()) { - auto *stmt = cast<OperationInst>(u.getOwner()); + auto *inst = cast<OperationInst>(u.getOwner()); // Don't propagate to terminals, a separate pass is needed for those. // TODO(ntv)[b/119759136]: use isa<> once Op is implemented. - if (state->terminators.count(stmt) > 0) { + if (state->terminators.count(inst) > 0) { continue; } - worklist.insert(stmt); + worklist.insert(inst); } }; apply(insertUsesOf, state->roots); @@ -1152,15 +1153,15 @@ static bool vectorizeOperations(VectorizationState *state) { // size again. By construction, the order of elements in the worklist is // consistent across iterations. for (unsigned i = 0; i < worklist.size(); ++i) { - auto *stmt = worklist[i]; + auto *inst = worklist[i]; LLVM_DEBUG(dbgs() << "\n[early-vect] vectorize use: "); - LLVM_DEBUG(stmt->print(dbgs())); + LLVM_DEBUG(inst->print(dbgs())); - // 2. Create vectorized form of the statement. - // Insert it just before stmt, on success register stmt as replaced. - FuncBuilder b(stmt); - auto *vectorizedStmt = vectorizeOneOperationInst(&b, stmt, state); - if (!vectorizedStmt) { + // 2. Create vectorized form of the instruction. + // Insert it just before inst, on success register inst as replaced. + FuncBuilder b(inst); + auto *vectorizedInst = vectorizeOneOperationInst(&b, inst, state); + if (!vectorizedInst) { return true; } @@ -1168,11 +1169,11 @@ static bool vectorizeOperations(VectorizationState *state) { // Note that we cannot just call replaceAllUsesWith because it may // result in ops with mixed types, for ops whose operands have not all // yet been vectorized. This would be invalid IR. - state->registerReplacement(stmt, vectorizedStmt); + state->registerReplacement(inst, vectorizedInst); - // 4. Augment the worklist with uses of the statement we just vectorized. + // 4. Augment the worklist with uses of the instruction we just vectorized. // This preserves the proper order in the worklist. - apply(insertUsesOf, ArrayRef<OperationInst *>{stmt}); + apply(insertUsesOf, ArrayRef<OperationInst *>{inst}); } return false; } @@ -1184,7 +1185,7 @@ static bool vectorizeOperations(VectorizationState *state) { static bool vectorizeRootMatches(MLFunctionMatches matches, VectorizationStrategy *strategy) { for (auto m : matches) { - auto *loop = cast<ForStmt>(m.first); + auto *loop = cast<ForInst>(m.first); VectorizationState state; state.strategy = strategy; @@ -1201,7 +1202,7 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } FuncBuilder builder(loop); // builder to insert in place of loop DenseMap<const Value *, Value *> nomap; - ForStmt *clonedLoop = cast<ForStmt>(builder.clone(*loop, nomap)); + ForInst *clonedLoop = cast<ForInst>(builder.clone(*loop, nomap)); auto fail = doVectorize(m, &state); /// Sets up error handling for this root loop. This is how the root match /// maintains a clone for handling failure and restores the proper state via @@ -1230,8 +1231,8 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, auto roots = map(getDefiningInst, map(getKey, state.replacementMap)); // Vectorize the root operations and everything reached by use-def chains - // except the terminators (store statements) that need to be post-processed - // separately. + // except the terminators (store instructions) that need to be + // post-processed separately. fail = vectorizeOperations(&state); if (fail) { LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ failed vectorizeOperations"); @@ -1239,12 +1240,12 @@ static bool vectorizeRootMatches(MLFunctionMatches matches, } // Finally, vectorize the terminators. If anything fails to vectorize, skip. - auto vectorizeOrFail = [&fail, &state](OperationInst *stmt) { + auto vectorizeOrFail = [&fail, &state](OperationInst *inst) { if (fail) { return; } - FuncBuilder b(stmt); - auto *res = vectorizeOneOperationInst(&b, stmt, &state); + FuncBuilder b(inst); + auto *res = vectorizeOneOperationInst(&b, inst, &state); if (res == nullptr) { fail = true; } @@ -1284,7 +1285,7 @@ PassResult Vectorize::runOnMLFunction(Function *f) { if (fail) { continue; } - auto *loop = cast<ForStmt>(m.first); + auto *loop = cast<ForInst>(m.first); vectorizeLoopIfProfitable(loop, 0, patternDepth, &strategy); // TODO(ntv): if pattern does not apply, report it; alter the // cost/benefit. |

