diff options
| author | Uday Bondhugula <uday@polymagelabs.com> | 2019-11-22 21:47:47 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-22 22:09:31 -0800 |
| commit | 6a101671b040f59b1fe2ae00244c08b20f2766d5 (patch) | |
| tree | e5f949c7d29924371bfbb845dc7131e3b0fc7d34 /mlir/lib | |
| parent | b8ee5634491e3b3b0a52dd50ccd44103c918d3fe (diff) | |
| download | bcm5719-llvm-6a101671b040f59b1fe2ae00244c08b20f2766d5.tar.gz bcm5719-llvm-6a101671b040f59b1fe2ae00244c08b20f2766d5.zip | |
Make isValidSymbol more powerful
The check in isValidSymbol, as far as a DimOp result went, checked if
the dim op was on a top-level memref. However, any alloc'ed, view, or
subview memref would be fine as long as the corresponding dimension of
that memref is either a static one or was in turn created using a valid
symbol in the case of dynamic dimensions.
Reported-by: Jose Gomez
Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Closes tensorflow/mlir#252
COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/252 from bondhugula:symbol 7b57dc394df9375e651f497231c6e4525a32a662
PiperOrigin-RevId: 282097114
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/AffineStructures.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/AffineOps.cpp | 52 | ||||
| -rw-r--r-- | mlir/lib/IR/StandardTypes.cpp | 6 |
3 files changed, 49 insertions, 11 deletions
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 4b171f0bede..80dc73755c7 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -822,7 +822,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { return; // Caller is expected to fully compose map/operands if necessary. - assert((isTopLevelSymbol(id) || isForInductionVar(id)) && + assert((isTopLevelValue(id) || isForInductionVar(id)) && "non-terminal symbol / loop IV expected"); // Outer loop IVs could be used in forOp's bounds. if (auto loop = getForInductionVarOwner(id)) { diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 77ee9cfde72..ae219c740b7 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -105,8 +105,9 @@ static bool isFunctionRegion(Region *region) { } /// A utility function to check if a value is defined at the top level of a -/// function. A value defined at the top level is always a valid symbol. -bool mlir::isTopLevelSymbol(Value *value) { +/// function. A value of index type defined at the top level is always a valid +/// symbol. +bool mlir::isTopLevelValue(Value *value) { if (auto *arg = dyn_cast<BlockArgument>(value)) return isFunctionRegion(arg->getOwner()->getParent()); return isFunctionRegion(value->getDefiningOp()->getParentRegion()); @@ -130,13 +131,46 @@ bool mlir::isValidDim(Value *value) { // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = dyn_cast<DimOp>(op)) - return isTopLevelSymbol(dimOp.getOperand()); + return isTopLevelValue(dimOp.getOperand()); return false; } // This value is a block argument (which also includes 'affine.for' loop IVs). return true; } +/// Returns true if the 'index' dimension of the `memref` defined by +/// `memrefDefOp` is a statically shaped one or defined using a valid symbol. +template <typename AnyMemRefDefOp> +bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) { + auto memRefType = memrefDefOp.getType(); + // Statically shaped. + if (!ShapedType::isDynamic(memRefType.getDimSize(index))) + return true; + // Get the position of the dimension among dynamic dimensions; + unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); + return isValidSymbol( + *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos)); +} + +/// Returns true if the result of the dim op is a valid symbol. +static bool isDimOpValidSymbol(DimOp dimOp) { + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (isTopLevelValue(dimOp.getOperand())) + return true; + + // The dim op is also okay if its operand memref/tensor is a view/subview + // whose corresponding size is a valid symbol. + unsigned index = dimOp.getIndex(); + if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<ViewOp>(viewOp, index); + if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index); + if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<AllocOp>(allocOp, index); + return false; +} + // Value can be used as a symbol if it is a constant, or it is defined at // the top level, or it is a result of affine apply operation with symbol // arguments. @@ -152,14 +186,12 @@ bool mlir::isValidSymbol(Value *value) { // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast<AffineApplyOp>(op)) return applyOp.isValidSymbol(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast<DimOp>(op)) - return isTopLevelSymbol(dimOp.getOperand()); - return false; + if (auto dimOp = dyn_cast<DimOp>(op)) { + return isDimOpValidSymbol(dimOp); + } } - // Otherwise, check that the value is a top level symbol. - return isTopLevelSymbol(value); + // Otherwise, check that the value is a top level value. + return isTopLevelValue(value); } // Returns true if 'value' is a valid index to an affine operation (e.g. diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 3f3677cc05d..4347856de36 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -152,6 +152,12 @@ int64_t ShapedType::getDimSize(int64_t i) const { return getShape()[i]; } +unsigned ShapedType::getDynamicDimIndex(unsigned index) const { + assert(index < getRank() && "invalid index"); + assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); + return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); +} + /// Get the number of bits require to store a value of the given shaped type. /// Compute the value recursively since tensors are allowed to have vectors as /// elements. |

