diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Analysis/AffineStructures.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/AffineOps/AffineOps.cpp | 52 | ||||
| -rw-r--r-- | mlir/lib/IR/StandardTypes.cpp | 6 |
3 files changed, 49 insertions, 11 deletions
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 4b171f0bede..80dc73755c7 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -822,7 +822,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) { return; // Caller is expected to fully compose map/operands if necessary. - assert((isTopLevelSymbol(id) || isForInductionVar(id)) && + assert((isTopLevelValue(id) || isForInductionVar(id)) && "non-terminal symbol / loop IV expected"); // Outer loop IVs could be used in forOp's bounds. if (auto loop = getForInductionVarOwner(id)) { diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 77ee9cfde72..ae219c740b7 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -105,8 +105,9 @@ static bool isFunctionRegion(Region *region) { } /// A utility function to check if a value is defined at the top level of a -/// function. A value defined at the top level is always a valid symbol. -bool mlir::isTopLevelSymbol(Value *value) { +/// function. A value of index type defined at the top level is always a valid +/// symbol. +bool mlir::isTopLevelValue(Value *value) { if (auto *arg = dyn_cast<BlockArgument>(value)) return isFunctionRegion(arg->getOwner()->getParent()); return isFunctionRegion(value->getDefiningOp()->getParentRegion()); @@ -130,13 +131,46 @@ bool mlir::isValidDim(Value *value) { // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = dyn_cast<DimOp>(op)) - return isTopLevelSymbol(dimOp.getOperand()); + return isTopLevelValue(dimOp.getOperand()); return false; } // This value is a block argument (which also includes 'affine.for' loop IVs). return true; } +/// Returns true if the 'index' dimension of the `memref` defined by +/// `memrefDefOp` is a statically shaped one or defined using a valid symbol. +template <typename AnyMemRefDefOp> +bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) { + auto memRefType = memrefDefOp.getType(); + // Statically shaped. + if (!ShapedType::isDynamic(memRefType.getDimSize(index))) + return true; + // Get the position of the dimension among dynamic dimensions; + unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); + return isValidSymbol( + *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos)); +} + +/// Returns true if the result of the dim op is a valid symbol. +static bool isDimOpValidSymbol(DimOp dimOp) { + // The dim op is okay if its operand memref/tensor is defined at the top + // level. + if (isTopLevelValue(dimOp.getOperand())) + return true; + + // The dim op is also okay if its operand memref/tensor is a view/subview + // whose corresponding size is a valid symbol. + unsigned index = dimOp.getIndex(); + if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<ViewOp>(viewOp, index); + if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index); + if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp())) + return isMemRefSizeValidSymbol<AllocOp>(allocOp, index); + return false; +} + // Value can be used as a symbol if it is a constant, or it is defined at // the top level, or it is a result of affine apply operation with symbol // arguments. @@ -152,14 +186,12 @@ bool mlir::isValidSymbol(Value *value) { // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast<AffineApplyOp>(op)) return applyOp.isValidSymbol(); - // The dim op is okay if its operand memref/tensor is defined at the top - // level. - if (auto dimOp = dyn_cast<DimOp>(op)) - return isTopLevelSymbol(dimOp.getOperand()); - return false; + if (auto dimOp = dyn_cast<DimOp>(op)) { + return isDimOpValidSymbol(dimOp); + } } - // Otherwise, check that the value is a top level symbol. - return isTopLevelSymbol(value); + // Otherwise, check that the value is a top level value. + return isTopLevelValue(value); } // Returns true if 'value' is a valid index to an affine operation (e.g. diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 3f3677cc05d..4347856de36 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -152,6 +152,12 @@ int64_t ShapedType::getDimSize(int64_t i) const { return getShape()[i]; } +unsigned ShapedType::getDynamicDimIndex(unsigned index) const { + assert(index < getRank() && "invalid index"); + assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); + return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); +} + /// Get the number of bits require to store a value of the given shaped type. /// Compute the value recursively since tensors are allowed to have vectors as /// elements. |

