summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-11 18:18:02 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-11 18:18:31 -0800
commit9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf (patch)
tree2fd6e7aeaa1e41a5e5a6355f860f35bc63ca8d99 /mlir/lib/Dialect
parent5cf6e0ce7f03f9841675b1a9d44232540f3df5cc (diff)
downloadbcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.tar.gz
bcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.zip
Add support for nested symbol references.
This change allows for adding additional nested references to a SymbolRefAttr to allow for further resolving a symbol if that symbol also defines a SymbolTable. If a referenced symbol also defines a symbol table, a nested reference can be used to refer to a symbol within that table. Nested references are printed after the main reference in the following form: symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)* Example: module @reference { func @nested_reference() } my_reference_op @reference::@nested_reference Given that SymbolRefAttr is now more general, the existing functionality centered around a single reference is moved to a derived class FlatSymbolRefAttr. Followup commits will add support to lookups, rauw, etc. for scoped references. PiperOrigin-RevId: 279860501
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp3
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp6
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp17
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp2
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp10
7 files changed, 23 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d70d51feee7..bfd094d6203 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -533,7 +533,8 @@ unsigned LaunchFuncOp::getNumKernelOperands() {
}
StringRef LaunchFuncOp::getKernelModuleName() {
- return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue();
+ return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName())
+ .getRootReference();
}
Value *LaunchFuncOp::getKernelOperand(unsigned i) {
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 672beee56cd..420b2340bc9 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -189,7 +189,8 @@ private:
if (Optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName = symbolUse.getSymbolRef().getValue();
+ StringRef symbolName =
+ symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
if (moduleManager.lookupSymbol(symbolName))
continue;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
index 7a8bc7162af..2dc46bf7b2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -359,8 +359,8 @@ public:
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
// If the library function does not exist, insert a declaration.
template <typename LinalgOp>
-static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
- PatternRewriter &rewriter) {
+static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
+ PatternRewriter &rewriter) {
auto linalgOp = cast<LinalgOp>(op);
auto fnName = linalgOp.getLibraryCallName();
if (fnName.empty()) {
@@ -369,7 +369,7 @@ static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
}
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
- SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
+ FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
auto module = op->getParentOfType<ModuleOp>();
if (module.lookupSymbol(fnName)) {
return fnNameAttr;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index b3a9f6f7443..3c1563ed515 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -658,7 +658,7 @@ void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
static ParseResult parseAddressOfOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr varRefAttr;
+ FlatSymbolRefAttr varRefAttr;
Type type;
if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName,
state.attributes) ||
@@ -1088,7 +1088,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
SmallVector<Type, 0> idTypes;
SmallVector<Attribute, 4> interfaceVars;
- SymbolRefAttr fn;
+ FlatSymbolRefAttr fn;
if (parseEnumAttribute(execModel, parser, state) ||
parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
return failure();
@@ -1099,7 +1099,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
do {
// The name of the interface variable attribute isnt important
auto attrName = "var_symbol";
- SymbolRefAttr var;
+ FlatSymbolRefAttr var;
SmallVector<NamedAttribute, 1> attrs;
if (parser.parseAttribute(var, Type(), attrName, attrs)) {
return failure();
@@ -1186,7 +1186,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
static ParseResult parseFunctionCallOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr calleeAttr;
+ FlatSymbolRefAttr calleeAttr;
FunctionType type;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto loc = parser.getNameLoc();
@@ -1305,7 +1305,7 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
// Parse optional initializer
if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
- SymbolRefAttr initSymbol;
+ FlatSymbolRefAttr initSymbol;
if (parser.parseLParen() ||
parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
state.attributes) ||
@@ -1361,7 +1361,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
if (varOp.storageClass() == spirv::StorageClass::Generic)
return varOp.emitOpError("storage class cannot be 'Generic'");
- if (auto init = varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
+ if (auto init =
+ varOp.getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
auto *initOp = moduleOp.lookupSymbol(init.getValue());
// TODO: Currently only variable initialization with specialization
@@ -1713,7 +1714,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
}
if (auto interface = entryPointOp.interface()) {
for (Attribute varRef : interface) {
- auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
+ auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
if (!varSymRef) {
return entryPointOp.emitError(
"expected symbol reference for interface "
@@ -1790,7 +1791,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
static ParseResult parseReferenceOfOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr constRefAttr;
+ FlatSymbolRefAttr constRefAttr;
Type type;
if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
state.attributes) ||
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 11660ed4e87..40b53185529 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -970,7 +970,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
wordIndex++;
// Initializer.
- SymbolRefAttr initializer = nullptr;
+ FlatSymbolRefAttr initializer = nullptr;
if (wordIndex < operands.size()) {
auto initializerOp = getGlobalVariable(operands[wordIndex]);
if (!initializerOp) {
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index f92b9ae62be..805a3393b0c 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1696,7 +1696,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
// Add the interface values.
if (auto interface = op.interface()) {
for (auto var : interface.getValue()) {
- auto id = getVariableID(var.cast<SymbolRefAttr>().getValue());
+ auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
if (!id) {
return op.emitError("referencing undefined global variable."
"spv.EntryPoint is at the end of spv.module. All "
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 12029248168..8c08868bc7a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -528,7 +528,7 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
- SymbolRefAttr calleeAttr;
+ FlatSymbolRefAttr calleeAttr;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto calleeLoc = parser.getNameLoc();
@@ -555,7 +555,7 @@ static void print(OpAsmPrinter &p, CallOp op) {
static LogicalResult verify(CallOp op) {
// Check that the callee attribute was specified.
- auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
+ auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' symbol reference attribute");
auto fn =
@@ -608,8 +608,8 @@ struct SimplifyIndirectCallWithKnownCallee
// Replace with a direct call.
SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
- callResults, callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
+ callOperands);
return matchSuccess();
}
};
@@ -1206,7 +1206,7 @@ static LogicalResult verify(ConstantOp &op) {
}
if (type.isa<FunctionType>()) {
- auto fnAttr = value.dyn_cast<SymbolRefAttr>();
+ auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
return op.emitOpError("requires 'value' to be a function reference");
OpenPOWER on IntegriCloud