diff options
| author | Nicolas Vasilache <ntv@google.com> | 2018-12-03 15:21:27 -0800 |
|---|---|---|
| committer | jpienaar <jpienaar@google.com> | 2019-03-29 14:15:25 -0700 |
| commit | b39d1f0bdb5052d447cdb0d8accedf292bb50d6c (patch) | |
| tree | f5218495ea7cecb0304fa96128244c11674f3052 /mlir/lib/Transforms/MaterializeVectors.cpp | |
| parent | bb3ffc1c2226d81155bc5ad01c1397566c2e7ee9 (diff) | |
| download | bcm5719-llvm-b39d1f0bdb5052d447cdb0d8accedf292bb50d6c.tar.gz bcm5719-llvm-b39d1f0bdb5052d447cdb0d8accedf292bb50d6c.zip | |
[MLIR] Add VectorTransferOps
This CL implements and uses VectorTransferOps in lieu of the former custom
call op. Tests are updated accordingly.
VectorTransferOps come in 2 flavors: VectorTransferReadOp and
VectorTransferWriteOp.
VectorTransferOps can be thought of as a backend-independent
pseudo op/library call that needs to be legalized to MLIR (whiteboxed) before
it can be lowered to backend-dependent IR.
Note that the current implementation does not yet support a real permutation
map. Proper support will come in a followup CL.
VectorTransferReadOp
====================
VectorTransferReadOp performs a blocking read from a scalar memref
location into a super-vector of the same elemental type. This operation is
called 'read' by opposition to 'load' because the super-vector granularity
is generally not representable with a single hardware register. As a
consequence, memory transfers will generally be required when lowering
VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
that supports super-vectorization with non-effecting padding for full-tile
only code.
A vector transfer read has semantics similar to a vector load, with additional
support for:
1. an optional value of the elemental type of the MemRef. This value
supports non-effecting padding and is inserted in places where the
vector read exceeds the MemRef bounds. If the value is not specified,
the access is statically guaranteed to be within bounds;
2. an attribute of type AffineMap to specify a slice of the original
MemRef access and its transposition into the super-vector shape. The
permutation_map is an unbounded AffineMap that must represent a
permutation from the MemRef dim space projected onto the vector dim
space.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
...
%val = `ssa-value` : f32
// let %i, %j, %k, %l be ssa-values of type index
%v0 = vector_transfer_read %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index) ->
vector<16x32x64xf32>
%v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(memref<?x?x?x?xf32>, index, index, index, index, f32) ->
vector<16x32x64xf32>
```
VectorTransferWriteOp
=====================
VectorTransferWriteOp performs a blocking write from a super-vector to
a scalar memref of the same elemental type. This operation is
called 'write' by opposition to 'store' because the super-vector
granularity is generally not representable with a single hardware register. As
a consequence, memory transfers will generally be required when lowering
VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
abstraction that supports super-vectorization with non-effecting padding
for full-tile only code.
A vector transfer write has semantics similar to a vector store, with
additional support for handling out-of-bounds situations.
Example:
```mlir
%A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
%val = `ssa-value` : vector<16x32x64xf32>
// let %i, %j, %k, %l be ssa-values of type index
vector_transfer_write %val, %src, %i, %j, %k, %l
{permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
(vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index)
```
PiperOrigin-RevId: 223873234
Diffstat (limited to 'mlir/lib/Transforms/MaterializeVectors.cpp')
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 245 |
1 files changed, 131 insertions, 114 deletions
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 60f0c06aad5..400b4fdf934 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -89,6 +89,7 @@ using llvm::SetVector; using namespace mlir; +using functional::makePtrDynCaster; using functional::map; static llvm::cl::list<int> @@ -243,11 +244,11 @@ substitute(SSAValue *v, /// TODO(ntv): support a concrete AffineMap and compose with it. /// TODO(ntv): these implementation details should be captured in a /// vectorization trait at the op level directly. -static SmallVector<MLValue *, 8> -reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType, +static SmallVector<SSAValue *, 8> +reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, ArrayRef<SSAValue *> memrefIndices) { - auto vectorShape = hwVectorType.cast<VectorType>().getShape(); + auto vectorShape = hwVectorType.getShape(); assert(hwVectorInstance.size() >= vectorShape.size()); unsigned numIndices = memrefIndices.size(); @@ -287,78 +288,21 @@ reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType, // TODO(ntv): support a concrete map and composition. auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(), affineMap, memrefIndices); - unsigned numResults = app->getNumResults(); - SmallVector<MLValue *, 8> res; - for (unsigned i = 0; i < numResults; ++i) { - res.push_back(cast<MLValue>(app->getResult(i))); - } - return res; + return SmallVector<SSAValue *, 8>{app->getResults()}; } -/// Returns the cloned operands of `opStmt` for the instance of -/// `hwVectorInstance` when lowering from a super-vector type to -/// `hwVectorType`. `hwVectorInstance` represents one particular instance of -/// `hwVectorType` int the covering of the super-vector type. For a more -/// detailed description of the problem, see the description of -/// reindexAffineIndices. -static SmallVector<MLValue *, 8> -cloneAndUnrollOperands(OperationStmt *opStmt, Type hwVectorType, - ArrayRef<unsigned> hwVectorInstance, - DenseMap<const MLValue *, MLValue *> *substitutionsMap) { - using functional::map; - - // For Ops that are not vector_transfer_read/vector_transfer_write we can just - // substitute and be done. - if (!isaVectorTransferRead(*opStmt) && !isaVectorTransferWrite(*opStmt)) { - return map([substitutionsMap]( - SSAValue *v) { return substitute(v, *substitutionsMap); }, - opStmt->getOperands()); - } - - // TODO(ntv): this error-prone boilerplate can be removed once we have a - // proper Op for vectr_transfer. - unsigned offset = 0; - unsigned numIndices = 0; - SmallVector<MLValue *, 8> res; - auto operands = opStmt->getOperands(); - if (isaVectorTransferRead(*opStmt)) { - offset = 1; - numIndices = opStmt->getNumOperands() - 1; - } else if (isaVectorTransferWrite(*opStmt)) { - offset = 2; - numIndices = opStmt->getNumOperands() - 2; - } - // Copy as-is the [optional valueToStore], memref. - for (unsigned i = 0; i < offset; ++i) { - res.push_back(substitute(*(operands.begin() + i), *substitutionsMap)); - } - - MLFuncBuilder b(opStmt); - // TODO(ntv): indices extraction is brittle and unsafe before we have an Op. - SmallVector<SSAValue *, 8> indices; - for (auto it = operands.begin() + offset; it != operands.end(); ++it) { - indices.push_back(*it); - } - auto affineValues = - reindexAffineIndices(&b, hwVectorType, hwVectorInstance, indices); - res.append(affineValues.begin(), affineValues.end()); - - return res; -} - -// Returns attributes with the following substitutions applied: -// - splat of `superVectorType` is replaced by splat of `hwVectorType`. -// TODO(ntv): add more substitutions on a per-need basis. -static SmallVector<NamedAttribute, 2> +/// Returns attributes with the following substitutions applied: +/// - splat of `superVectorType` is replaced by splat of `hwVectorType`. +/// TODO(ntv): add more substitutions on a per-need basis. +static SmallVector<NamedAttribute, 1> materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, VectorType hwVectorType) { - SmallVector<NamedAttribute, 2> res; + SmallVector<NamedAttribute, 1> res; for (auto a : opStmt->getAttrs()) { auto splat = a.second.dyn_cast<SplatElementsAttr>(); bool splatOfSuperVectorType = splat && (splat.getType() == superVectorType); if (splatOfSuperVectorType) { - auto attr = SplatElementsAttr::get(hwVectorType.cast<VectorType>(), - splat.getValue()); + auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue()); res.push_back(NamedAttribute(a.first, attr)); } else { res.push_back(a); @@ -367,6 +311,70 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, return res; } +/// Creates an instantiated version of `opStmt`. +/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no +/// affine reindexing. Just substitute their SSAValue* operands and be done. For +/// this case the actual instance is irrelevant. Just use the SSA values in +/// substitutionsMap. +static OperationStmt * +instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType superVectorType, + VectorType hwVectorType, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + assert(!opStmt->isa<VectorTransferReadOp>() && + "Should call the function specialized for VectorTransferReadOp"); + assert(!opStmt->isa<VectorTransferWriteOp>() && + "Should call the function specialized for VectorTransferWriteOp"); + auto operands = + map([substitutionsMap]( + SSAValue *v) { return substitute(v, *substitutionsMap); }, + opStmt->getOperands()); + return b->createOperation( + opStmt->getLoc(), opStmt->getName(), operands, {hwVectorType}, + materializeAttributes(opStmt, superVectorType, hwVectorType)); +} + +/// Creates an instantiated version of `read` for the instance of +/// `hwVectorInstance` when lowering from a super-vector type to +/// `hwVectorType`. `hwVectorInstance` represents one particular instance of +/// `hwVectorType` int the covering of the super-vector type. For a more +/// detailed description of the problem, see the description of +/// reindexAffineIndices. +static OperationStmt * +instantiate(MLFuncBuilder *b, VectorTransferReadOp *read, + VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + SmallVector<SSAValue *, 8> indices = + map(makePtrDynCaster<SSAValue>(), read->getIndices()); + auto affineIndices = + reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); + auto cloned = b->create<VectorTransferReadOp>( + read->getLoc(), hwVectorType, read->getMemRef(), affineIndices, + makePermutationMap(read->getMemRefType(), hwVectorType), + read->getPaddingValue()); + return cast<OperationStmt>(cloned->getOperation()); +} + +/// Creates an instantiated version of `write` for the instance of +/// `hwVectorInstance` when lowering from a super-vector type to +/// `hwVectorType`. `hwVectorInstance` represents one particular instance of +/// `hwVectorType` int the covering of th3e super-vector type. For a more +/// detailed description of the problem, see the description of +/// reindexAffineIndices. +static OperationStmt * +instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write, + VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance, + DenseMap<const MLValue *, MLValue *> *substitutionsMap) { + SmallVector<SSAValue *, 8> indices = + map(makePtrDynCaster<SSAValue>(), write->getIndices()); + auto affineIndices = + reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices); + auto cloned = b->create<VectorTransferWriteOp>( + write->getLoc(), substitute(write->getVector(), *substitutionsMap), + write->getMemRef(), affineIndices, + makePermutationMap(write->getMemRefType(), hwVectorType)); + return cast<OperationStmt>(cloned->getOperation()); +} + /// Returns `true` if stmt instance is properly cloned and inserted, false /// otherwise. /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of @@ -386,45 +394,52 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType, /// type, all operands are substituted according to `substitutions`. Thanks /// to the topological order of a slice, the substitution is always /// possible. -static bool cloneAndInsertHardwareVectorInstance(Statement *stmt, - MaterializationState *state) { - LLVM_DEBUG(dbgs() << "\nclone" << *stmt); - if (auto *opStmt = dyn_cast<OperationStmt>(stmt)) { - // TODO(ntv): Is it worth considering an OperationStmt.clone operation - // which changes the type so we can promote an OperationStmt with less - // boilerplate? - assert(opStmt->getNumResults() <= 1 && "NYI: opStmt has > 1 results"); - auto operands = cloneAndUnrollOperands(opStmt, state->hwVectorType, - state->hwVectorInstance, - state->substitutionsMap); - MLFuncBuilder b(stmt); - if (opStmt->getNumResults() == 0) { - // vector_transfer_write - b.createOperation(stmt->getLoc(), opStmt->getName(), operands, {}, - materializeAttributes(opStmt, state->superVectorType, - state->hwVectorType)); - } else { - // vector_transfer_read - auto *cloned = b.createOperation( - stmt->getLoc(), opStmt->getName(), operands, {state->hwVectorType}, - materializeAttributes(opStmt, state->superVectorType, - state->hwVectorType)); - state->substitutionsMap->insert(std::make_pair( - cast<MLValue>(opStmt->getResult(0)), - cast<MLValue>(cast<OperationStmt>(cloned)->getResult(0)))); - } - return false; - } +static bool instantiateMaterialization(Statement *stmt, + MaterializationState *state) { + LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt); + // Fail hard and wake up when needed. if (isa<ForStmt>(stmt)) { - // Fail hard and wake up when needed. stmt->emitError("NYI path ForStmt"); return true; } // Fail hard and wake up when needed. - stmt->emitError("NYI path IfStmt"); - return true; + if (isa<IfStmt>(stmt)) { + stmt->emitError("NYI path IfStmt"); + return true; + } + + // Create a builder here for unroll-and-jam effects. + MLFuncBuilder b(stmt); + auto *opStmt = cast<OperationStmt>(stmt); + if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) { + instantiate(&b, &*write, state->hwVectorType, state->hwVectorInstance, + state->substitutionsMap); + return false; + } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) { + auto *clone = instantiate(&b, &*read, state->hwVectorType, + state->hwVectorInstance, state->substitutionsMap); + state->substitutionsMap->insert(std::make_pair( + cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0)))); + return false; + } + // The only op with 0 results reaching this point must, by construction, be + // VectorTransferWriteOps and have been caught above. Ops with >= 2 results + // are not yet supported. So just support 1 result. + if (opStmt->getNumResults() != 1) { + stmt->emitError("NYI: ops with != 1 results"); + return true; + } + if (opStmt->getResult(0)->getType() != state->superVectorType) { + stmt->emitError("Op does not return a supervector."); + return true; + } + auto *clone = instantiate(&b, opStmt, state->superVectorType, + state->hwVectorType, state->substitutionsMap); + state->substitutionsMap->insert(std::make_pair( + cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0)))); + return false; } /// Takes a slice and rewrites the operations in it so that occurrences @@ -463,15 +478,22 @@ static void emitSlice(MaterializationState *state, scopedState.substitutionsMap = &substitutionMap; // slice are topologically sorted, we can just clone them in order. for (auto *stmt : *slice) { - auto fail = cloneAndInsertHardwareVectorInstance(stmt, &scopedState); + auto fail = instantiateMaterialization(stmt, &scopedState); (void)fail; assert(!fail && "Unhandled super-vector materialization failure"); } } + + LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); + LLVM_DEBUG( + cast<OperationStmt>((*slice)[0])->getOperationFunction()->print(dbgs())); + // slice are topologically sorted, we can just erase them in reverse // order. Reverse iterator does not just work simply with an operator* // dereference. for (int idx = slice->size() - 1; idx >= 0; --idx) { + LLVM_DEBUG(dbgs() << "\nErase: "); + LLVM_DEBUG((*slice)[idx]->print(dbgs())); (*slice)[idx]->erase(); } } @@ -497,25 +519,21 @@ static void materialize(MLFunction *f, const SetVector<OperationStmt *> &terminators, MaterializationState *state) { DenseSet<Statement *> seen; - for (auto terminator : terminators) { - LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *terminator); - + for (auto *term : terminators) { // Short-circuit test, a given terminator may have been reached by some // other previous transitive use-def chains. - if (seen.count(terminator) > 0) { + if (seen.count(term) > 0) { continue; } - // Terminators are vector_transfer_write with 0 results by construction atm. - assert(isaVectorTransferWrite(*terminator) && ""); - assert(terminator->getNumResults() == 0 && - "NYI: terminators must have 0 results"); + auto terminator = term->cast<VectorTransferWriteOp>(); + LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term); // Get the transitive use-defs starting from terminator, limited to the // current enclosing scope of the terminator. See the top of the function // Note for the justification of this restriction. // TODO(ntv): relax scoping constraints. - auto *enclosingScope = terminator->getParentStmt(); + auto *enclosingScope = term->getParentStmt(); auto keepIfInSameScope = [enclosingScope](Statement *stmt) { assert(stmt && "NULL stmt"); if (!enclosingScope) { @@ -525,7 +543,7 @@ static void materialize(MLFunction *f, return properlyDominates(*enclosingScope, *stmt); }; SetVector<Statement *> slice = - getSlice(terminator, keepIfInSameScope, keepIfInSameScope); + getSlice(term, keepIfInSameScope, keepIfInSameScope); assert(!slice.empty()); // Sanity checks: transitive slice must be completely disjoint from @@ -540,10 +558,9 @@ static void materialize(MLFunction *f, // Emit the current slice. // Set scoped super-vector and corresponding hw vector types. - state->superVectorType = - terminator->getOperand(0)->getType().cast<VectorType>(); + state->superVectorType = terminator->getVectorType(); assert((state->superVectorType.getElementType() == - Type::getF32(terminator->getContext())) && + Type::getF32(term->getContext())) && "Only f32 supported for now"); state->hwVectorType = VectorType::get( state->hwVectorSize, state->superVectorType.getElementType()); @@ -568,7 +585,7 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) { // super-vector of subVectorType. auto filter = [subVectorType](const Statement &stmt) { const auto &opStmt = cast<OperationStmt>(stmt); - if (!isaVectorTransferWrite(opStmt)) { + if (!opStmt.isa<VectorTransferWriteOp>()) { return false; } return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType); |

