summaryrefslogtreecommitdiffstats
path: root/mlir/lib/IR/SymbolTable.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/IR/SymbolTable.cpp')
-rw-r--r--mlir/lib/IR/SymbolTable.cpp50
1 files changed, 50 insertions, 0 deletions
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 057aeded242..ed9de48c400 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -10,6 +10,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
@@ -179,6 +180,38 @@ void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
StringAttr::get(name, symbol->getContext()));
}
+/// Returns the visibility of the given symbol operation.
+SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
+ // If the attribute doesn't exist, assume public.
+ StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
+ if (!vis)
+ return Visibility::Public;
+
+ // Otherwise, switch on the string value.
+ return llvm::StringSwitch<Visibility>(vis.getValue())
+ .Case("private", Visibility::Private)
+ .Case("nested", Visibility::Nested)
+ .Case("public", Visibility::Public);
+}
+/// Sets the visibility of the given symbol operation.
+void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
+ MLIRContext *ctx = symbol->getContext();
+
+ // If the visibility is public, just drop the attribute as this is the
+ // default.
+ if (vis == Visibility::Public) {
+ symbol->removeAttr(Identifier::get(getVisibilityAttrName(), ctx));
+ return;
+ }
+
+ // Otherwise, update the attribute.
+ assert((vis == Visibility::Private || vis == Visibility::Nested) &&
+ "unknown symbol visibility kind");
+
+ StringRef visName = vis == Visibility::Private ? "private" : "nested";
+ symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
+}
+
/// Returns the operation registered with the given symbol name with the
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
@@ -272,9 +305,26 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
}
LogicalResult OpTrait::impl::verifySymbol(Operation *op) {
+ // Verify the name attribute.
if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
return op->emitOpError() << "requires string attribute '"
<< mlir::SymbolTable::getSymbolAttrName() << "'";
+
+ // Verify the visibility attribute.
+ if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
+ StringAttr visStrAttr = vis.dyn_cast<StringAttr>();
+ if (!visStrAttr)
+ return op->emitOpError() << "requires visibility attribute '"
+ << mlir::SymbolTable::getVisibilityAttrName()
+ << "' to be a string attribute, but got " << vis;
+
+ if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
+ visStrAttr.getValue()))
+ return op->emitOpError()
+ << "visibility expected to be one of [\"public\", \"private\", "
+ "\"nested\"], but got "
+ << visStrAttr;
+ }
return success();
}
OpenPOWER on IntegriCloud