summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/examples/toy/Ch2/include/toy/Ops.td2
-rw-r--r--mlir/examples/toy/Ch3/include/toy/Ops.td2
-rw-r--r--mlir/examples/toy/Ch4/include/toy/Ops.td2
-rw-r--r--mlir/examples/toy/Ch5/include/toy/Ops.td2
-rw-r--r--mlir/examples/toy/Ch6/include/toy/Ops.td2
-rw-r--r--mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp6
-rw-r--r--mlir/examples/toy/Ch7/include/toy/Ops.td2
-rw-r--r--mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp6
-rw-r--r--mlir/g3doc/LangRef.md6
-rw-r--r--mlir/g3doc/Tutorials/Toy/Ch-6.md6
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td4
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td6
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td2
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td2
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td8
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Ops.td9
-rw-r--r--mlir/include/mlir/IR/Attributes.h102
-rw-r--r--mlir/include/mlir/IR/Builders.h6
-rw-r--r--mlir/include/mlir/IR/Function.h2
-rw-r--r--mlir/include/mlir/IR/OpBase.td16
-rw-r--r--mlir/lib/Analysis/CallGraph.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp3
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp6
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp17
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp2
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp10
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp10
-rw-r--r--mlir/lib/IR/AttributeDetail.h37
-rw-r--r--mlir/lib/IR/Attributes.cpp23
-rw-r--r--mlir/lib/IR/Builders.cpp9
-rw-r--r--mlir/lib/IR/FunctionSupport.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp28
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp4
-rw-r--r--mlir/test/IR/parser.mlir4
-rw-r--r--mlir/test/lib/TestDialect/TestOps.td6
-rw-r--r--mlir/test/mlir-tblgen/op-attribute.td6
38 files changed, 286 insertions, 83 deletions
diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td
index 799813b9fed..fb41818bed3 100644
--- a/mlir/examples/toy/Ch2/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch2/include/toy/Ops.td
@@ -123,7 +123,7 @@ def GenericCallOp : Toy_Op<"generic_call"> {
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
- let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td
index 0be2b66a54c..25bf06cd294 100644
--- a/mlir/examples/toy/Ch3/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch3/include/toy/Ops.td
@@ -123,7 +123,7 @@ def GenericCallOp : Toy_Op<"generic_call"> {
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
- let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td
index 27fdc345c25..bbbcc484bae 100644
--- a/mlir/examples/toy/Ch4/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch4/include/toy/Ops.td
@@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call",
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
- let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td
index f609a1469a5..d9306f4cf6b 100644
--- a/mlir/examples/toy/Ch5/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch5/include/toy/Ops.td
@@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call",
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
- let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td
index f609a1469a5..d9306f4cf6b 100644
--- a/mlir/examples/toy/Ch6/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch6/include/toy/Ops.td
@@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call",
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
- let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 7e300fb702d..091eada5ac3 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -107,9 +107,9 @@ public:
private:
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
- static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+ ModuleOp module,
+ LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td
index 5e932bb0a7e..d41406a913b 100644
--- a/mlir/examples/toy/Ch7/include/toy/Ops.td
+++ b/mlir/examples/toy/Ch7/include/toy/Ops.td
@@ -160,7 +160,7 @@ def GenericCallOp : Toy_Op<"generic_call",
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
- let arguments = (ins SymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
// The generic call operation returns a single value of TensorType or
// StructType.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 7e300fb702d..091eada5ac3 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -107,9 +107,9 @@ public:
private:
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
- static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+ static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+ ModuleOp module,
+ LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md
index 391f77325d4..3409b9fac83 100644
--- a/mlir/g3doc/LangRef.md
+++ b/mlir/g3doc/LangRef.md
@@ -1367,13 +1367,15 @@ A string attribute is an attribute that represents a string literal value.
Syntax:
``` {.ebnf}
-symbol-ref-attribute ::= symbol-ref-id
+symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)*
```
A symbol reference attribute is a literal attribute that represents a named
reference to an operation that is nested within an operation with the
`OpTrait::SymbolTable` trait. As such, this reference is given meaning by the
-nearest parent operation containing the `OpTrait::SymbolTable` trait.
+nearest parent operation containing the `OpTrait::SymbolTable` trait. It may
+optionally contain a set of nested references that further resolve to a symbol
+nested within a different symbol table.
This attribute can only be held internally by
[array attributes](#array-attribute) and
diff --git a/mlir/g3doc/Tutorials/Toy/Ch-6.md b/mlir/g3doc/Tutorials/Toy/Ch-6.md
index b01dfde5a9f..49114b45dff 100644
--- a/mlir/g3doc/Tutorials/Toy/Ch-6.md
+++ b/mlir/g3doc/Tutorials/Toy/Ch-6.md
@@ -26,9 +26,9 @@ During lowering we can get, or build, the declaration for printf as so:
```c++
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
-static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
- ModuleOp module,
- LLVM::LLVMDialect *llvmDialect) {
+static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+ ModuleOp module,
+ LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 29509268c7a..ba337e29d30 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -341,7 +341,7 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
// Call-related operations.
def LLVM_CallOp : LLVM_Op<"call">,
- Arguments<(ins OptionalAttr<SymbolRefAttr>:$callee,
+ Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>)>,
Results<(outs Variadic<LLVM_Type>)>,
LLVM_TwoBuilders<LLVM_OneResultOpBuilder,
@@ -479,7 +479,7 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
// to work correctly).
def LLVM_AddressOfOp
: LLVM_OneResultOp<"mlir.addressof">,
- Arguments<(ins SymbolRefAttr:$global_name)> {
+ Arguments<(ins FlatSymbolRefAttr:$global_name)> {
let builders = [
OpBuilder<"Builder *builder, OperationState &result, LLVMType resType, "
"StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
index 467cd08b98a..dd1649862a8 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
@@ -376,7 +376,7 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
I64ArrayAttr:$n_loop_types,
I64ArrayAttr:$n_views,
OptionalAttr<StrAttr>:$doc,
- OptionalAttr<SymbolRefAttr>:$fun,
+ OptionalAttr<FlatSymbolRefAttr>:$fun,
OptionalAttr<StrAttr>:$library_call);
let regions = (region AnyRegion:$region);
let extraClassDeclaration = [{
@@ -464,7 +464,7 @@ def GenericOp : GenericOpBase<"generic"> {
Where #trait_attributes is an alias of a dictionary attribute containing:
- doc [optional]: a documentation string
- - fun: a SymbolRefAttr that must resolve to an existing function symbol.
+ - fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
To support inplace updates in a generic fashion, the signature of the
function must be:
```
@@ -558,7 +558,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
Where #trait_attributes is an alias of a dictionary attribute containing:
- doc [optional]: a documentation string
- - fun: a SymbolRefAttr that must resolve to an existing function symbol.
+ - fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
To support inplace updates in a generic fashion, the signature of the
function must be:
```
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
index 070d6a6c5ad..8de2aebf0b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
@@ -250,7 +250,7 @@ def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [
}];
let arguments = (ins
- SymbolRefAttr:$callee,
+ FlatSymbolRefAttr:$callee,
Variadic<SPV_Type>:$arguments
);
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
index 0e135c573f6..9be4898365b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
@@ -273,7 +273,7 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
}];
let arguments = (ins
- SymbolRefAttr:$fn,
+ FlatSymbolRefAttr:$fn,
SPV_ExecutionModeAttr:$execution_mode,
I32ArrayAttr:$values
);
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index aadd17940af..fab97bdd5eb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -55,7 +55,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
}];
let arguments = (ins
- SymbolRefAttr:$variable
+ FlatSymbolRefAttr:$variable
);
let results = (outs
@@ -174,7 +174,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
let arguments = (ins
SPV_ExecutionModelAttr:$execution_model,
- SymbolRefAttr:$fn,
+ FlatSymbolRefAttr:$fn,
SymbolRefArrayAttr:$interface
);
@@ -237,7 +237,7 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
let arguments = (ins
TypeAttr:$type,
StrAttr:$sym_name,
- OptionalAttr<SymbolRefAttr>:$initializer
+ OptionalAttr<FlatSymbolRefAttr>:$initializer
);
let builders = [
@@ -394,7 +394,7 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
}];
let arguments = (ins
- SymbolRefAttr:$spec_const
+ FlatSymbolRefAttr:$spec_const
);
let results = (outs
diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td
index d7de15576d5..fa7230690f1 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td
@@ -239,14 +239,15 @@ def BranchOp : Std_Op<"br", [Terminator]> {
def CallOp : Std_Op<"call", [CallOpInterface]> {
let summary = "call operation";
let description = [{
- The "call" operation represents a direct call to a function. The operands
- and result types of the call must match the specified function type. The
- callee is encoded as a function attribute named "callee".
+ The "call" operation represents a direct call to a function that is within
+ the same symbol scope as the call. The operands and result types of the
+ call must match the specified function type. The callee is encoded as a
+ function attribute named "callee".
%2 = call @my_add(%0, %1) : (f32, f32) -> f32
}];
- let arguments = (ins SymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+ let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let builders = [OpBuilder<
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 5d98f6e11f1..8a5e3b5321d 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -44,6 +44,7 @@ struct IntegerSetAttributeStorage;
struct FloatAttributeStorage;
struct OpaqueAttributeStorage;
struct StringAttributeStorage;
+struct SymbolRefAttributeStorage;
struct TypeAttributeStorage;
/// Elements Attributes.
@@ -179,6 +180,10 @@ enum Kind {
};
} // namespace StandardAttributes
+//===----------------------------------------------------------------------===//
+// AffineMapAttr
+//===----------------------------------------------------------------------===//
+
class AffineMapAttr
: public Attribute::AttrBase<AffineMapAttr, Attribute,
detail::AffineMapAttributeStorage> {
@@ -196,6 +201,10 @@ public:
}
};
+//===----------------------------------------------------------------------===//
+// ArrayAttr
+//===----------------------------------------------------------------------===//
+
/// Array attributes are lists of other attributes. They are not necessarily
/// type homogenous given that attributes don't, in general, carry types.
class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
@@ -220,6 +229,10 @@ public:
}
};
+//===----------------------------------------------------------------------===//
+// BoolAttr
+//===----------------------------------------------------------------------===//
+
class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
detail::BoolAttributeStorage> {
public:
@@ -234,6 +247,10 @@ public:
static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; }
};
+//===----------------------------------------------------------------------===//
+// DictionaryAttr
+//===----------------------------------------------------------------------===//
+
/// NamedAttribute is used for dictionary attributes, it holds an identifier for
/// the name and a value for the attribute. The attribute pointer should always
/// be non-null.
@@ -271,6 +288,10 @@ public:
}
};
+//===----------------------------------------------------------------------===//
+// FloatAttr
+//===----------------------------------------------------------------------===//
+
class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
detail::FloatAttributeStorage> {
public:
@@ -308,6 +329,10 @@ public:
Type type, const APFloat &value);
};
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
class IntegerAttr
: public Attribute::AttrBase<IntegerAttr, Attribute,
detail::IntegerAttributeStorage> {
@@ -328,6 +353,10 @@ public:
}
};
+//===----------------------------------------------------------------------===//
+// IntegerSetAttr
+//===----------------------------------------------------------------------===//
+
class IntegerSetAttr
: public Attribute::AttrBase<IntegerSetAttr, Attribute,
detail::IntegerSetAttributeStorage> {
@@ -345,6 +374,10 @@ public:
}
};
+//===----------------------------------------------------------------------===//
+// OpaqueAttr
+//===----------------------------------------------------------------------===//
+
/// Opaque attributes represent attributes of non-registered dialects. These are
/// attribute represented in their raw string form, and can only usefully be
/// tested for attribute equality.
@@ -380,6 +413,10 @@ public:
}
};
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
detail::StringAttributeStorage> {
public:
@@ -400,19 +437,40 @@ public:
}
};
+//===----------------------------------------------------------------------===//
+// SymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+class FlatSymbolRefAttr;
+
/// A symbol reference attribute represents a symbolic reference to another
/// operation.
class SymbolRefAttr
: public Attribute::AttrBase<SymbolRefAttr, Attribute,
- detail::StringAttributeStorage> {
+ detail::SymbolRefAttributeStorage> {
public:
using Base::Base;
- using ValueType = StringRef;
- static SymbolRefAttr get(StringRef value, MLIRContext *ctx);
+ /// Construct a symbol reference for the given value name.
+ static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
- /// Returns the name of the held symbol reference.
- StringRef getValue() const;
+ /// Construct a symbol reference for the given value name, and a set of nested
+ /// references that are further resolve to a nested symbol.
+ static SymbolRefAttr get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> references,
+ MLIRContext *ctx);
+
+ /// Returns the name of the top level symbol reference, i.e. the root of the
+ /// reference path.
+ StringRef getRootReference() const;
+
+ /// Returns the name of the fully resolved symbol, i.e. the leaf of the
+ /// reference path.
+ StringRef getLeafReference() const;
+
+ /// Returns the set of nested references representing the path to the symbol
+ /// nested under the root reference.
+ ArrayRef<FlatSymbolRefAttr> getNestedReferences() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
@@ -420,6 +478,36 @@ public:
}
};
+/// A symbol reference with a reference path containing a single element. This
+/// is used to refer to an operation within the current symbol table.
+class FlatSymbolRefAttr : public SymbolRefAttr {
+public:
+ using SymbolRefAttr::SymbolRefAttr;
+ using ValueType = StringRef;
+
+ /// Construct a symbol reference for the given value name.
+ static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
+ return SymbolRefAttr::get(value, ctx);
+ }
+
+ /// Returns the name of the held symbol reference.
+ StringRef getValue() const { return getRootReference(); }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(Attribute attr) {
+ SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
+ return refAttr && refAttr.getNestedReferences().empty();
+ }
+
+private:
+ using SymbolRefAttr::get;
+ using SymbolRefAttr::getNestedReferences;
+};
+
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
detail::TypeAttributeStorage> {
public:
@@ -434,6 +522,10 @@ public:
static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; }
};
+//===----------------------------------------------------------------------===//
+// UnitAttr
+//===----------------------------------------------------------------------===//
+
/// Unit attributes are attributes that hold no specific value and are given
/// meaning by their existence.
class UnitAttr : public Attribute::AttrBase<UnitAttr> {
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 0005a395e70..01ad38cfc11 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -100,8 +100,10 @@ public:
FloatAttr getFloatAttr(Type type, const APFloat &value);
StringAttr getStringAttr(StringRef bytes);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
- SymbolRefAttr getSymbolRefAttr(Operation *value);
- SymbolRefAttr getSymbolRefAttr(StringRef value);
+ FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
+ FlatSymbolRefAttr getSymbolRefAttr(StringRef value);
+ SymbolRefAttr getSymbolRefAttr(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences);
// Returns a 0-valued attribute of the given `type`. This function only
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h
index 0f435b2a6e6..228b0302ebb 100644
--- a/mlir/include/mlir/IR/Function.h
+++ b/mlir/include/mlir/IR/Function.h
@@ -119,7 +119,7 @@ public:
/// to. This may return null in the case of an external callable object, e.g.
/// an external function.
Region *getCallableRegion(CallInterfaceCallable callable) {
- assert(callable.get<SymbolRefAttr>().getValue() == getName());
+ assert(callable.get<SymbolRefAttr>().getLeafReference() == getName());
return isExternal() ? nullptr : &getBody();
}
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 62d9542fa6e..2f675608dd2 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1139,8 +1139,16 @@ class StructAttr<string name, Dialect dialect,
def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
"symbol reference attribute"> {
let storageType = [{ SymbolRefAttr }];
+ let returnType = [{ SymbolRefAttr }];
+ let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
+ let convertFromStorage = "$_self";
+}
+def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
+ "flat symbol reference attribute"> {
+ let storageType = [{ FlatSymbolRefAttr }];
let returnType = [{ StringRef }];
let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
+ let convertFromStorage = "$_self.getValue()";
}
def SymbolRefArrayAttr :
@@ -1241,12 +1249,14 @@ class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
def IsNullAttr : AttrConstraint<
CPred<"!$_self">, "empty attribute (for optional attributes)">;
-// An attribute constraint on SymbolRefAttr that requires the SymbolRefAttr
-// pointing to an op of `opClass` within the closest parent with a symbol table.
+// An attribute constraint on FlatSymbolRefAttr that requires that the
+// reference point to an op of `opClass` within the closest parent with a symbol
+// table.
+// TODO(riverriddle) Add support for nested symbol references.
class ReferToOp<string opClass> : AttrConstraint<
CPred<"isa_and_nonnull<" # opClass # ">("
"::mlir::SymbolTable::lookupNearestSymbolFrom("
- "&$_op, $_self.cast<SymbolRefAttr>().getValue()))">,
+ "&$_op, $_self.cast<FlatSymbolRefAttr>().getValue()))">,
"referencing to a '" # opClass # "' symbol">;
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 3c02802dd35..2b5894ff162 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -184,7 +184,9 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
// Get the callee operation from the callable.
Operation *callee;
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
- callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef.getValue());
+ // TODO(riverriddle) Support nested references.
+ callee = SymbolTable::lookupNearestSymbolFrom(from,
+ symbolRef.getRootReference());
else
callee = callable.get<Value *>()->getDefiningOp();
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d70d51feee7..bfd094d6203 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -533,7 +533,8 @@ unsigned LaunchFuncOp::getNumKernelOperands() {
}
StringRef LaunchFuncOp::getKernelModuleName() {
- return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue();
+ return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName())
+ .getRootReference();
}
Value *LaunchFuncOp::getKernelOperand(unsigned i) {
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 672beee56cd..420b2340bc9 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -189,7 +189,8 @@ private:
if (Optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName = symbolUse.getSymbolRef().getValue();
+ StringRef symbolName =
+ symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
if (moduleManager.lookupSymbol(symbolName))
continue;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
index 7a8bc7162af..2dc46bf7b2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -359,8 +359,8 @@ public:
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
// If the library function does not exist, insert a declaration.
template <typename LinalgOp>
-static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
- PatternRewriter &rewriter) {
+static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
+ PatternRewriter &rewriter) {
auto linalgOp = cast<LinalgOp>(op);
auto fnName = linalgOp.getLibraryCallName();
if (fnName.empty()) {
@@ -369,7 +369,7 @@ static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
}
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
- SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
+ FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
auto module = op->getParentOfType<ModuleOp>();
if (module.lookupSymbol(fnName)) {
return fnNameAttr;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index b3a9f6f7443..3c1563ed515 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -658,7 +658,7 @@ void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
static ParseResult parseAddressOfOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr varRefAttr;
+ FlatSymbolRefAttr varRefAttr;
Type type;
if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName,
state.attributes) ||
@@ -1088,7 +1088,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
SmallVector<Type, 0> idTypes;
SmallVector<Attribute, 4> interfaceVars;
- SymbolRefAttr fn;
+ FlatSymbolRefAttr fn;
if (parseEnumAttribute(execModel, parser, state) ||
parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
return failure();
@@ -1099,7 +1099,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
do {
// The name of the interface variable attribute isnt important
auto attrName = "var_symbol";
- SymbolRefAttr var;
+ FlatSymbolRefAttr var;
SmallVector<NamedAttribute, 1> attrs;
if (parser.parseAttribute(var, Type(), attrName, attrs)) {
return failure();
@@ -1186,7 +1186,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
static ParseResult parseFunctionCallOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr calleeAttr;
+ FlatSymbolRefAttr calleeAttr;
FunctionType type;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto loc = parser.getNameLoc();
@@ -1305,7 +1305,7 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
// Parse optional initializer
if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
- SymbolRefAttr initSymbol;
+ FlatSymbolRefAttr initSymbol;
if (parser.parseLParen() ||
parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
state.attributes) ||
@@ -1361,7 +1361,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
if (varOp.storageClass() == spirv::StorageClass::Generic)
return varOp.emitOpError("storage class cannot be 'Generic'");
- if (auto init = varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
+ if (auto init =
+ varOp.getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
auto *initOp = moduleOp.lookupSymbol(init.getValue());
// TODO: Currently only variable initialization with specialization
@@ -1713,7 +1714,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
}
if (auto interface = entryPointOp.interface()) {
for (Attribute varRef : interface) {
- auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
+ auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
if (!varSymRef) {
return entryPointOp.emitError(
"expected symbol reference for interface "
@@ -1790,7 +1791,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
static ParseResult parseReferenceOfOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr constRefAttr;
+ FlatSymbolRefAttr constRefAttr;
Type type;
if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
state.attributes) ||
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 11660ed4e87..40b53185529 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -970,7 +970,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
wordIndex++;
// Initializer.
- SymbolRefAttr initializer = nullptr;
+ FlatSymbolRefAttr initializer = nullptr;
if (wordIndex < operands.size()) {
auto initializerOp = getGlobalVariable(operands[wordIndex]);
if (!initializerOp) {
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index f92b9ae62be..805a3393b0c 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1696,7 +1696,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
// Add the interface values.
if (auto interface = op.interface()) {
for (auto var : interface.getValue()) {
- auto id = getVariableID(var.cast<SymbolRefAttr>().getValue());
+ auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
if (!id) {
return op.emitError("referencing undefined global variable."
"spv.EntryPoint is at the end of spv.module. All "
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 12029248168..8c08868bc7a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -528,7 +528,7 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
- SymbolRefAttr calleeAttr;
+ FlatSymbolRefAttr calleeAttr;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto calleeLoc = parser.getNameLoc();
@@ -555,7 +555,7 @@ static void print(OpAsmPrinter &p, CallOp op) {
static LogicalResult verify(CallOp op) {
// Check that the callee attribute was specified.
- auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
+ auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' symbol reference attribute");
auto fn =
@@ -608,8 +608,8 @@ struct SimplifyIndirectCallWithKnownCallee
// Replace with a direct call.
SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
- callResults, callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
+ callOperands);
return matchSuccess();
}
};
@@ -1206,7 +1206,7 @@ static LogicalResult verify(ConstantOp &op) {
}
if (type.isa<FunctionType>()) {
- auto fnAttr = value.dyn_cast<SymbolRefAttr>();
+ auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
return op.emitOpError("requires 'value' to be a function reference");
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 43452b2712d..6f77de0e721 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -801,9 +801,15 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
case StandardAttributes::Type:
printType(attr.cast<TypeAttr>().getValue());
break;
- case StandardAttributes::SymbolRef:
- printSymbolReference(attr.cast<SymbolRefAttr>().getValue(), os);
+ case StandardAttributes::SymbolRef: {
+ auto refAttr = attr.dyn_cast<SymbolRefAttr>();
+ printSymbolReference(refAttr.getRootReference(), os);
+ for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
+ os << "::";
+ printSymbolReference(nestedRef.getValue(), os);
+ }
break;
+ }
case StandardAttributes::OpaqueElements: {
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 21f8b68c265..da4aa69dda4 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -321,6 +321,43 @@ struct StringAttributeStorage : public AttributeStorage {
StringRef value;
};
+/// An attribute representing a symbol reference.
+struct SymbolRefAttributeStorage final
+ : public AttributeStorage,
+ public llvm::TrailingObjects<SymbolRefAttributeStorage,
+ FlatSymbolRefAttr> {
+ using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>;
+
+ SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs)
+ : value(value), numNestedRefs(numNestedRefs) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(value, getNestedRefs());
+ }
+
+ /// Construct a new storage instance.
+ static SymbolRefAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>(
+ key.second.size());
+ auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage));
+ auto result = ::new (rawMem) SymbolRefAttributeStorage(
+ allocator.copyInto(key.first), key.second.size());
+ std::uninitialized_copy(key.second.begin(), key.second.end(),
+ result->getTrailingObjects<FlatSymbolRefAttr>());
+ return result;
+ }
+
+ /// Returns the set of nested references.
+ ArrayRef<FlatSymbolRefAttr> getNestedRefs() const {
+ return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs};
+ }
+
+ StringRef value;
+ size_t numNestedRefs;
+};
+
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
using KeyTy = Type;
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index d74cacbe695..80ac4a59246 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -249,12 +249,27 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
// SymbolRefAttr
//===----------------------------------------------------------------------===//
-SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
- return Base::get(ctx, StandardAttributes::SymbolRef, value,
- NoneType::get(ctx));
+FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
+ .cast<FlatSymbolRefAttr>();
}
-StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
+SymbolRefAttr SymbolRefAttr::get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences,
+ MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
+}
+
+StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
+
+StringRef SymbolRefAttr::getLeafReference() const {
+ ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
+ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
+}
+
+ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
+ return getImpl()->getNestedRefs();
+}
//===----------------------------------------------------------------------===//
// IntegerAttr
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 24ae2072f77..afdeefd023c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -150,15 +150,20 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
return ArrayAttr::get(value, context);
}
-SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
auto symName =
value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
assert(symName && "value does not have a valid symbol name");
return getSymbolRefAttr(symName.getValue());
}
-SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
return SymbolRefAttr::get(value, getContext());
}
+SymbolRefAttr
+Builder::getSymbolRefAttr(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences) {
+ return SymbolRefAttr::get(value, nestedReferences, getContext());
+}
ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
auto attrs = functional::map(
diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index d1ba2d30fa1..29cae177cec 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -159,7 +159,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
auto &builder = parser.getBuilder();
// Parse the name as a symbol reference attribute.
- SymbolRefAttr nameAttr;
+ FlatSymbolRefAttr nameAttr;
if (parser.parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 35c694b6a43..2843aae4bb8 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1400,7 +1400,7 @@ static std::string extractSymbolReference(Token tok) {
/// | type
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
-/// | symbol-ref-id
+/// | symbol-ref-id (`::` symbol-ref-id)*
/// | `dense` `<` attribute-value `>` `:`
/// (tensor-type | vector-type)
/// | `sparse` `<` attribute-value `,` attribute-value `>`
@@ -1509,7 +1509,31 @@ Attribute Parser::parseAttribute(Type type) {
case Token::at_identifier: {
std::string nameStr = extractSymbolReference(getToken());
consumeToken(Token::at_identifier);
- return builder.getSymbolRefAttr(nameStr);
+
+ // Parse any nested references.
+ std::vector<FlatSymbolRefAttr> nestedRefs;
+ while (getToken().is(Token::colon)) {
+ // Check for the '::' prefix.
+ const char *curPointer = getToken().getLoc().getPointer();
+ consumeToken(Token::colon);
+ if (!consumeIf(Token::colon)) {
+ state.lex.resetPointer(curPointer);
+ consumeToken();
+ break;
+ }
+ // Parse the reference itself.
+ auto curLoc = getToken().getLoc();
+ if (getToken().isNot(Token::at_identifier)) {
+ emitError(curLoc, "expected nested symbol reference identifier");
+ return Attribute();
+ }
+
+ std::string nameStr = extractSymbolReference(getToken());
+ consumeToken(Token::at_identifier);
+ nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+ }
+
+ return builder.getSymbolRefAttr(nameStr, nestedRefs);
}
// Parse a 'unit' attribute.
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 69f7e933d49..7f3ce5a738f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -52,7 +52,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
return llvm::ConstantInt::get(llvmType, intAttr.getValue());
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
- if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>())
+ if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
return functionMapping.lookup(funcAttr.getValue());
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
auto *sequentialType = cast<llvm::SequentialType>(llvmType);
@@ -194,7 +194,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
auto operands = lookupValues(op.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
- if (auto attr = op.getAttrOfType<SymbolRefAttr>("callee")) {
+ if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee")) {
return builder.CreateCall(functionMapping.lookup(attr.getValue()),
operandsRef);
} else {
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 37f85e71ce9..dc85fbb14b3 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1112,3 +1112,7 @@ func @"\"_string_symbol_reference\""() {
"foo.symbol_reference"() {ref = @"\"_string_symbol_reference\""} : () -> ()
return
}
+
+// CHECK-LABEL: func @nested_reference
+// CHECK-NEXT: ref = @some_symbol::@some_nested_symbol
+func @nested_reference() attributes {test.ref = @some_symbol::@some_nested_symbol }
diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td
index 4071f7e232f..2972793bf02 100644
--- a/mlir/test/lib/TestDialect/TestOps.td
+++ b/mlir/test/lib/TestDialect/TestOps.td
@@ -206,7 +206,7 @@ def UpdateFloatElementsAttr : Pat<
def SymbolRefOp : TEST_Op<"symbol_ref_attr"> {
let arguments = (ins
- Confined<SymbolRefAttr, [ReferToOp<"FuncOp">]>:$symbol
+ Confined<FlatSymbolRefAttr, [ReferToOp<"FuncOp">]>:$symbol
);
}
@@ -232,7 +232,7 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
def ConversionCallOp : TEST_Op<"conversion_call_op",
[CallOpInterface]> {
- let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
+ let arguments = (ins Variadic<AnyType>:$inputs, FlatSymbolRefAttr:$callee);
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
@@ -241,7 +241,7 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
/// Return the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
- return getAttrOfType<SymbolRefAttr>("callee");
+ return getAttrOfType<FlatSymbolRefAttr>("callee");
}
}];
}
diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td
index 82702bfe29e..7fe249b9159 100644
--- a/mlir/test/mlir-tblgen/op-attribute.td
+++ b/mlir/test/mlir-tblgen/op-attribute.td
@@ -72,7 +72,7 @@ def AOp : NS_Op<"a_op", []> {
// CHECK: auto tblgen_cAttr = this->getAttr("cAttr");
// CHECK-NEXT: if (tblgen_cAttr) {
// CHECK-NEXT: if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind");
-
+
def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
def BOp : NS_Op<"b_op", []> {
@@ -85,7 +85,7 @@ def BOp : NS_Op<"b_op", []> {
F64Attr:$f64_attr,
StrAttr:$str_attr,
ElementsAttr:$elements_attr,
- SymbolRefAttr:$function_attr,
+ FlatSymbolRefAttr:$function_attr,
SomeTypeAttr:$type_attr,
ArrayAttr:$array_attr,
TypedArrayAttrBase<SomeAttr, "SomeAttr array">:$some_attr_array,
@@ -122,7 +122,7 @@ def BOp : NS_Op<"b_op", []> {
// CHECK: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
// CHECK: if (!((tblgen_str_attr.isa<StringAttr>())))
// CHECK: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
-// CHECK: if (!((tblgen_function_attr.isa<SymbolRefAttr>())))
+// CHECK: if (!((tblgen_function_attr.isa<FlatSymbolRefAttr>())))
// CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
// CHECK: if (!((tblgen_array_attr.isa<ArrayAttr>())))
// CHECK: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))
OpenPOWER on IntegriCloud