diff options
Diffstat (limited to 'mlir/lib/Transforms/MaterializeVectors.cpp')
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 245 |
1 files changed, 131 insertions, 114 deletions
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 60f0c06aad5..400b4fdf934 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -89,6 +89,7 @@ using llvm::SetVector; using namespace mlir; +using functional::makePtrDynCaster; using functional::map; static llvm::cl::list<int> @@ -243,11 +244,11 @@ substitute(SSAValue *v, /// TODO(ntv): support a concrete AffineMap and compose with it. /// TODO(ntv): these implementation details should be captured in a /// vectorization trait at the op level directly. -static SmallVector<MLValue *, 8> -reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType, +static SmallVector<SSAValue *, 8> +reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, ArrayRef<SSAValue *> memrefIndices) { - auto vectorShape = hwVectorType.cast<VectorType>().getShape(); + auto vectorShape = hwVectorType.getShape(); assert(hwVectorInstance.size() >= vectorShape.size()); unsigned numIndices = memrefIndices.size(); @@ -287,78 +288,21 @@ reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType, // TODO(ntv): support a concrete map and composition. auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(), affineMap, memrefIndices); - unsigned numResults = app->getNumResults(); - SmallVector<MLValue *, 8> res; - for (unsigned i = 0; i < numResults; ++i) { - res.push_back(cast<MLValue>(app->getResult(i))); - } - return res; + return SmallVector<SSAValue *, 8>{app->getResults()}; } -/// Returns the cloned operands of `opStmt` for the instance of -/// `hwVectorInstance` when lowering from a super-vector type to -/// `hwVectorType`. `hwVectorInstance` represents one particular instance of -/// `hwVectorType` int the covering of the super-vector type. For a more -/// detailed description of the problem, see the description of -/// reindexAffineIndices. -static SmallVector<MLValue *, 8> -cloneAndUnrollOperands(OperationStmt *opStmt, Type hwVectorType, - ArrayRef<unsigned> hwVectorInstance, - DenseMap<const MLValue *, MLValue *> *substitutionsMap) { - using functional::map; - - // For Ops that are not vector_transfer_read/vector_transfer_write we can just - // substitute and be done. - if (!isaVectorTransferRead(*opStmt) && !isaVectorTransferWrite(*opStmt)) { - return map([substitutionsMap]( - SSAValue *v) { return substitute(v, *substitutionsMap); }, - opStmt->getOperands()); - } - - // TODO(ntv): this error-prone boilerplate can be removed once we have a - // proper Op for vectr_transfer. - unsigned offset = 0; - unsigned numIndices = 0; - SmallVector<MLValue *, 8> res; - auto operands = opStmt->getOperands(); - if (isaVectorTransferRead(*opStmt)) { - offset = 1; - numIndices = opStmt->getNumOperands() - 1; - } else if (isaVectorTransferWrite(*opStmt)) { - offset = 2; - numIndices = opStmt->getNumOperands() - 2; - } - // Copy as-is the [optional valueToStore], memref. - for (unsigned i = 0; i < offset; ++i) { - res.push_back(substitute(*(operands.begin() + i), *substitutionsMap)); - } - - MLFuncBuilder b(opStmt); - // TODO(ntv): indices extraction is brittle and unsafe before we have an Op. - SmallVector<SSAValue *, 8> indices; - for (auto it = operands.begin() + offset; it != operands.end(); ++it) { - indices.push_back(*it); - } - auto affineValues = - reindexAffineIndices(&b, hwVectorType, hwVectorInstance, indices); - res.append(affineValues.begin(), affineValues.end()); - - return res; -} - -// Returns attributes with the following substitutions applied: -// - splat of `superVectorType` is replaced by splat of `hwVectorType`. -// TODO(ntv): add more substitutions on a per-need basis. -static SmallVector<NamedAttribute, 2> +/// Returns attributes with the following substitutions applied: +/// - splat of `superVectorType` is replaced by splat of `hwVectorType`. +/// TODO(ntv): add more substitutions on a per-need basis. +static SmallVector<NamedAttribute, 1> materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, VectorType hwVectorType) { - SmallVector<NamedAttribute, 2> res; + SmallVector<NamedAttribute, 1> res; for (auto a : opStmt->getAttrs()) { auto splat = a.second.dyn_cast<SplatElementsAttr>(); bool splatOfSuperVectorType = splat && (splat.getType() == superVectorType); if (splatOfSuperVectorType) { - auto attr = SplatElementsAttr::get(hwVectorType.cast<VectorType>(), - splat.getValue()); + auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue()); res.push_back(NamedAttribute(a.first, attr)); } else { res.push_back(a); @@ -367,6 +311,70 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, return res; } +/// Creates an instantiated version of `opStmt`. +/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no +/// affine reindexing. Just substitute their SSAValue* operands and be done. For +/// this case the actual instance is irrelevant. Just use the SSA values in +/// substitutionsMap. +static OperationStmt * +instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType superVectorType, + VectorType hwVectorType, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + assert(!opStmt->isa<VectorTransferReadOp>() && + "Should call the function specialized for VectorTransferReadOp"); + assert(!opStmt->isa<VectorTransferWriteOp>() && + "Should call the function specialized for VectorTransferWriteOp"); + auto operands = + map([substitutionsMap]( + SSAValue *v) { return substitute(v, *substitutionsMap); }, + opStmt->getOperands()); + return b->createOperation( + opStmt->getLoc(), opStmt->getName(), operands, {hwVectorType}, + materializeAttributes(opStmt, superVectorType, hwVectorType)); +} + +/// Creates an instantiated version of `read` for the instance of +/// `hwVectorInstance` when lowering from a super-vector type to +/// `hwVectorType`. `hwVectorInstance` represents one particular instance of +/// `hwVectorType` int the covering of the super-vector type. For a more +/// detailed description of the problem, see the description of +/// reindexAffineIndices. +static OperationStmt * +instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, + VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + SmallVector<SSAValue *, 8> indices = + map(makePtrDynCaster<SSAValue>(), read->getIndices()); + auto affineIndices = + reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); + auto cloned = b->create<VectorTransferReadOp>( + read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, + makePermutationMap(read->getMemRefType(), hwVectorType), + read->getPaddingValue()); + return cast<OperationStmt>(cloned->getOperation()); +} + +/// Creates an instantiated version of `write` for the instance of +/// `hwVectorInstance` when lowering from a super-vector type to +/// `hwVectorType`. `hwVectorInstance` represents one particular instance of +/// `hwVectorType` int the covering of th3e super-vector type. For a more +/// detailed description of the problem, see the description of +/// reindexAffineIndices. +static OperationStmt * +instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write, + VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + SmallVector<SSAValue *, 8> indices = + map(makePtrDynCaster<SSAValue>(), write->getIndices()); + auto affineIndices = + reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); + auto cloned = b->create<VectorTransferWriteOp>( + write->getLoc(), substitute(write->getVector(), *substitutionsMap), + write->getMemRef(), affineIndices, + makePermutationMap(write->getMemRefType(), hwVectorType)); + return cast<OperationStmt>(cloned->getOperation()); +} + /// Returns `true` if stmt instance is properly cloned and inserted, false /// otherwise. /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of @@ -386,45 +394,52 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, /// type, all operands are substituted according to `substitutions`. Thanks /// to the topological order of a slice, the substitution is always /// possible. -static bool cloneAndInsertHardwareVectorInstance(Statement *stmt, - MaterializationState *state) { - LLVM_DEBUG(dbgs() << "\nclone" << *stmt); - if (auto *opStmt = dyn_cast<OperationStmt>(stmt)) { - // TODO(ntv): Is it worth considering an OperationStmt.clone operation - // which changes the type so we can promote an OperationStmt with less - // boilerplate? - assert(opStmt->getNumResults() <= 1 && "NYI: opStmt has > 1 results"); - auto operands = cloneAndUnrollOperands(opStmt, state->hwVectorType, - state->hwVectorInstance, - state->substitutionsMap); - MLFuncBuilder b(stmt); - if (opStmt->getNumResults() == 0) { - // vector_transfer_write - b.createOperation(stmt->getLoc(), opStmt->getName(), operands, {}, - materializeAttributes(opStmt, state->superVectorType, - state->hwVectorType)); - } else { - // vector_transfer_read - auto *cloned = b.createOperation( - stmt->getLoc(), opStmt->getName(), operands, {state->hwVectorType}, - materializeAttributes(opStmt, state->superVectorType, - state->hwVectorType)); - state->substitutionsMap->insert(std::make_pair( - cast<MLValue>(opStmt->getResult(0)), - cast<MLValue>(cast<OperationStmt>(cloned)->getResult(0)))); - } - return false; - } +static bool instantiateMaterialization(Statement *stmt, + MaterializationState *state) { + LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt); + // Fail hard and wake up when needed. if (isa<ForStmt>(stmt)) { - // Fail hard and wake up when needed. stmt->emitError("NYI path ForStmt"); return true; } // Fail hard and wake up when needed. - stmt->emitError("NYI path IfStmt"); - return true; + if (isa<IfStmt>(stmt)) { + stmt->emitError("NYI path IfStmt"); + return true; + } + + // Create a builder here for unroll-and-jam effects. + MLFuncBuilder b(stmt); + auto *opStmt = cast<OperationStmt>(stmt); + if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) { + instantiate(&b, &*write, state->hwVectorType, state->hwVectorInstance, + state->substitutionsMap); + return false; + } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) { + auto *clone = instantiate(&b, &*read, state->hwVectorType, + state->hwVectorInstance, state->substitutionsMap); + state->substitutionsMap->insert(std::make_pair( + cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0)))); + return false; + } + // The only op with 0 results reaching this point must, by construction, be + // VectorTransferWriteOps and have been caught above. Ops with >= 2 results + // are not yet supported. So just support 1 result. + if (opStmt->getNumResults() != 1) { + stmt->emitError("NYI: ops with != 1 results"); + return true; + } + if (opStmt->getResult(0)->getType() != state->superVectorType) { + stmt->emitError("Op does not return a supervector."); + return true; + } + auto *clone = instantiate(&b, opStmt, state->superVectorType, + state->hwVectorType, state->substitutionsMap); + state->substitutionsMap->insert(std::make_pair( + cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0)))); + return false; } /// Takes a slice and rewrites the operations in it so that occurrences @@ -463,15 +478,22 @@ static void emitSlice(MaterializationState *state, scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. for (auto *stmt : *slice) { - auto fail = cloneAndInsertHardwareVectorInstance(stmt, &scopedState); + auto fail = instantiateMaterialization(stmt, &scopedState); (void)fail; assert(!fail && "Unhandled super-vector materialization failure"); } } + + LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); + LLVM_DEBUG( + cast<OperationStmt>((*slice)[0])->getOperationFunction()->print(dbgs())); + // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* // dereference. for (int idx = slice->size() - 1; idx >= 0; --idx) { + LLVM_DEBUG(dbgs() << "\nErase: "); + LLVM_DEBUG((*slice)[idx]->print(dbgs())); (*slice)[idx]->erase(); } } @@ -497,25 +519,21 @@ static void materialize(MLFunction *f, const SetVector<OperationStmt *> &terminators, MaterializationState *state) { DenseSet<Statement *> seen; - for (auto terminator : terminators) { - LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *terminator); - + for (auto *term : terminators) { // Short-circuit test, a given terminator may have been reached by some // other previous transitive use-def chains. - if (seen.count(terminator) > 0) { + if (seen.count(term) > 0) { continue; } - // Terminators are vector_transfer_write with 0 results by construction atm. - assert(isaVectorTransferWrite(*terminator) && ""); - assert(terminator->getNumResults() == 0 && - "NYI: terminators must have 0 results"); + auto terminator = term->cast<VectorTransferWriteOp>(); + LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term); // Get the transitive use-defs starting from terminator, limited to the // current enclosing scope of the terminator. See the top of the function // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. - auto *enclosingScope = terminator->getParentStmt(); + auto *enclosingScope = term->getParentStmt(); auto keepIfInSameScope = [enclosingScope](Statement *stmt) { assert(stmt && "NULL stmt"); if (!enclosingScope) { @@ -525,7 +543,7 @@ static void materialize(MLFunction *f, return properlyDominates(*enclosingScope, *stmt); }; SetVector<Statement *> slice = - getSlice(terminator, keepIfInSameScope, keepIfInSameScope); + getSlice(term, keepIfInSameScope, keepIfInSameScope); assert(!slice.empty()); // Sanity checks: transitive slice must be completely disjoint from @@ -540,10 +558,9 @@ static void materialize(MLFunction *f, // Emit the current slice. // Set scoped super-vector and corresponding hw vector types. - state->superVectorType = - terminator->getOperand(0)->getType().cast<VectorType>(); + state->superVectorType = terminator->getVectorType(); assert((state->superVectorType.getElementType() == - Type::getF32(terminator->getContext())) && + Type::getF32(term->getContext())) && "Only f32 supported for now"); state->hwVectorType = VectorType::get( state->hwVectorSize, state->superVectorType.getElementType()); @@ -568,7 +585,7 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) { // super-vector of subVectorType. auto filter = [subVectorType](const Statement &stmt) { const auto &opStmt = cast<OperationStmt>(stmt); - if (!isaVectorTransferWrite(opStmt)) { + if (!opStmt.isa<VectorTransferWriteOp>()) { return false; } return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType); |

