diff options
Diffstat (limited to 'mlir/lib/Transforms/MaterializeVectors.cpp')
| -rw-r--r-- | mlir/lib/Transforms/MaterializeVectors.cpp | 57 |
1 files changed, 41 insertions, 16 deletions
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 7cd6a0ad273..3bf4305ca0c 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -209,23 +209,25 @@ struct MaterializeVectors : public FunctionPass { char MaterializeVectors::passID = 0; -// Returns the distance, in number of elements, between a slice in a dimension -// and the next slice in the same dimension. -// e.g. shape[3, 4, 5] -> strides[20, 5, 1] +/// Given a shape with sizes greater than 0 along all dimensions, +/// returns the distance, in number of elements, between a slice in a dimension +/// and the next slice in the same dimension. +/// e.g. shape[3, 4, 5] -> strides[20, 5, 1] static SmallVector<unsigned, 8> makeStrides(ArrayRef<unsigned> shape) { SmallVector<unsigned, 8> tmp; tmp.reserve(shape.size()); unsigned running = 1; for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) { - // TODO(ntv): emitError instead of NYI assert. - assert(*rit > 0 && "NYI: symbolic or null shape dimension"); + assert(*rit > 0 && "size must be greater than 0 along all dimensions of " + "shape"); tmp.push_back(running); running *= *rit; } return SmallVector<unsigned, 8>(tmp.rbegin(), tmp.rend()); } -// Returns the linearized expression. +/// Given a shape with sizes greater than 0 along all dimensions, returns the +/// delinearized components of linearIndex along shape. static SmallVector<unsigned, 8> delinearize(unsigned linearIndex, ArrayRef<unsigned> shape) { SmallVector<unsigned, 8> res; @@ -256,6 +258,8 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, /// insertion. /// For now, this is limited to ConstantOp because we do not vectorize loop /// indices and will need to be extended in the future. +/// +/// If substitution fails, returns nullptr. static MLValue * substitute(SSAValue *v, VectorType hwVectorType, DenseMap<const MLValue *, MLValue *> *substitutionsMap) { @@ -271,7 +275,7 @@ substitute(SSAValue *v, VectorType hwVectorType, return res.first->second; } v->getDefiningOperation()->emitError("Missing substitution"); - assert(false); + return nullptr; } return it->second; } @@ -400,6 +404,8 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) { /// affine reindexing. Just substitute their SSAValue* operands and be done. For /// this case the actual instance is irrelevant. Just use the SSA values in /// substitutionsMap. +/// +/// If the underlying substitution fails, this fails too and returns nullptr. static OperationStmt * instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, DenseMap<const MLValue *, MLValue *> *substitutionsMap) { @@ -407,11 +413,18 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType, "Should call the function specialized for VectorTransferReadOp"); assert(!opStmt->isa<VectorTransferWriteOp>() && "Should call the function specialized for VectorTransferWriteOp"); + bool fail = false; auto operands = map( - [hwVectorType, substitutionsMap](SSAValue *v) { - return substitute(v, hwVectorType, substitutionsMap); + [hwVectorType, substitutionsMap, &fail](SSAValue *v) { + auto *res = + fail ? nullptr : substitute(v, hwVectorType, substitutionsMap); + fail |= !res; + return res; }, opStmt->getOperands()); + if (fail) { + return nullptr; + } auto attrs = materializeAttributes(opStmt, hwVectorType); return b->createOperation(opStmt->getLoc(), opStmt->getName(), operands, {hwVectorType}, attrs); @@ -452,7 +465,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer, auto permutationMap = transfer->getPermutationMap(); LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: ")); LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: ")); - return composeUnboundedMaps(projectionMap, transfer->getPermutationMap()); + return composeUnboundedMaps(projectionMap, permutationMap); } /// Creates an instantiated version of `read` for the instance of @@ -516,6 +529,8 @@ instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write, /// type, all operands are substituted according to `substitutions`. Thanks /// to the topological order of a slice, the substitution is always /// possible. +/// +/// Returns true on failure. static bool instantiateMaterialization(Statement *stmt, MaterializationState *state) { LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt); @@ -557,6 +572,9 @@ static bool instantiateMaterialization(Statement *stmt, } auto *clone = instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap); + if (!clone) { + return true; + } state->substitutionsMap->insert(std::make_pair( cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0)))); return false; @@ -578,10 +596,12 @@ static bool instantiateMaterialization(Statement *stmt, /// equivalent of loop strip-mining + loop sinking and encoded this in the /// vector type. /// +/// Returns true on failure. +/// /// TODO(ntv): materialized allocs. /// TODO(ntv): full loops + materialized allocs. /// TODO(ntv): partial unrolling + materialized allocs. -static void emitSlice(MaterializationState *state, +static bool emitSlice(MaterializationState *state, SetVector<Statement *> *slice) { auto ratio = shapeRatio(state->superVectorType, state->hwVectorType); assert(ratio.hasValue() && @@ -601,7 +621,7 @@ static void emitSlice(MaterializationState *state, auto fail = instantiateMaterialization(stmt, &scopedState); if (fail) { stmt->emitError("Unhandled super-vector materialization failure"); - assert(!fail); + return true; } } } @@ -618,6 +638,7 @@ static void emitSlice(MaterializationState *state, LLVM_DEBUG((*slice)[idx]->print(dbgs())); (*slice)[idx]->erase(); } + return false; } /// Materializes super-vector types into concrete hw vector types as follows: @@ -637,7 +658,7 @@ static void emitSlice(MaterializationState *state, /// Additionally, this set is limited to statements in the same lexical scope /// because we currently disallow vectorization of defs that come from another /// scope. -static void materialize(MLFunction *f, +static bool materialize(MLFunction *f, const SetVector<OperationStmt *> &terminators, MaterializationState *state) { DenseSet<Statement *> seen; @@ -686,10 +707,14 @@ static void materialize(MLFunction *f, "Only f32 supported for now"); state->hwVectorType = VectorType::get( state->hwVectorSize, state->superVectorType.getElementType()); - emitSlice(state, &slice); + auto fail = emitSlice(state, &slice); + if (fail) { + return true; + } LLVM_DEBUG(dbgs() << "\nMLFunction is now\n"); LLVM_DEBUG(f->print(dbgs())); } + return false; } PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) { @@ -720,9 +745,9 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) { } // Call materialization. - materialize(f, terminators, &state); + auto fail = materialize(f, terminators, &state); - return PassResult::Success; + return fail ? PassResult::Failure : PassResult::Success; } FunctionPass *mlir::createMaterializeVectors() { |

