summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorUday Bondhugula <uday@polymagelabs.com>2019-11-22 21:47:47 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-22 22:09:31 -0800
commit6a101671b040f59b1fe2ae00244c08b20f2766d5 (patch)
treee5f949c7d29924371bfbb845dc7131e3b0fc7d34 /mlir/lib
parentb8ee5634491e3b3b0a52dd50ccd44103c918d3fe (diff)
downloadbcm5719-llvm-6a101671b040f59b1fe2ae00244c08b20f2766d5.tar.gz
bcm5719-llvm-6a101671b040f59b1fe2ae00244c08b20f2766d5.zip
Make isValidSymbol more powerful
The check in isValidSymbol, as far as a DimOp result went, checked if the dim op was on a top-level memref. However, any alloc'ed, view, or subview memref would be fine as long as the corresponding dimension of that memref is either a static one or was in turn created using a valid symbol in the case of dynamic dimensions. Reported-by: Jose Gomez Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#252 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/252 from bondhugula:symbol 7b57dc394df9375e651f497231c6e4525a32a662 PiperOrigin-RevId: 282097114
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AffineStructures.cpp2
-rw-r--r--mlir/lib/Dialect/AffineOps/AffineOps.cpp52
-rw-r--r--mlir/lib/IR/StandardTypes.cpp6
3 files changed, 49 insertions, 11 deletions
diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 4b171f0bede..80dc73755c7 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -822,7 +822,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value *id) {
return;
// Caller is expected to fully compose map/operands if necessary.
- assert((isTopLevelSymbol(id) || isForInductionVar(id)) &&
+ assert((isTopLevelValue(id) || isForInductionVar(id)) &&
"non-terminal symbol / loop IV expected");
// Outer loop IVs could be used in forOp's bounds.
if (auto loop = getForInductionVarOwner(id)) {
diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
index 77ee9cfde72..ae219c740b7 100644
--- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp
+++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp
@@ -105,8 +105,9 @@ static bool isFunctionRegion(Region *region) {
}
/// A utility function to check if a value is defined at the top level of a
-/// function. A value defined at the top level is always a valid symbol.
-bool mlir::isTopLevelSymbol(Value *value) {
+/// function. A value of index type defined at the top level is always a valid
+/// symbol.
+bool mlir::isTopLevelValue(Value *value) {
if (auto *arg = dyn_cast<BlockArgument>(value))
return isFunctionRegion(arg->getOwner()->getParent());
return isFunctionRegion(value->getDefiningOp()->getParentRegion());
@@ -130,13 +131,46 @@ bool mlir::isValidDim(Value *value) {
// The dim op is okay if its operand memref/tensor is defined at the top
// level.
if (auto dimOp = dyn_cast<DimOp>(op))
- return isTopLevelSymbol(dimOp.getOperand());
+ return isTopLevelValue(dimOp.getOperand());
return false;
}
// This value is a block argument (which also includes 'affine.for' loop IVs).
return true;
}
+/// Returns true if the 'index' dimension of the `memref` defined by
+/// `memrefDefOp` is a statically shaped one or defined using a valid symbol.
+template <typename AnyMemRefDefOp>
+bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index) {
+ auto memRefType = memrefDefOp.getType();
+ // Statically shaped.
+ if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
+ return true;
+ // Get the position of the dimension among dynamic dimensions;
+ unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
+ return isValidSymbol(
+ *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos));
+}
+
+/// Returns true if the result of the dim op is a valid symbol.
+static bool isDimOpValidSymbol(DimOp dimOp) {
+ // The dim op is okay if its operand memref/tensor is defined at the top
+ // level.
+ if (isTopLevelValue(dimOp.getOperand()))
+ return true;
+
+ // The dim op is also okay if its operand memref/tensor is a view/subview
+ // whose corresponding size is a valid symbol.
+ unsigned index = dimOp.getIndex();
+ if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand()->getDefiningOp()))
+ return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
+ if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand()->getDefiningOp()))
+ return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
+ if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand()->getDefiningOp()))
+ return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
+ return false;
+}
+
// Value can be used as a symbol if it is a constant, or it is defined at
// the top level, or it is a result of affine apply operation with symbol
// arguments.
@@ -152,14 +186,12 @@ bool mlir::isValidSymbol(Value *value) {
// Affine apply operation is ok if all of its operands are ok.
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
return applyOp.isValidSymbol();
- // The dim op is okay if its operand memref/tensor is defined at the top
- // level.
- if (auto dimOp = dyn_cast<DimOp>(op))
- return isTopLevelSymbol(dimOp.getOperand());
- return false;
+ if (auto dimOp = dyn_cast<DimOp>(op)) {
+ return isDimOpValidSymbol(dimOp);
+ }
}
- // Otherwise, check that the value is a top level symbol.
- return isTopLevelSymbol(value);
+ // Otherwise, check that the value is a top level value.
+ return isTopLevelValue(value);
}
// Returns true if 'value' is a valid index to an affine operation (e.g.
diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 3f3677cc05d..4347856de36 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -152,6 +152,12 @@ int64_t ShapedType::getDimSize(int64_t i) const {
return getShape()[i];
}
+unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
+ assert(index < getRank() && "invalid index");
+ assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
+ return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
+}
+
/// Get the number of bits require to store a value of the given shaped type.
/// Compute the value recursively since tensors are allowed to have vectors as
/// elements.
OpenPOWER on IntegriCloud