summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorNicolas Vasilache <ntv@google.com>2020-01-02 15:25:21 -0500
committerNicolas Vasilache <ntv@google.com>2020-01-06 09:41:38 -0500
commitd67c4cc2eb4ddc450c886598b934c111e721ab0c (patch)
tree3f4c0712f32c86f6e99e32898c69b53fc94bd45b /mlir/lib/IR
parentd45aafa2fbcf66f3dafdc7c5e0a0ce3709914cbc (diff)
downloadbcm5719-llvm-d67c4cc2eb4ddc450c886598b934c111e721ab0c.tar.gz
bcm5719-llvm-d67c4cc2eb4ddc450c886598b934c111e721ab0c.zip
[mlir][Linalg] Reimplement and extend getStridesAndOffset
Summary: This diff reimplements getStridesAndOffset in a significantly simpler way by operating on the AffineExpr and calling into simplifyAffineExpr instead of rolling its own saturating arithmetic. As a consequence it becomes quite simple to extend the behavior of getStridesAndOffset to encompass more cases by manipulating the AffineExpr more directly. The divisions are still filtered out and continue to yield fully dynamic strides. Simplifying the divisions is left for a later time if compelling use cases arise. Relevant tests are added. Reviewers: ftynse Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72098
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/StandardTypes.cpp246
1 files changed, 116 insertions, 130 deletions
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 441b59ed9cd..55d7baaa0e7 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -456,126 +456,73 @@ static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
auto sym = getAffineSymbolExpr(nSymbols++, context);
expr = expr ? expr + d * sym : d * sym;
}
- return expr;
-}
-
-// Factored out common logic to update `strides` and `seen` for `dim` with value
-// `val`. This handles both saturated and unsaturated cases.
-static void accumulateStrides(MutableArrayRef<int64_t> strides,
- MutableArrayRef<bool> seen, unsigned pos,
- int64_t val) {
- if (!seen[pos]) {
- // Newly seen case, sets value
- strides[pos] = val;
- seen[pos] = true;
- return;
- }
- if (strides[pos] != MemRefType::getDynamicStrideOrOffset())
- // Already seen case accumulates unless they are already saturated.
- strides[pos] += val;
-}
-
-// This sums multiple offsets as they are seen. In the particular case of
-// accumulating a dynamic offset with either a static of dynamic one, this
-// saturates to MemRefType::getDynamicStrideOrOffset().
-static void accumulateOffset(int64_t &offset, bool &seenOffset, int64_t val) {
- if (!seenOffset) {
- // Newly seen case, sets value
- offset = val;
- seenOffset = true;
- return;
- }
- if (offset != MemRefType::getDynamicStrideOrOffset())
- // Already seen case accumulates unless they are already saturated.
- offset += val;
+ return simplifyAffineExpr(expr, rank, nSymbols);
}
-/// Takes a single AffineExpr `e` and populates the `strides` and `seen` arrays
-/// with the strides values for each dim position and whether a value exists at
-/// that position, respectively.
+// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
+// i.e. single term). Accumulate the AffineExpr into the existing one.
+static void extractStridesFromTerm(AffineExpr e,
+ AffineExpr multiplicativeFactor,
+ MutableArrayRef<AffineExpr> strides,
+ AffineExpr &offset) {
+ if (auto dim = e.dyn_cast<AffineDimExpr>())
+ strides[dim.getPosition()] =
+ strides[dim.getPosition()] + multiplicativeFactor;
+ else
+ offset = offset + e * multiplicativeFactor;
+}
+
+/// Takes a single AffineExpr `e` and populates the `strides` array with the
+/// strides expressions for each dim position.
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
-static void extractStrides(AffineExpr e, MutableArrayRef<int64_t> strides,
- int64_t &offset, MutableArrayRef<bool> seen,
- bool &seenOffset, bool &failed) {
+static LogicalResult extractStrides(AffineExpr e,
+ AffineExpr multiplicativeFactor,
+ MutableArrayRef<AffineExpr> strides,
+ AffineExpr &offset) {
auto bin = e.dyn_cast<AffineBinaryOpExpr>();
- if (!bin)
- return;
+ if (!bin) {
+ extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
+ return success();
+ }
if (bin.getKind() == AffineExprKind::CeilDiv ||
bin.getKind() == AffineExprKind::FloorDiv ||
- bin.getKind() == AffineExprKind::Mod) {
- failed = true;
- return;
- }
+ bin.getKind() == AffineExprKind::Mod)
+ return failure();
+
if (bin.getKind() == AffineExprKind::Mul) {
- // LHS may be more complex than just a single dim (e.g. multiple syms and
- // dims). Bail out for now and revisit when we have evidence this is needed.
auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
- if (!dim) {
- failed = true;
- return;
- }
- auto cst = bin.getRHS().dyn_cast<AffineConstantExpr>();
- if (!cst) {
- strides[dim.getPosition()] = MemRefType::getDynamicStrideOrOffset();
- seen[dim.getPosition()] = true;
- } else {
- accumulateStrides(strides, seen, dim.getPosition(), cst.getValue());
+ if (dim) {
+ strides[dim.getPosition()] =
+ strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
+ return success();
}
- return;
+ // LHS and RHS may both contain complex expressions of dims. Try one path
+ // and if it fails try the other. This is guaranteed to succeed because
+ // only one path may have a `dim`, otherwise this is not an AffineExpr in
+ // the first place.
+ if (bin.getLHS().isSymbolicOrConstant())
+ return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
+ strides, offset);
+ return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
+ strides, offset);
}
+
if (bin.getKind() == AffineExprKind::Add) {
- for (auto e : {bin.getLHS(), bin.getRHS()}) {
- if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
- // Independent constants cumulate.
- accumulateOffset(offset, seenOffset, cst.getValue());
- } else if (auto sym = e.dyn_cast<AffineSymbolExpr>()) {
- // Independent symbols saturate.
- offset = MemRefType::getDynamicStrideOrOffset();
- seenOffset = true;
- } else if (auto dim = e.dyn_cast<AffineDimExpr>()) {
- // Independent symbols cumulate 1.
- accumulateStrides(strides, seen, dim.getPosition(), 1);
- }
- // Sum of binary ops dispatch to the respective exprs.
- }
- return;
+ auto res1 =
+ extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
+ auto res2 =
+ extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
+ return success(succeeded(res1) && succeeded(res2));
}
- llvm_unreachable("unexpected binary operation");
-}
-// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
-// i.e. single term).
-static void extractStridesFromTerm(AffineExpr e,
- MutableArrayRef<int64_t> strides,
- int64_t &offset, MutableArrayRef<bool> seen,
- bool &seenOffset) {
- if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
- assert(!seenOffset && "unexpected `seen` bit with single term");
- offset = cst.getValue();
- seenOffset = true;
- return;
- }
- if (auto sym = e.dyn_cast<AffineSymbolExpr>()) {
- assert(!seenOffset && "unexpected `seen` bit with single term");
- offset = MemRefType::getDynamicStrideOrOffset();
- seenOffset = true;
- return;
- }
- if (auto dim = e.dyn_cast<AffineDimExpr>()) {
- assert(!seen[dim.getPosition()] &&
- "unexpected `seen` bit with single term");
- strides[dim.getPosition()] = 1;
- seen[dim.getPosition()] = true;
- return;
- }
llvm_unreachable("unexpected binary operation");
}
-LogicalResult mlir::getStridesAndOffset(MemRefType t,
- SmallVectorImpl<int64_t> &strides,
- int64_t &offset) {
+static LogicalResult getStridesAndOffset(MemRefType t,
+ SmallVectorImpl<AffineExpr> &strides,
+ AffineExpr &offset) {
auto affineMaps = t.getAffineMaps();
// For now strides are only computed on a single affine map with a single
// result (i.e. the closed subset of linearization maps that are compatible
@@ -583,39 +530,58 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
// TODO(ntv): support more forms on a per-need basis.
if (affineMaps.size() > 1)
return failure();
- AffineExpr stridedExpr;
- if (affineMaps.empty() || affineMaps[0].isIdentity()) {
- if (t.getRank() == 0) {
- // Handle 0-D corner case.
- offset = 0;
- return success();
- }
- stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
- } else if (affineMaps[0].getNumResults() == 1) {
- stridedExpr = affineMaps[0].getResult(0);
- }
- if (!stridedExpr)
+ if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
return failure();
- bool failed = false;
- strides = SmallVector<int64_t, 4>(t.getRank(), 0);
- bool seenOffset = false;
- SmallVector<bool, 4> seen(t.getRank(), false);
- if (stridedExpr.isa<AffineBinaryOpExpr>()) {
- stridedExpr.walk([&](AffineExpr e) {
- if (!failed)
- extractStrides(e, strides, offset, seen, seenOffset, failed);
- });
- } else {
- extractStridesFromTerm(stridedExpr, strides, offset, seen, seenOffset);
+ auto zero = getAffineConstantExpr(0, t.getContext());
+ auto one = getAffineConstantExpr(1, t.getContext());
+ offset = zero;
+ strides.assign(t.getRank(), zero);
+
+ AffineMap m;
+ if (!affineMaps.empty()) {
+ m = affineMaps.front();
+ assert(!m.isIdentity() && "unexpected identity map");
}
- // Constant offset may not be present in `stridedExpr` which means it is
- // implicitly 0.
- if (!seenOffset)
- offset = 0;
+ // Canonical case for empty map.
+ if (!m) {
+ // 0-D corner case, offset is already 0.
+ if (t.getRank() == 0)
+ return success();
+ auto stridedExpr =
+ makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
+ if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
+ return success();
+ assert(false && "unexpected failure: extract strides in canonical layout");
+ }
+
+ // Non-canonical case requires more work.
+ auto stridedExpr =
+ simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
+ if (failed(extractStrides(stridedExpr, one, strides, offset))) {
+ offset = AffineExpr();
+ strides.clear();
+ return failure();
+ }
- if (failed || !llvm::all_of(seen, [](bool b) { return b; })) {
+ // Simplify results to allow folding to constants and simple checks.
+ unsigned numDims = m.getNumDims();
+ unsigned numSymbols = m.getNumSymbols();
+ offset = simplifyAffineExpr(offset, numDims, numSymbols);
+ for (auto &stride : strides)
+ stride = simplifyAffineExpr(stride, numDims, numSymbols);
+
+ /// In practice, a strided memref must be internally non-aliasing. Test
+ /// against 0 as a proxy.
+ /// TODO(ntv) static cases can have more advanced checks.
+ /// TODO(ntv) dynamic cases would require a way to compare symbolic
+ /// expressions and would probably need an affine set context propagated
+ /// everywhere.
+ if (llvm::any_of(strides, [](AffineExpr e) {
+ return e == getAffineConstantExpr(0, e.getContext());
+ })) {
+ offset = AffineExpr();
strides.clear();
return failure();
}
@@ -623,6 +589,26 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
return success();
}
+LogicalResult mlir::getStridesAndOffset(MemRefType t,
+ SmallVectorImpl<int64_t> &strides,
+ int64_t &offset) {
+ AffineExpr offsetExpr;
+ SmallVector<AffineExpr, 4> strideExprs;
+ if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
+ return failure();
+ if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
+ offset = cst.getValue();
+ else
+ offset = ShapedType::kDynamicStrideOrOffset;
+ for (auto e : strideExprs) {
+ if (auto c = e.dyn_cast<AffineConstantExpr>())
+ strides.push_back(c.getValue());
+ else
+ strides.push_back(ShapedType::kDynamicStrideOrOffset);
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//
OpenPOWER on IntegriCloud