diff options
| author | Jacques Pienaar <jpienaar@google.com> | 2019-12-09 08:57:27 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-12-09 08:57:56 -0800 |
| commit | 70aeb4566e35541ffeef28050babc1c9580b43eb (patch) | |
| tree | 5955edd534f872ecc7ef626436e5ac63224acc17 /mlir | |
| parent | 7b19bd5411a68399db4bcf3c2804a67f1d0b3a62 (diff) | |
| download | bcm5719-llvm-70aeb4566e35541ffeef28050babc1c9580b43eb.tar.gz bcm5719-llvm-70aeb4566e35541ffeef28050babc1c9580b43eb.zip | |
Add RegionRange for when need to abstract over different region iteration
Follows ValueRange in representing a generic abstraction over the different
ways to represent a range of Regions. This wrapper is not as ValueRange and only
considers the current cases of interest: MutableArrayRef<Region> and
ArrayRef<std::unique_ptr<Region>> as occurs during op construction vs op region
querying.
Note: ArrayRef<std::unique_ptr<Region>> allows for unset regions, so this range
returns a pointer to a Region instead of a Region.
PiperOrigin-RevId: 284563229
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/include/mlir/Analysis/InferTypeOpInterface.td | 2 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Operation.h | 14 | ||||
| -rw-r--r-- | mlir/include/mlir/IR/Region.h | 53 | ||||
| -rw-r--r-- | mlir/lib/IR/Operation.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/IR/Region.cpp | 18 | ||||
| -rw-r--r-- | mlir/test/lib/TestDialect/TestDialect.cpp | 2 |
6 files changed, 80 insertions, 12 deletions
diff --git a/mlir/include/mlir/Analysis/InferTypeOpInterface.td b/mlir/include/mlir/Analysis/InferTypeOpInterface.td index 2fa1cc887ca..7f63b2b18bf 100644 --- a/mlir/include/mlir/Analysis/InferTypeOpInterface.td +++ b/mlir/include/mlir/Analysis/InferTypeOpInterface.td @@ -50,7 +50,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { /*args=*/(ins "llvm::Optional<Location>":$location, "ValueRange":$operands, "ArrayRef<NamedAttribute>":$attributes, - "ArrayRef<Region>":$regions, + "RegionRange":$regions, "SmallVectorImpl<Type>&":$inferedReturnTypes) >, StaticInterfaceMethod< diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index f2c94bc539c..75ea9727d4a 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -71,13 +71,11 @@ public: static Operation *create(const OperationState &state); /// Create a new Operation with the specific fields. - static Operation *create(Location location, OperationName name, - ArrayRef<Type> resultTypes, - ArrayRef<Value *> operands, - NamedAttributeList attributes, - ArrayRef<Block *> successors = {}, - ArrayRef<std::unique_ptr<Region>> regions = {}, - bool resizableOperandList = false); + static Operation * + create(Location location, OperationName name, ArrayRef<Type> resultTypes, + ArrayRef<Value *> operands, NamedAttributeList attributes, + ArrayRef<Block *> successors = {}, RegionRange regions = {}, + bool resizableOperandList = false); /// The name of an operation is the key identifier for it. OperationName getName() { return name; } @@ -799,7 +797,7 @@ inline auto Operation::getResultTypes() -> result_type_range { /// This class provides an abstraction over the different types of ranges over /// Value*s. In many cases, this prevents the need to explicitly materialize a /// SmallVector/std::vector. This class should be used in places that are not -/// suitable for a more derived type(e.g. ArrayRef) or a template range +/// suitable for a more derived type (e.g. ArrayRef) or a template range /// parameter. class ValueRange { /// The type representing the owner of this range. This is either a list of diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 249ba9562f2..933bf104506 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -160,6 +160,59 @@ private: Operation *container; }; +/// This class provides an abstraction over the different types of ranges over +/// Regions. In many cases, this prevents the need to explicitly materialize a +/// SmallVector/std::vector. This class should be used in places that are not +/// suitable for a more derived type (e.g. ArrayRef) or a template range +/// parameter. +class RegionRange { + /// The type representing the owner of this range. This is either a list of + /// values, operands, or results. + using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>; + +public: + RegionRange(const RegionRange &) = default; + RegionRange(RegionRange &&) = default; + + RegionRange(MutableArrayRef<Region> regions = llvm::None); + + template <typename Arg, + typename = typename std::enable_if_t<std::is_constructible< + ArrayRef<std::unique_ptr<Region>>, Arg>::value>> + RegionRange(Arg &&arg) + : RegionRange(ArrayRef<std::unique_ptr<Region>>(std::forward<Arg>(arg))) { + } + RegionRange(ArrayRef<std::unique_ptr<Region>> regions); + + /// An iterator element of this range. + class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Region *, + Region *, Region *> { + public: + Region *operator*() const; + + private: + Iterator(OwnerT owner, unsigned curIndex); + /// Allow access to the constructor. + friend RegionRange; + }; + Iterator begin() const { return Iterator(owner, 0); } + Iterator end() const { return Iterator(owner, count); } + Region *operator[](unsigned index) const { + assert(index < size() && "invalid index for region range"); + return *std::next(begin(), index); + } + /// Return the size of this range. + size_t size() const { return count; } + /// Return if the range is empty. + bool empty() const { return size() == 0; } + +private: + /// The object that owns the provided range of regions. + OwnerT owner; + /// The size from the owning range. + unsigned count; +}; + } // end namespace mlir #endif // MLIR_IR_REGION_H diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 26f20d324f0..ae635d108b2 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -136,8 +136,7 @@ Operation *Operation::create(Location location, OperationName name, ArrayRef<Type> resultTypes, ArrayRef<Value *> operands, NamedAttributeList attributes, - ArrayRef<Block *> successors, - ArrayRef<std::unique_ptr<Region>> regions, + ArrayRef<Block *> successors, RegionRange regions, bool resizableOperandList) { unsigned numRegions = regions.size(); Operation *op = create(location, name, resultTypes, operands, attributes, diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index a91d36b1e48..a5a19cbcbf2 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -213,3 +213,21 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList( for (; first != last; ++first) first->parentValidOpOrderPair.setPointer(curParent); } + +//===----------------------------------------------------------------------===// +// RegionRange +//===----------------------------------------------------------------------===// +RegionRange::RegionRange(MutableArrayRef<Region> regions) + : owner(regions.data()), count(regions.size()) {} +RegionRange::RegionRange(ArrayRef<std::unique_ptr<Region>> regions) + : owner(regions.data()), count(regions.size()) {} +RegionRange::Iterator::Iterator(OwnerT owner, unsigned curIndex) + : indexed_accessor_iterator<Iterator, OwnerT, Region *, Region *, Region *>( + owner, curIndex) {} + +Region *RegionRange::Iterator::operator*() const { + if (const std::unique_ptr<Region> *operand = + object.dyn_cast<const std::unique_ptr<Region> *>()) + return operand[index].get(); + return &object.get<Region *>()[index]; +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index f470e6ab674..059cfb3dce7 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -291,7 +291,7 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold( LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( llvm::Optional<Location> location, ValueRange operands, - ArrayRef<NamedAttribute> attributes, ArrayRef<Region> regions, + ArrayRef<NamedAttribute> attributes, RegionRange regions, SmallVectorImpl<Type> &inferedReturnTypes) { if (operands[0]->getType() != operands[1]->getType()) { return emitOptionalError(location, "operand type mismatch ", |

