diff options
| author | River Riddle <riverriddle@google.com> | 2019-11-11 18:18:02 -0800 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-11-11 18:18:31 -0800 |
| commit | 9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf (patch) | |
| tree | 2fd6e7aeaa1e41a5e5a6355f860f35bc63ca8d99 /mlir/lib/IR | |
| parent | 5cf6e0ce7f03f9841675b1a9d44232540f3df5cc (diff) | |
| download | bcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.tar.gz bcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.zip | |
Add support for nested symbol references.
This change allows for adding additional nested references to a SymbolRefAttr to allow for further resolving a symbol if that symbol also defines a SymbolTable. If a referenced symbol also defines a symbol table, a nested reference can be used to refer to a symbol within that table. Nested references are printed after the main reference in the following form:
symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)*
Example:
module @reference {
func @nested_reference()
}
my_reference_op @reference::@nested_reference
Given that SymbolRefAttr is now more general, the existing functionality centered around a single reference is moved to a derived class FlatSymbolRefAttr. Followup commits will add support to lookups, rauw, etc. for scoped references.
PiperOrigin-RevId: 279860501
Diffstat (limited to 'mlir/lib/IR')
| -rw-r--r-- | mlir/lib/IR/AsmPrinter.cpp | 10 | ||||
| -rw-r--r-- | mlir/lib/IR/AttributeDetail.h | 37 | ||||
| -rw-r--r-- | mlir/lib/IR/Attributes.cpp | 23 | ||||
| -rw-r--r-- | mlir/lib/IR/Builders.cpp | 9 | ||||
| -rw-r--r-- | mlir/lib/IR/FunctionSupport.cpp | 2 |
5 files changed, 72 insertions, 9 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 43452b2712d..6f77de0e721 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -801,9 +801,15 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { case StandardAttributes::Type: printType(attr.cast<TypeAttr>().getValue()); break; - case StandardAttributes::SymbolRef: - printSymbolReference(attr.cast<SymbolRefAttr>().getValue(), os); + case StandardAttributes::SymbolRef: { + auto refAttr = attr.dyn_cast<SymbolRefAttr>(); + printSymbolReference(refAttr.getRootReference(), os); + for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { + os << "::"; + printSymbolReference(nestedRef.getValue(), os); + } break; + } case StandardAttributes::OpaqueElements: { auto eltsAttr = attr.cast<OpaqueElementsAttr>(); os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", "; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 21f8b68c265..da4aa69dda4 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -321,6 +321,43 @@ struct StringAttributeStorage : public AttributeStorage { StringRef value; }; +/// An attribute representing a symbol reference. +struct SymbolRefAttributeStorage final + : public AttributeStorage, + public llvm::TrailingObjects<SymbolRefAttributeStorage, + FlatSymbolRefAttr> { + using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>; + + SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs) + : value(value), numNestedRefs(numNestedRefs) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { + return key == KeyTy(value, getNestedRefs()); + } + + /// Construct a new storage instance. + static SymbolRefAttributeStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>( + key.second.size()); + auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage)); + auto result = ::new (rawMem) SymbolRefAttributeStorage( + allocator.copyInto(key.first), key.second.size()); + std::uninitialized_copy(key.second.begin(), key.second.end(), + result->getTrailingObjects<FlatSymbolRefAttr>()); + return result; + } + + /// Returns the set of nested references. + ArrayRef<FlatSymbolRefAttr> getNestedRefs() const { + return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs}; + } + + StringRef value; + size_t numNestedRefs; +}; + /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { using KeyTy = Type; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index d74cacbe695..80ac4a59246 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -249,12 +249,27 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc, // SymbolRefAttr //===----------------------------------------------------------------------===// -SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { - return Base::get(ctx, StandardAttributes::SymbolRef, value, - NoneType::get(ctx)); +FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { + return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) + .cast<FlatSymbolRefAttr>(); } -StringRef SymbolRefAttr::getValue() const { return getImpl()->value; } +SymbolRefAttr SymbolRefAttr::get(StringRef value, + ArrayRef<FlatSymbolRefAttr> nestedReferences, + MLIRContext *ctx) { + return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); +} + +StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } + +StringRef SymbolRefAttr::getLeafReference() const { + ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); + return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); +} + +ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const { + return getImpl()->getNestedRefs(); +} //===----------------------------------------------------------------------===// // IntegerAttr diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 24ae2072f77..afdeefd023c 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -150,15 +150,20 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) { return ArrayAttr::get(value, context); } -SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { +FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { auto symName = value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); assert(symName && "value does not have a valid symbol name"); return getSymbolRefAttr(symName.getValue()); } -SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { +FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { return SymbolRefAttr::get(value, getContext()); } +SymbolRefAttr +Builder::getSymbolRefAttr(StringRef value, + ArrayRef<FlatSymbolRefAttr> nestedReferences) { + return SymbolRefAttr::get(value, nestedReferences, getContext()); +} ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) { auto attrs = functional::map( diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index d1ba2d30fa1..29cae177cec 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -159,7 +159,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, auto &builder = parser.getBuilder(); // Parse the name as a symbol reference attribute. - SymbolRefAttr nameAttr; + FlatSymbolRefAttr nameAttr; if (parser.parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), result.attributes)) return failure(); |

