diff options
| author | Uday Bondhugula <uday@polymagelabs.com> | 2019-11-22 21:47:47 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-22 22:09:31 -0800 |
| commit | 6a101671b040f59b1fe2ae00244c08b20f2766d5 (patch) | |
| tree | e5f949c7d29924371bfbb845dc7131e3b0fc7d34 /mlir | |
| parent | b8ee5634491e3b3b0a52dd50ccd44103c918d3fe (diff) | |
| download | bcm5719-llvm-6a101671b040f59b1fe2ae00244c08b20f2766d5.tar.gz bcm5719-llvm-6a101671b040f59b1fe2ae00244c08b20f2766d5.zip | |
Make isValidSymbol more powerful
The check in isValidSymbol, as far as a DimOp result went, checked if
the dim op was on a top-level memref. However, any alloc'ed, view, or
subview memref would be fine as long as the corresponding dimension of
that memref is either a static one or was in turn created using a valid
symbol in the case of dynamic dimensions.
Reported-by: Jose Gomez
Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Closes tensorflow/mlir#252
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/252 from bondhugula:symbol 7b57dc394df9375e651f497231c6e4525a32a662
PiperOrigin-RevId: 282097114
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 5 | ||||
| -rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/Ops.td | 3 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/StandardTypes.h | 4 | ||||
| -rw-r--r-- | mlir/lib/Analysis/AffineStructures.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/AffineOps.cpp | 52 | ||||
| -rw-r--r-- | mlir/lib/IR/StandardTypes.cpp | 6 | ||||
| -rw-r--r-- | mlir/test/AffineOps/ops.mlir | 21 |
7 files changed, 80 insertions, 13 deletions
diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index 1cd115e89c4..35dd3a29348 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -39,8 +39,9 @@ class FlatAffineConstraints; class OpBuilder; /// A utility function to check if a value is defined at the top level of a -/// function. A value defined at the top level is always a valid symbol. -bool isTopLevelSymbol(Value *value); +/// function. A value of index type defined at the top level is always a valid +/// symbol. +bool isTopLevelValue(Value *value); class AffineOpsDialect : public Dialect { public: diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index e335984ed59..efe83ee6f0e 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -204,6 +204,9 @@ def AllocOp : Std_Op<"alloc"> { operand_range getSymbolicOperands() { return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; } + + /// Returns the dynamic sizes for this alloc operation if specified. + operand_range getDynamicSizes() { return getOperands(); } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 264b1653234..2d232897428 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -226,6 +226,10 @@ public: /// Otherwise, abort. int64_t getDimSize(int64_t i) const; + /// Returns the position of the dynamic dimension relative to just the dynamic + /// dimensions, given its `index` within the shape. + unsigned getDynamicDimIndex(unsigned index) const; + /// Get the total amount of bits occupied by a value of this type. This does /// not take into account any memory layout or widening constraints, e.g. a /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 4b171f0bede..80dc73755c7 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -822,7 +822,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { return; // Caller is expected to fully compose map/operands if necessary. - assert((isTopLevelSymbol(id) || isForInductionVar(id)) && + assert((isTopLevelValue(id) || isForInductionVar(id)) && "non-terminal symbol / loop IV expected"); // Outer loop IVs could be used in forOp's bounds. if (auto loop = getForInductionVarOwner(id)) { diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 77ee9cfde72..ae219c740b7 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -105,8 +105,9 @@ static bool isFunctionRegion(Region *region) { } /// A utility function to check if a value is defined at the top level of a -/// function. A value defined at the top level is always a valid symbol. -bool mlir::isTopLevelSymbol(Value *value) { +/// function. A value of index type defined at the top level is always a valid +/// symbol. +bool mlir::isTopLevelValue(Value *value) { if (auto *arg = dyn_cast<BlockArgument>(value)) return isFunctionRegion(arg->getOwner()->getParent()); return isFunctionRegion(value->getDefiningOp()->getParentRegion()); @@ -130,13 +131,46 @@ bool mlir::isValidDim(Value *value) { // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = dyn_cast<DimOp>(op)) - return isTopLevelSymbol(dimOp.getOperand()); + return isTopLevelValue(dimOp.getOperand()); return false; } // This value is a block argument (which also includes 'affine.for' loop IVs). return true; } +/// Returns true if the 'index' dimension of the `memref` defined by +/// `memrefDefOp` is a statically shaped one or defined using a valid symbol. +template <typename AnyMemRefDefOp> +bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) { + auto memRefType = memrefDefOp.getType(); + // Statically shaped. + if (!ShapedType::isDynamic(memRefType.getDimSize(index))) + return true; + // Get the position of the dimension among dynamic dimensions; + unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); + return isValidSymbol( + *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos)); +} + +/// Returns true if the result of the dim op is a valid symbol. +static bool isDimOpValidSymbol(DimOp dimOp) { + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (isTopLevelValue(dimOp.getOperand())) + return true; + + // The dim op is also okay if its operand memref/tensor is a view/subview + // whose corresponding size is a valid symbol. + unsigned index = dimOp.getIndex(); + if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<ViewOp>(viewOp, index); + if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index); + if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<AllocOp>(allocOp, index); + return false; +} + // Value can be used as a symbol if it is a constant, or it is defined at // the top level, or it is a result of affine apply operation with symbol // arguments. @@ -152,14 +186,12 @@ bool mlir::isValidSymbol(Value *value) { // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast<AffineApplyOp>(op)) return applyOp.isValidSymbol(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast<DimOp>(op)) - return isTopLevelSymbol(dimOp.getOperand()); - return false; + if (auto dimOp = dyn_cast<DimOp>(op)) { + return isDimOpValidSymbol(dimOp); + } } - // Otherwise, check that the value is a top level symbol. - return isTopLevelSymbol(value); + // Otherwise, check that the value is a top level value. + return isTopLevelValue(value); } // Returns true if 'value' is a valid index to an affine operation (e.g. diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 3f3677cc05d..4347856de36 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -152,6 +152,12 @@ int64_t ShapedType::getDimSize(int64_t i) const { return getShape()[i]; } +unsigned ShapedType::getDynamicDimIndex(unsigned index) const { + assert(index < getRank() && "invalid index"); + assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); + return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); +} + /// Get the number of bits require to store a value of the given shaped type. /// Compute the value recursively since tensors are allowed to have vectors as /// elements. diff --git a/mlir/test/AffineOps/ops.mlir b/mlir/test/AffineOps/ops.mlir index 795bce96f3e..d78ddd2d76f 100644 --- a/mlir/test/AffineOps/ops.mlir +++ b/mlir/test/AffineOps/ops.mlir @@ -78,3 +78,24 @@ func @affine_min(%arg0 : index, %arg1 : index, %arg2 : index) { %3 = affine.min ()[] -> (77, 78, 79) ()[] return } + +// ----- + +func @valid_symbols(%arg0: index, %arg1: index, %arg2: index) { + %c0 = constant 1 : index + %c1 = constant 0 : index + %0 = alloc(%arg0, %arg1) : memref<?x?xf32> + affine.for %arg3 = 0 to %arg2 step 768 { + %13 = dim %0, 1 : memref<?x?xf32> + affine.for %arg4 = 0 to %13 step 264 { + %18 = dim %0, 0 : memref<?x?xf32> + %20 = std.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref<?x?xf32> + to memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> + %24 = dim %20, 0 : memref<?x?xf32, (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> + affine.for %arg5 = 0 to %24 step 768 { + "foo"() : () -> () + } + } + } + return +} |

