summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/AffineOps/AffineOps.h5
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.td3
-rw-r--r--mlir/include/mlir/IR/StandardTypes.h4
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp2
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp52
-rw-r--r--mlir/lib/IR/StandardTypes.cpp6
-rw-r--r--mlir/test/AffineOps/ops.mlir21
7 files changed, 80 insertions, 13 deletions
diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
index 1cd115e89c4..35dd3a29348 100644
--- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
+++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h
@@ -39,8 +39,9 @@ class FlatAffineConstraints;
class OpBuilder;
/// A utility function to check if a value is defined at the top level of a
-/// function. A value defined at the top level is always a valid symbol.
-bool isTopLevelSymbol(Value *value);
+/// function. A value of index type defined at the top level is always a valid
+/// symbol.
+bool isTopLevelValue(Value *value);
class AffineOpsDialect : public Dialect {
public:
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index e335984ed59..efe83ee6f0e 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -204,6 +204,9 @@ def AllocOp : Std_Op<"alloc"> {
operand_range getSymbolicOperands() {
return {operand_begin() + getType().getNumDynamicDims(), operand_end()};
}
+
+ /// Returns the dynamic sizes for this alloc operation if specified.
+ operand_range getDynamicSizes() { return getOperands(); }
}];
let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h
index 264b1653234..2d232897428 100644
--- a/mlir/include/mlir/IR/StandardTypes.h
+++ b/mlir/include/mlir/IR/StandardTypes.h
@@ -226,6 +226,10 @@ public:
/// Otherwise, abort.
int64_t getDimSize(int64_t i) const;
+ /// Returns the position of the dynamic dimension relative to just the dynamic
+ /// dimensions, given its `index` within the shape.
+ unsigned getDynamicDimIndex(unsigned index) const;
+
/// Get the total amount of bits occupied by a value of this type. This does
/// not take into account any memory layout or widening constraints, e.g. a
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 4b171f0bede..80dc73755c7 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -822,7 +822,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) {
return;
// Caller is expected to fully compose map/operands if necessary.
- assert((isTopLevelSymbol(id) || isForInductionVar(id)) &&
+ assert((isTopLevelValue(id) || isForInductionVar(id)) &&
"non-terminal symbol / loop IV expected");
// Outer loop IVs could be used in forOp's bounds.
if (auto loop = getForInductionVarOwner(id)) {
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 77ee9cfde72..ae219c740b7 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -105,8 +105,9 @@ static bool isFunctionRegion(Region *region) {
}
/// A utility function to check if a value is defined at the top level of a
-/// function. A value defined at the top level is always a valid symbol.
-bool mlir::isTopLevelSymbol(Value *value) {
+/// function. A value of index type defined at the top level is always a valid
+/// symbol.
+bool mlir::isTopLevelValue(Value *value) {
if (auto *arg = dyn_cast<BlockArgument>(value))
return isFunctionRegion(arg->getOwner()->getParent());
return isFunctionRegion(value->getDefiningOp()->getParentRegion());
@@ -130,13 +131,46 @@ bool mlir::isValidDim(Value *value) {
// The dim op is okay if its operand memref/tensor is defined at the top
// level.
if (auto dimOp = dyn_cast<DimOp>(op))
- return isTopLevelSymbol(dimOp.getOperand());
+ return isTopLevelValue(dimOp.getOperand());
return false;
}
// This value is a block argument (which also includes 'affine.for' loop IVs).
return true;
}
+/// Returns true if the 'index' dimension of the `memref` defined by
+/// `memrefDefOp` is a statically shaped one or defined using a valid symbol.
+template <typename AnyMemRefDefOp>
+bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) {
+ auto memRefType = memrefDefOp.getType();
+ // Statically shaped.
+ if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
+ return true;
+ // Get the position of the dimension among dynamic dimensions;
+ unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
+ return isValidSymbol(
+ *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos));
+}
+
+/// Returns true if the result of the dim op is a valid symbol.
+static bool isDimOpValidSymbol(DimOp dimOp) {
+ // The dim op is okay if its operand memref/tensor is defined at the top
+ // level.
+ if (isTopLevelValue(dimOp.getOperand()))
+ return true;
+
+ // The dim op is also okay if its operand memref/tensor is a view/subview
+ // whose corresponding size is a valid symbol.
+ unsigned index = dimOp.getIndex();
+ if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp()))
+ return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
+ if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp()))
+ return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
+ if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp()))
+ return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
+ return false;
+}
+
// Value can be used as a symbol if it is a constant, or it is defined at
// the top level, or it is a result of affine apply operation with symbol
// arguments.
@@ -152,14 +186,12 @@ bool mlir::isValidSymbol(Value *value) {
// Affine apply operation is ok if all of its operands are ok.
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
return applyOp.isValidSymbol();
- // The dim op is okay if its operand memref/tensor is defined at the top
- // level.
- if (auto dimOp = dyn_cast<DimOp>(op))
- return isTopLevelSymbol(dimOp.getOperand());
- return false;
+ if (auto dimOp = dyn_cast<DimOp>(op)) {
+ return isDimOpValidSymbol(dimOp);
+ }
}
- // Otherwise, check that the value is a top level symbol.
- return isTopLevelSymbol(value);
+ // Otherwise, check that the value is a top level value.
+ return isTopLevelValue(value);
}
// Returns true if 'value' is a valid index to an affine operation (e.g.
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 3f3677cc05d..4347856de36 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -152,6 +152,12 @@ int64_t ShapedType::getDimSize(int64_t i) const {
return getShape()[i];
}
+unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
+ assert(index < getRank() && "invalid index");
+ assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
+ return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
+}
+
/// Get the number of bits require to store a value of the given shaped type.
/// Compute the value recursively since tensors are allowed to have vectors as
/// elements.
diff --git a/mlir/test/AffineOps/ops.mlir b/mlir/test/AffineOps/ops.mlir
index 795bce96f3e..d78ddd2d76f 100644
--- a/mlir/test/AffineOps/ops.mlir
+++ b/mlir/test/AffineOps/ops.mlir
@@ -78,3 +78,24 @@ func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) {
%3 = affine.min ()[] -> (77, 78, 79) ()[]
return
}
+
+// -----
+
+func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) {
+ %c0 = constant 1 : index
+ %c1 = constant 0 : index
+ %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
+ affine.for %arg3 = 0 to %arg2 step 768 {
+ %13 = dim %0, 1 : memref<?x?xf32>
+ affine.for %arg4 = 0 to %13 step 264 {
+ %18 = dim %0, 0 : memref<?x?xf32>
+ %20 = std.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref<?x?xf32>
+ to memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
+ %24 = dim %20, 0 : memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>
+ affine.for %arg5 = 0 to %24 step 768 {
+ "foo"() : () -> ()
+ }
+ }
+ }
+ return
+}
OpenPOWER on IntegriCloud