diff options
Diffstat (limited to 'mlir/lib/IR/SymbolTable.cpp')
-rw-r--r-- | mlir/lib/IR/SymbolTable.cpp | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 057aeded242..ed9de48c400 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -10,6 +10,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringSwitch.h" using namespace mlir; @@ -179,6 +180,38 @@ void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { StringAttr::get(name, symbol->getContext())); } +/// Returns the visibility of the given symbol operation. +SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) { + // If the attribute doesn't exist, assume public. + StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName()); + if (!vis) + return Visibility::Public; + + // Otherwise, switch on the string value. + return llvm::StringSwitch<Visibility>(vis.getValue()) + .Case("private", Visibility::Private) + .Case("nested", Visibility::Nested) + .Case("public", Visibility::Public); +} +/// Sets the visibility of the given symbol operation. +void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { + MLIRContext *ctx = symbol->getContext(); + + // If the visibility is public, just drop the attribute as this is the + // default. + if (vis == Visibility::Public) { + symbol->removeAttr(Identifier::get(getVisibilityAttrName(), ctx)); + return; + } + + // Otherwise, update the attribute. + assert((vis == Visibility::Private || vis == Visibility::Nested) && + "unknown symbol visibility kind"); + + StringRef visName = vis == Visibility::Private ? "private" : "nested"; + symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); +} + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol @@ -272,9 +305,26 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) { } LogicalResult OpTrait::impl::verifySymbol(Operation *op) { + // Verify the name attribute. if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName())) return op->emitOpError() << "requires string attribute '" << mlir::SymbolTable::getSymbolAttrName() << "'"; + + // Verify the visibility attribute. + if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) { + StringAttr visStrAttr = vis.dyn_cast<StringAttr>(); + if (!visStrAttr) + return op->emitOpError() << "requires visibility attribute '" + << mlir::SymbolTable::getVisibilityAttrName() + << "' to be a string attribute, but got " << vis; + + if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"}, + visStrAttr.getValue())) + return op->emitOpError() + << "visibility expected to be one of [\"public\", \"private\", " + "\"nested\"], but got " + << visStrAttr; + } return success(); } |