summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/SymbolTable.h38
-rw-r--r--mlir/lib/IR/Module.cpp7
-rw-r--r--mlir/lib/IR/SymbolTable.cpp50
-rw-r--r--mlir/test/IR/traits.mlir24
-rw-r--r--mlir/test/lib/TestDialect/TestOps.td8
5 files changed, 124 insertions, 3 deletions
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 2df39ea1b73..1b93b66b687 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -46,10 +46,31 @@ public:
/// Returns the associated operation.
Operation *getOp() const { return symbolTableOp; }
+ /// Return the name of the attribute used for symbol visibility.
+ static StringRef getVisibilityAttrName() { return "sym_visibility"; }
+
//===--------------------------------------------------------------------===//
// Symbol Utilities
//===--------------------------------------------------------------------===//
+ /// An enumeration detailing the different visibility types that a symbol may
+ /// have.
+ enum class Visibility {
+ /// The symbol is public and may be referenced anywhere internal or external
+ /// to the visible references in the IR.
+ Public,
+
+ /// The symbol is private and may only be referenced by SymbolRefAttrs local
+ /// to the operations within the current symbol table.
+ Private,
+
+ /// The symbol is visible to the current IR, which may include operations in
+ /// symbol tables above the one that owns the current symbol. `Nested`
+ /// visibility allows for referencing a symbol outside of its current symbol
+ /// table, while retaining the ability to observe all uses.
+ Nested,
+ };
+
/// Returns true if the given operation defines a symbol.
static bool isSymbol(Operation *op);
@@ -58,6 +79,11 @@ public:
/// Sets the name of the given symbol operation.
static void setSymbolName(Operation *symbol, StringRef name);
+ /// Returns the visibility of the given symbol operation.
+ static Visibility getSymbolVisibility(Operation *symbol);
+ /// Sets the visibility of the given symbol operation.
+ static void setSymbolVisibility(Operation *symbol, Visibility vis);
+
/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait.
@@ -200,6 +226,8 @@ public:
template <typename ConcreteType>
class Symbol : public TraitBase<ConcreteType, Symbol> {
public:
+ using Visibility = mlir::SymbolTable::Visibility;
+
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySymbol(op);
}
@@ -219,6 +247,16 @@ public:
StringAttr::get(name, this->getOperation()->getContext()));
}
+ /// Returns the visibility of the current symbol.
+ Visibility getVisibility() {
+ return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
+ }
+
+ /// Sets the visibility of the current symbol.
+ void setVisibility(Visibility vis) {
+ mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
+ }
+
/// Get all of the uses of the current symbol that are nested within the given
/// operation 'from'.
/// Note: See mlir::SymbolTable::getSymbolUses for more details.
diff --git a/mlir/lib/IR/Module.cpp b/mlir/lib/IR/Module.cpp
index c5af227459c..e0caeda5ed8 100644
--- a/mlir/lib/IR/Module.cpp
+++ b/mlir/lib/IR/Module.cpp
@@ -82,10 +82,13 @@ LogicalResult ModuleOp::verify() {
return emitOpError("expected body to have no arguments");
// Check that none of the attributes are non-dialect attributes, except for
- // the symbol name attribute.
+ // the symbol related attributes.
for (auto attr : getOperation()->getAttrList().getAttrs()) {
if (!attr.first.strref().contains('.') &&
- attr.first.strref() != mlir::SymbolTable::getSymbolAttrName())
+ !llvm::is_contained(
+ ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
+ mlir::SymbolTable::getVisibilityAttrName()},
+ attr.first.strref()))
return emitOpError(
"can only contain dialect-specific attributes, found: '")
<< attr.first << "'";
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 057aeded242..ed9de48c400 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -10,6 +10,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
@@ -179,6 +180,38 @@ void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
StringAttr::get(name, symbol->getContext()));
}
+/// Returns the visibility of the given symbol operation.
+SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
+ // If the attribute doesn't exist, assume public.
+ StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
+ if (!vis)
+ return Visibility::Public;
+
+ // Otherwise, switch on the string value.
+ return llvm::StringSwitch<Visibility>(vis.getValue())
+ .Case("private", Visibility::Private)
+ .Case("nested", Visibility::Nested)
+ .Case("public", Visibility::Public);
+}
+/// Sets the visibility of the given symbol operation.
+void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
+ MLIRContext *ctx = symbol->getContext();
+
+ // If the visibility is public, just drop the attribute as this is the
+ // default.
+ if (vis == Visibility::Public) {
+ symbol->removeAttr(Identifier::get(getVisibilityAttrName(), ctx));
+ return;
+ }
+
+ // Otherwise, update the attribute.
+ assert((vis == Visibility::Private || vis == Visibility::Nested) &&
+ "unknown symbol visibility kind");
+
+ StringRef visName = vis == Visibility::Private ? "private" : "nested";
+ symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
+}
+
/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
@@ -272,9 +305,26 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
}
LogicalResult OpTrait::impl::verifySymbol(Operation *op) {
+ // Verify the name attribute.
if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
return op->emitOpError() << "requires string attribute '"
<< mlir::SymbolTable::getSymbolAttrName() << "'";
+
+ // Verify the visibility attribute.
+ if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
+ StringAttr visStrAttr = vis.dyn_cast<StringAttr>();
+ if (!visStrAttr)
+ return op->emitOpError() << "requires visibility attribute '"
+ << mlir::SymbolTable::getVisibilityAttrName()
+ << "' to be a string attribute, but got " << vis;
+
+ if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
+ visStrAttr.getValue()))
+ return op->emitOpError()
+ << "visibility expected to be one of [\"public\", \"private\", "
+ "\"nested\"], but got "
+ << visStrAttr;
+ }
return success();
}
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 794ed4cd4f7..42044bde5dc 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -207,6 +207,30 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() {
// -----
+// Test the invariants of operations with the Symbol Trait.
+
+// expected-error@+1 {{requires string attribute 'sym_name'}}
+"test.symbol"() {} : () -> ()
+
+// -----
+
+// expected-error@+1 {{requires visibility attribute 'sym_visibility' to be a string attribute}}
+"test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> ()
+
+// -----
+
+// expected-error@+1 {{visibility expected to be one of ["public", "private", "nested"]}}
+"test.symbol"() {sym_name = "foo_2", sym_visibility = "foo"} : () -> ()
+
+// -----
+
+"test.symbol"() {sym_name = "foo_3", sym_visibility = "nested"} : () -> ()
+"test.symbol"() {sym_name = "foo_4", sym_visibility = "private"} : () -> ()
+"test.symbol"() {sym_name = "foo_5", sym_visibility = "public"} : () -> ()
+"test.symbol"() {sym_name = "foo_6"} : () -> ()
+
+// -----
+
// Test that operation with the SymbolTable Trait define a new symbol scope.
"test.symbol_scope"() ({
func @foo() {
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index f10991dfe5b..0dd32d74a75 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -74,9 +74,15 @@ def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> {
}
//===----------------------------------------------------------------------===//
-// Test Operands
+// Test Symbols
//===----------------------------------------------------------------------===//
+def SymbolOp : TEST_Op<"symbol", [Symbol]> {
+ let summary = "operation which defines a new symbol";
+ let arguments = (ins StrAttr:$sym_name,
+ OptionalAttr<StrAttr>:$sym_visibility);
+}
+
def SymbolScopeOp : TEST_Op<"symbol_scope",
[SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "operation which defines a new symbol table";
OpenPOWER on IntegriCloud