summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
authorJacques Pienaar <jpienaar@google.com>2019-12-09 08:57:27 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-12-09 08:57:56 -0800
commit70aeb4566e35541ffeef28050babc1c9580b43eb (patch)
tree5955edd534f872ecc7ef626436e5ac63224acc17 /mlir
parent7b19bd5411a68399db4bcf3c2804a67f1d0b3a62 (diff)
downloadbcm5719-llvm-70aeb4566e35541ffeef28050babc1c9580b43eb.tar.gz
bcm5719-llvm-70aeb4566e35541ffeef28050babc1c9580b43eb.zip
Add RegionRange for when need to abstract over different region iteration
Follows ValueRange in representing a generic abstraction over the different ways to represent a range of Regions. This wrapper is not as ValueRange and only considers the current cases of interest: MutableArrayRef<Region> and ArrayRef<std::unique_ptr<Region>> as occurs during op construction vs op region querying. Note: ArrayRef<std::unique_ptr<Region>> allows for unset regions, so this range returns a pointer to a Region instead of a Region. PiperOrigin-RevId: 284563229
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/InferTypeOpInterface.td2
-rw-r--r--mlir/include/mlir/IR/Operation.h14
-rw-r--r--mlir/include/mlir/IR/Region.h53
-rw-r--r--mlir/lib/IR/Operation.cpp3
-rw-r--r--mlir/lib/IR/Region.cpp18
-rw-r--r--mlir/test/lib/TestDialect/TestDialect.cpp2
6 files changed, 80 insertions, 12 deletions
diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td
index 2fa1cc887ca..7f63b2b18bf 100644
--- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td
@@ -50,7 +50,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
/*args=*/(ins "llvm::Optional<Location>":$location,
"ValueRange":$operands,
"ArrayRef<NamedAttribute>":$attributes,
- "ArrayRef<Region>":$regions,
+ "RegionRange":$regions,
"SmallVectorImpl<Type>&":$inferedReturnTypes)
>,
StaticInterfaceMethod<
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index f2c94bc539c..75ea9727d4a 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -71,13 +71,11 @@ public:
static Operation *create(const OperationState &state);
/// Create a new Operation with the specific fields.
- static Operation *create(Location location, OperationName name,
- ArrayRef<Type> resultTypes,
- ArrayRef<Value *> operands,
- NamedAttributeList attributes,
- ArrayRef<Block *> successors = {},
- ArrayRef<std::unique_ptr<Region>> regions = {},
- bool resizableOperandList = false);
+ static Operation *
+ create(Location location, OperationName name, ArrayRef<Type> resultTypes,
+ ArrayRef<Value *> operands, NamedAttributeList attributes,
+ ArrayRef<Block *> successors = {}, RegionRange regions = {},
+ bool resizableOperandList = false);
/// The name of an operation is the key identifier for it.
OperationName getName() { return name; }
@@ -799,7 +797,7 @@ inline auto Operation::getResultTypes() -> result_type_range {
/// This class provides an abstraction over the different types of ranges over
/// Value*s. In many cases, this prevents the need to explicitly materialize a
/// SmallVector/std::vector. This class should be used in places that are not
-/// suitable for a more derived type(e.g. ArrayRef) or a template range
+/// suitable for a more derived type (e.g. ArrayRef) or a template range
/// parameter.
class ValueRange {
/// The type representing the owner of this range. This is either a list of
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 249ba9562f2..933bf104506 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -160,6 +160,59 @@ private:
Operation *container;
};
+/// This class provides an abstraction over the different types of ranges over
+/// Regions. In many cases, this prevents the need to explicitly materialize a
+/// SmallVector/std::vector. This class should be used in places that are not
+/// suitable for a more derived type (e.g. ArrayRef) or a template range
+/// parameter.
+class RegionRange {
+ /// The type representing the owner of this range. This is either a list of
+ /// values, operands, or results.
+ using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>;
+
+public:
+ RegionRange(const RegionRange &) = default;
+ RegionRange(RegionRange &&) = default;
+
+ RegionRange(MutableArrayRef<Region> regions = llvm::None);
+
+ template <typename Arg,
+ typename = typename std::enable_if_t<std::is_constructible<
+ ArrayRef<std::unique_ptr<Region>>, Arg>::value>>
+ RegionRange(Arg &&arg)
+ : RegionRange(ArrayRef<std::unique_ptr<Region>>(std::forward<Arg>(arg))) {
+ }
+ RegionRange(ArrayRef<std::unique_ptr<Region>> regions);
+
+ /// An iterator element of this range.
+ class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Region *,
+ Region *, Region *> {
+ public:
+ Region *operator*() const;
+
+ private:
+ Iterator(OwnerT owner, unsigned curIndex);
+ /// Allow access to the constructor.
+ friend RegionRange;
+ };
+ Iterator begin() const { return Iterator(owner, 0); }
+ Iterator end() const { return Iterator(owner, count); }
+ Region *operator[](unsigned index) const {
+ assert(index < size() && "invalid index for region range");
+ return *std::next(begin(), index);
+ }
+ /// Return the size of this range.
+ size_t size() const { return count; }
+ /// Return if the range is empty.
+ bool empty() const { return size() == 0; }
+
+private:
+ /// The object that owns the provided range of regions.
+ OwnerT owner;
+ /// The size from the owning range.
+ unsigned count;
+};
+
} // end namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 26f20d324f0..ae635d108b2 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -136,8 +136,7 @@ Operation *Operation::create(Location location, OperationName name,
ArrayRef<Type> resultTypes,
ArrayRef<Value *> operands,
NamedAttributeList attributes,
- ArrayRef<Block *> successors,
- ArrayRef<std::unique_ptr<Region>> regions,
+ ArrayRef<Block *> successors, RegionRange regions,
bool resizableOperandList) {
unsigned numRegions = regions.size();
Operation *op = create(location, name, resultTypes, operands, attributes,
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index a91d36b1e48..a5a19cbcbf2 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -213,3 +213,21 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList(
for (; first != last; ++first)
first->parentValidOpOrderPair.setPointer(curParent);
}
+
+//===----------------------------------------------------------------------===//
+// RegionRange
+//===----------------------------------------------------------------------===//
+RegionRange::RegionRange(MutableArrayRef<Region> regions)
+ : owner(regions.data()), count(regions.size()) {}
+RegionRange::RegionRange(ArrayRef<std::unique_ptr<Region>> regions)
+ : owner(regions.data()), count(regions.size()) {}
+RegionRange::Iterator::Iterator(OwnerT owner, unsigned curIndex)
+ : indexed_accessor_iterator<Iterator, OwnerT, Region *, Region *, Region *>(
+ owner, curIndex) {}
+
+Region *RegionRange::Iterator::operator*() const {
+ if (const std::unique_ptr<Region> *operand =
+ object.dyn_cast<const std::unique_ptr<Region> *>())
+ return operand[index].get();
+ return &object.get<Region *>()[index];
+}
diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp
index f470e6ab674..059cfb3dce7 100644
--- a/mlir/test/lib/TestDialect/TestDialect.cpp
+++ b/mlir/test/lib/TestDialect/TestDialect.cpp
@@ -291,7 +291,7 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
llvm::Optional<Location> location, ValueRange operands,
- ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes) {
if (operands[0]->getType() != operands[1]->getType()) {
return emitOptionalError(location, "operand type mismatch ",
OpenPOWER on IntegriCloud