summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/lib/IR/StandardTypes.cpp246
-rw-r--r--mlir/test/AffineOps/memref-stride-calculation.mlir10
2 files changed, 126 insertions, 130 deletions
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 441b59ed9cd..55d7baaa0e7 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -456,126 +456,73 @@ static AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
auto sym = getAffineSymbolExpr(nSymbols++, context);
expr = expr ? expr + d * sym : d * sym;
}
- return expr;
-}
-
-// Factored out common logic to update `strides` and `seen` for `dim` with value
-// `val`. This handles both saturated and unsaturated cases.
-static void accumulateStrides(MutableArrayRef<int64_t> strides,
- MutableArrayRef<bool> seen, unsigned pos,
- int64_t val) {
- if (!seen[pos]) {
- // Newly seen case, sets value
- strides[pos] = val;
- seen[pos] = true;
- return;
- }
- if (strides[pos] != MemRefType::getDynamicStrideOrOffset())
- // Already seen case accumulates unless they are already saturated.
- strides[pos] += val;
-}
-
-// This sums multiple offsets as they are seen. In the particular case of
-// accumulating a dynamic offset with either a static of dynamic one, this
-// saturates to MemRefType::getDynamicStrideOrOffset().
-static void accumulateOffset(int64_t &offset, bool &seenOffset, int64_t val) {
- if (!seenOffset) {
- // Newly seen case, sets value
- offset = val;
- seenOffset = true;
- return;
- }
- if (offset != MemRefType::getDynamicStrideOrOffset())
- // Already seen case accumulates unless they are already saturated.
- offset += val;
+ return simplifyAffineExpr(expr, rank, nSymbols);
}
-/// Takes a single AffineExpr `e` and populates the `strides` and `seen` arrays
-/// with the strides values for each dim position and whether a value exists at
-/// that position, respectively.
+// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
+// i.e. single term). Accumulate the AffineExpr into the existing one.
+static void extractStridesFromTerm(AffineExpr e,
+ AffineExpr multiplicativeFactor,
+ MutableArrayRef<AffineExpr> strides,
+ AffineExpr &offset) {
+ if (auto dim = e.dyn_cast<AffineDimExpr>())
+ strides[dim.getPosition()] =
+ strides[dim.getPosition()] + multiplicativeFactor;
+ else
+ offset = offset + e * multiplicativeFactor;
+}
+
+/// Takes a single AffineExpr `e` and populates the `strides` array with the
+/// strides expressions for each dim position.
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
-static void extractStrides(AffineExpr e, MutableArrayRef<int64_t> strides,
- int64_t &offset, MutableArrayRef<bool> seen,
- bool &seenOffset, bool &failed) {
+static LogicalResult extractStrides(AffineExpr e,
+ AffineExpr multiplicativeFactor,
+ MutableArrayRef<AffineExpr> strides,
+ AffineExpr &offset) {
auto bin = e.dyn_cast<AffineBinaryOpExpr>();
- if (!bin)
- return;
+ if (!bin) {
+ extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
+ return success();
+ }
if (bin.getKind() == AffineExprKind::CeilDiv ||
bin.getKind() == AffineExprKind::FloorDiv ||
- bin.getKind() == AffineExprKind::Mod) {
- failed = true;
- return;
- }
+ bin.getKind() == AffineExprKind::Mod)
+ return failure();
+
if (bin.getKind() == AffineExprKind::Mul) {
- // LHS may be more complex than just a single dim (e.g. multiple syms and
- // dims). Bail out for now and revisit when we have evidence this is needed.
auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
- if (!dim) {
- failed = true;
- return;
- }
- auto cst = bin.getRHS().dyn_cast<AffineConstantExpr>();
- if (!cst) {
- strides[dim.getPosition()] = MemRefType::getDynamicStrideOrOffset();
- seen[dim.getPosition()] = true;
- } else {
- accumulateStrides(strides, seen, dim.getPosition(), cst.getValue());
+ if (dim) {
+ strides[dim.getPosition()] =
+ strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
+ return success();
}
- return;
+ // LHS and RHS may both contain complex expressions of dims. Try one path
+ // and if it fails try the other. This is guaranteed to succeed because
+ // only one path may have a `dim`, otherwise this is not an AffineExpr in
+ // the first place.
+ if (bin.getLHS().isSymbolicOrConstant())
+ return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
+ strides, offset);
+ return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
+ strides, offset);
}
+
if (bin.getKind() == AffineExprKind::Add) {
- for (auto e : {bin.getLHS(), bin.getRHS()}) {
- if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
- // Independent constants cumulate.
- accumulateOffset(offset, seenOffset, cst.getValue());
- } else if (auto sym = e.dyn_cast<AffineSymbolExpr>()) {
- // Independent symbols saturate.
- offset = MemRefType::getDynamicStrideOrOffset();
- seenOffset = true;
- } else if (auto dim = e.dyn_cast<AffineDimExpr>()) {
- // Independent symbols cumulate 1.
- accumulateStrides(strides, seen, dim.getPosition(), 1);
- }
- // Sum of binary ops dispatch to the respective exprs.
- }
- return;
+ auto res1 =
+ extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
+ auto res2 =
+ extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
+ return success(succeeded(res1) && succeeded(res2));
}
- llvm_unreachable("unexpected binary operation");
-}
-// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
-// i.e. single term).
-static void extractStridesFromTerm(AffineExpr e,
- MutableArrayRef<int64_t> strides,
- int64_t &offset, MutableArrayRef<bool> seen,
- bool &seenOffset) {
- if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
- assert(!seenOffset && "unexpected `seen` bit with single term");
- offset = cst.getValue();
- seenOffset = true;
- return;
- }
- if (auto sym = e.dyn_cast<AffineSymbolExpr>()) {
- assert(!seenOffset && "unexpected `seen` bit with single term");
- offset = MemRefType::getDynamicStrideOrOffset();
- seenOffset = true;
- return;
- }
- if (auto dim = e.dyn_cast<AffineDimExpr>()) {
- assert(!seen[dim.getPosition()] &&
- "unexpected `seen` bit with single term");
- strides[dim.getPosition()] = 1;
- seen[dim.getPosition()] = true;
- return;
- }
llvm_unreachable("unexpected binary operation");
}
-LogicalResult mlir::getStridesAndOffset(MemRefType t,
- SmallVectorImpl<int64_t> &strides,
- int64_t &offset) {
+static LogicalResult getStridesAndOffset(MemRefType t,
+ SmallVectorImpl<AffineExpr> &strides,
+ AffineExpr &offset) {
auto affineMaps = t.getAffineMaps();
// For now strides are only computed on a single affine map with a single
// result (i.e. the closed subset of linearization maps that are compatible
@@ -583,39 +530,58 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
// TODO(ntv): support more forms on a per-need basis.
if (affineMaps.size() > 1)
return failure();
- AffineExpr stridedExpr;
- if (affineMaps.empty() || affineMaps[0].isIdentity()) {
- if (t.getRank() == 0) {
- // Handle 0-D corner case.
- offset = 0;
- return success();
- }
- stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
- } else if (affineMaps[0].getNumResults() == 1) {
- stridedExpr = affineMaps[0].getResult(0);
- }
- if (!stridedExpr)
+ if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
return failure();
- bool failed = false;
- strides = SmallVector<int64_t, 4>(t.getRank(), 0);
- bool seenOffset = false;
- SmallVector<bool, 4> seen(t.getRank(), false);
- if (stridedExpr.isa<AffineBinaryOpExpr>()) {
- stridedExpr.walk([&](AffineExpr e) {
- if (!failed)
- extractStrides(e, strides, offset, seen, seenOffset, failed);
- });
- } else {
- extractStridesFromTerm(stridedExpr, strides, offset, seen, seenOffset);
+ auto zero = getAffineConstantExpr(0, t.getContext());
+ auto one = getAffineConstantExpr(1, t.getContext());
+ offset = zero;
+ strides.assign(t.getRank(), zero);
+
+ AffineMap m;
+ if (!affineMaps.empty()) {
+ m = affineMaps.front();
+ assert(!m.isIdentity() && "unexpected identity map");
}
- // Constant offset may not be present in `stridedExpr` which means it is
- // implicitly 0.
- if (!seenOffset)
- offset = 0;
+ // Canonical case for empty map.
+ if (!m) {
+ // 0-D corner case, offset is already 0.
+ if (t.getRank() == 0)
+ return success();
+ auto stridedExpr =
+ makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
+ if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
+ return success();
+ assert(false && "unexpected failure: extract strides in canonical layout");
+ }
+
+ // Non-canonical case requires more work.
+ auto stridedExpr =
+ simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
+ if (failed(extractStrides(stridedExpr, one, strides, offset))) {
+ offset = AffineExpr();
+ strides.clear();
+ return failure();
+ }
- if (failed || !llvm::all_of(seen, [](bool b) { return b; })) {
+ // Simplify results to allow folding to constants and simple checks.
+ unsigned numDims = m.getNumDims();
+ unsigned numSymbols = m.getNumSymbols();
+ offset = simplifyAffineExpr(offset, numDims, numSymbols);
+ for (auto &stride : strides)
+ stride = simplifyAffineExpr(stride, numDims, numSymbols);
+
+ /// In practice, a strided memref must be internally non-aliasing. Test
+ /// against 0 as a proxy.
+ /// TODO(ntv) static cases can have more advanced checks.
+ /// TODO(ntv) dynamic cases would require a way to compare symbolic
+ /// expressions and would probably need an affine set context propagated
+ /// everywhere.
+ if (llvm::any_of(strides, [](AffineExpr e) {
+ return e == getAffineConstantExpr(0, e.getContext());
+ })) {
+ offset = AffineExpr();
strides.clear();
return failure();
}
@@ -623,6 +589,26 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
return success();
}
+LogicalResult mlir::getStridesAndOffset(MemRefType t,
+ SmallVectorImpl<int64_t> &strides,
+ int64_t &offset) {
+ AffineExpr offsetExpr;
+ SmallVector<AffineExpr, 4> strideExprs;
+ if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
+ return failure();
+ if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
+ offset = cst.getValue();
+ else
+ offset = ShapedType::kDynamicStrideOrOffset;
+ for (auto e : strideExprs) {
+ if (auto c = e.dyn_cast<AffineConstantExpr>())
+ strides.push_back(c.getValue());
+ else
+ strides.push_back(ShapedType::kDynamicStrideOrOffset);
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/AffineOps/memref-stride-calculation.mlir b/mlir/test/AffineOps/memref-stride-calculation.mlir
index 6efd21d01dc..aacd0c776f3 100644
--- a/mlir/test/AffineOps/memref-stride-calculation.mlir
+++ b/mlir/test/AffineOps/memref-stride-calculation.mlir
@@ -67,5 +67,15 @@ func @f(%0: index) {
// CHECK: MemRefType memref<3x4x5xf32, (d0, d1, d2) -> (d0 ceildiv 4 + d1 + d2)> cannot be converted to strided form
%103 = alloc() : memref<3x4x5xf32, (i, j, k)->(i mod 4 + j + k)>
// CHECK: MemRefType memref<3x4x5xf32, (d0, d1, d2) -> (d0 mod 4 + d1 + d2)> cannot be converted to strided form
+
+ %200 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * i + N * i + N * j + K * k - (M + N - 20)* i)>
+ // CHECK: MemRefType offset: 0 strides: 20, ?, ?
+ %201 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * i + N * i + N * K * j + K * K * k - (M + N - 20) * (i + 1))>
+ // CHECK: MemRefType offset: ? strides: 20, ?, ?
+ %202 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M * (i + 1) + j + k - M)>
+ // CHECK: MemRefType offset: 0 strides: ?, 1, 1
+ %203 = alloc()[%0, %0, %0] : memref<3x4x5xf32, (i, j, k)[M, N, K]->(M + M * (i + N * (j + K * k)))>
+ // CHECK: MemRefType offset: ? strides: ?, ?, ?
+
return
}
OpenPOWER on IntegriCloud