summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/MaterializeVectors.cpp
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2018-12-06 11:37:53 -0800
committerjpienaar <jpienaar@google.com>2019-03-29 14:20:37 -0700
commit5b610630b2d016ffce00a5f5bf78efc45b335002 (patch)
tree1244038621d792ac7a9e5a6641c27fe15fffe045 /mlir/lib/Transforms/MaterializeVectors.cpp
parent4adc169bd00c17e53d0f79a4e2f9b1105ae730cc (diff)
downloadbcm5719-llvm-5b610630b2d016ffce00a5f5bf78efc45b335002.tar.gz
bcm5719-llvm-5b610630b2d016ffce00a5f5bf78efc45b335002.zip
[MLIR] Error handling in MaterializeVectors
This removes assertions as a means to capture NYI behavior and propagates errors up. PiperOrigin-RevId: 224376935
Diffstat (limited to 'mlir/lib/Transforms/MaterializeVectors.cpp')
-rw-r--r--mlir/lib/Transforms/MaterializeVectors.cpp57
1 files changed, 41 insertions, 16 deletions
diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp
index 7cd6a0ad273..3bf4305ca0c 100644
--- a/mlir/lib/Transforms/MaterializeVectors.cpp
+++ b/mlir/lib/Transforms/MaterializeVectors.cpp
@@ -209,23 +209,25 @@ struct MaterializeVectors : public FunctionPass {
char MaterializeVectors::passID = 0;
-// Returns the distance, in number of elements, between a slice in a dimension
-// and the next slice in the same dimension.
-// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
+/// Given a shape with sizes greater than 0 along all dimensions,
+/// returns the distance, in number of elements, between a slice in a dimension
+/// and the next slice in the same dimension.
+/// e.g. shape[3, 4, 5] -> strides[20, 5, 1]
static SmallVector<unsigned, 8> makeStrides(ArrayRef<unsigned> shape) {
SmallVector<unsigned, 8> tmp;
tmp.reserve(shape.size());
unsigned running = 1;
for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
- // TODO(ntv): emitError instead of NYI assert.
- assert(*rit > 0 && "NYI: symbolic or null shape dimension");
+ assert(*rit > 0 && "size must be greater than 0 along all dimensions of "
+ "shape");
tmp.push_back(running);
running *= *rit;
}
return SmallVector<unsigned, 8>(tmp.rbegin(), tmp.rend());
}
-// Returns the linearized expression.
+/// Given a shape with sizes greater than 0 along all dimensions, returns the
+/// delinearized components of linearIndex along shape.
static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
ArrayRef<unsigned> shape) {
SmallVector<unsigned, 8> res;
@@ -256,6 +258,8 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
/// insertion.
/// For now, this is limited to ConstantOp because we do not vectorize loop
/// indices and will need to be extended in the future.
+///
+/// If substitution fails, returns nullptr.
static MLValue *
substitute(SSAValue *v, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
@@ -271,7 +275,7 @@ substitute(SSAValue *v, VectorType hwVectorType,
return res.first->second;
}
v->getDefiningOperation()->emitError("Missing substitution");
- assert(false);
+ return nullptr;
}
return it->second;
}
@@ -400,6 +404,8 @@ materializeAttributes(OperationStmt *opStmt, VectorType hwVectorType) {
/// affine reindexing. Just substitute their SSAValue* operands and be done. For
/// this case the actual instance is irrelevant. Just use the SSA values in
/// substitutionsMap.
+///
+/// If the underlying substitution fails, this fails too and returns nullptr.
static OperationStmt *
instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
@@ -407,11 +413,18 @@ instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType hwVectorType,
"Should call the function specialized for VectorTransferReadOp");
assert(!opStmt->isa<VectorTransferWriteOp>() &&
"Should call the function specialized for VectorTransferWriteOp");
+ bool fail = false;
auto operands = map(
- [hwVectorType, substitutionsMap](SSAValue *v) {
- return substitute(v, hwVectorType, substitutionsMap);
+ [hwVectorType, substitutionsMap, &fail](SSAValue *v) {
+ auto *res =
+ fail ? nullptr : substitute(v, hwVectorType, substitutionsMap);
+ fail |= !res;
+ return res;
},
opStmt->getOperands());
+ if (fail) {
+ return nullptr;
+ }
auto attrs = materializeAttributes(opStmt, hwVectorType);
return b->createOperation(opStmt->getLoc(), opStmt->getName(), operands,
{hwVectorType}, attrs);
@@ -452,7 +465,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy *transfer,
auto permutationMap = transfer->getPermutationMap();
LLVM_DEBUG(projectionMap.print(dbgs() << "\nprojectionMap: "));
LLVM_DEBUG(permutationMap.print(dbgs() << "\npermutationMap: "));
- return composeUnboundedMaps(projectionMap, transfer->getPermutationMap());
+ return composeUnboundedMaps(projectionMap, permutationMap);
}
/// Creates an instantiated version of `read` for the instance of
@@ -516,6 +529,8 @@ instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
/// type, all operands are substituted according to `substitutions`. Thanks
/// to the topological order of a slice, the substitution is always
/// possible.
+///
+/// Returns true on failure.
static bool instantiateMaterialization(Statement *stmt,
MaterializationState *state) {
LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
@@ -557,6 +572,9 @@ static bool instantiateMaterialization(Statement *stmt,
}
auto *clone =
instantiate(&b, opStmt, state->hwVectorType, state->substitutionsMap);
+ if (!clone) {
+ return true;
+ }
state->substitutionsMap->insert(std::make_pair(
cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0))));
return false;
@@ -578,10 +596,12 @@ static bool instantiateMaterialization(Statement *stmt,
/// equivalent of loop strip-mining + loop sinking and encoded this in the
/// vector type.
///
+/// Returns true on failure.
+///
/// TODO(ntv): materialized allocs.
/// TODO(ntv): full loops + materialized allocs.
/// TODO(ntv): partial unrolling + materialized allocs.
-static void emitSlice(MaterializationState *state,
+static bool emitSlice(MaterializationState *state,
SetVector<Statement *> *slice) {
auto ratio = shapeRatio(state->superVectorType, state->hwVectorType);
assert(ratio.hasValue() &&
@@ -601,7 +621,7 @@ static void emitSlice(MaterializationState *state,
auto fail = instantiateMaterialization(stmt, &scopedState);
if (fail) {
stmt->emitError("Unhandled super-vector materialization failure");
- assert(!fail);
+ return true;
}
}
}
@@ -618,6 +638,7 @@ static void emitSlice(MaterializationState *state,
LLVM_DEBUG((*slice)[idx]->print(dbgs()));
(*slice)[idx]->erase();
}
+ return false;
}
/// Materializes super-vector types into concrete hw vector types as follows:
@@ -637,7 +658,7 @@ static void emitSlice(MaterializationState *state,
/// Additionally, this set is limited to statements in the same lexical scope
/// because we currently disallow vectorization of defs that come from another
/// scope.
-static void materialize(MLFunction *f,
+static bool materialize(MLFunction *f,
const SetVector<OperationStmt *> &terminators,
MaterializationState *state) {
DenseSet<Statement *> seen;
@@ -686,10 +707,14 @@ static void materialize(MLFunction *f,
"Only f32 supported for now");
state->hwVectorType = VectorType::get(
state->hwVectorSize, state->superVectorType.getElementType());
- emitSlice(state, &slice);
+ auto fail = emitSlice(state, &slice);
+ if (fail) {
+ return true;
+ }
LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
LLVM_DEBUG(f->print(dbgs()));
}
+ return false;
}
PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) {
@@ -720,9 +745,9 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) {
}
// Call materialization.
- materialize(f, terminators, &state);
+ auto fail = materialize(f, terminators, &state);
- return PassResult::Success;
+ return fail ? PassResult::Failure : PassResult::Success;
}
FunctionPass *mlir::createMaterializeVectors() {
OpenPOWER on IntegriCloud