summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-11 18:18:02 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-11 18:18:31 -0800
commit9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf (patch)
tree2fd6e7aeaa1e41a5e5a6355f860f35bc63ca8d99 /mlir/lib/IR
parent5cf6e0ce7f03f9841675b1a9d44232540f3df5cc (diff)
downloadbcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.tar.gz
bcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.zip
Add support for nested symbol references.
This change allows for adding additional nested references to a SymbolRefAttr to allow for further resolving a symbol if that symbol also defines a SymbolTable. If a referenced symbol also defines a symbol table, a nested reference can be used to refer to a symbol within that table. Nested references are printed after the main reference in the following form: symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)* Example: module @reference { func @nested_reference() } my_reference_op @reference::@nested_reference Given that SymbolRefAttr is now more general, the existing functionality centered around a single reference is moved to a derived class FlatSymbolRefAttr. Followup commits will add support to lookups, rauw, etc. for scoped references. PiperOrigin-RevId: 279860501
Diffstat (limited to 'mlir/lib/IR')
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp10
-rw-r--r--mlir/lib/IR/AttributeDetail.h37
-rw-r--r--mlir/lib/IR/Attributes.cpp23
-rw-r--r--mlir/lib/IR/Builders.cpp9
-rw-r--r--mlir/lib/IR/FunctionSupport.cpp2
5 files changed, 72 insertions, 9 deletions
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 43452b2712d..6f77de0e721 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -801,9 +801,15 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
case StandardAttributes::Type:
printType(attr.cast<TypeAttr>().getValue());
break;
- case StandardAttributes::SymbolRef:
- printSymbolReference(attr.cast<SymbolRefAttr>().getValue(), os);
+ case StandardAttributes::SymbolRef: {
+ auto refAttr = attr.dyn_cast<SymbolRefAttr>();
+ printSymbolReference(refAttr.getRootReference(), os);
+ for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
+ os << "::";
+ printSymbolReference(nestedRef.getValue(), os);
+ }
break;
+ }
case StandardAttributes::OpaqueElements: {
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 21f8b68c265..da4aa69dda4 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -321,6 +321,43 @@ struct StringAttributeStorage : public AttributeStorage {
StringRef value;
};
+/// An attribute representing a symbol reference.
+struct SymbolRefAttributeStorage final
+ : public AttributeStorage,
+ public llvm::TrailingObjects<SymbolRefAttributeStorage,
+ FlatSymbolRefAttr> {
+ using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>;
+
+ SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs)
+ : value(value), numNestedRefs(numNestedRefs) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(value, getNestedRefs());
+ }
+
+ /// Construct a new storage instance.
+ static SymbolRefAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>(
+ key.second.size());
+ auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage));
+ auto result = ::new (rawMem) SymbolRefAttributeStorage(
+ allocator.copyInto(key.first), key.second.size());
+ std::uninitialized_copy(key.second.begin(), key.second.end(),
+ result->getTrailingObjects<FlatSymbolRefAttr>());
+ return result;
+ }
+
+ /// Returns the set of nested references.
+ ArrayRef<FlatSymbolRefAttr> getNestedRefs() const {
+ return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs};
+ }
+
+ StringRef value;
+ size_t numNestedRefs;
+};
+
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
using KeyTy = Type;
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index d74cacbe695..80ac4a59246 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -249,12 +249,27 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
// SymbolRefAttr
//===----------------------------------------------------------------------===//
-SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
- return Base::get(ctx, StandardAttributes::SymbolRef, value,
- NoneType::get(ctx));
+FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
+ .cast<FlatSymbolRefAttr>();
}
-StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
+SymbolRefAttr SymbolRefAttr::get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences,
+ MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
+}
+
+StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
+
+StringRef SymbolRefAttr::getLeafReference() const {
+ ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
+ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
+}
+
+ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
+ return getImpl()->getNestedRefs();
+}
//===----------------------------------------------------------------------===//
// IntegerAttr
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 24ae2072f77..afdeefd023c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -150,15 +150,20 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
return ArrayAttr::get(value, context);
}
-SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
auto symName =
value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
assert(symName && "value does not have a valid symbol name");
return getSymbolRefAttr(symName.getValue());
}
-SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
return SymbolRefAttr::get(value, getContext());
}
+SymbolRefAttr
+Builder::getSymbolRefAttr(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences) {
+ return SymbolRefAttr::get(value, nestedReferences, getContext());
+}
ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
auto attrs = functional::map(
diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index d1ba2d30fa1..29cae177cec 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -159,7 +159,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
auto &builder = parser.getBuilder();
// Parse the name as a symbol reference attribute.
- SymbolRefAttr nameAttr;
+ FlatSymbolRefAttr nameAttr;
if (parser.parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
OpenPOWER on IntegriCloud