From 9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 11 Nov 2019 18:18:02 -0800 Subject: Add support for nested symbol references. This change allows for adding additional nested references to a SymbolRefAttr to allow for further resolving a symbol if that symbol also defines a SymbolTable. If a referenced symbol also defines a symbol table, a nested reference can be used to refer to a symbol within that table. Nested references are printed after the main reference in the following form: symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)* Example: module @reference { func @nested_reference() } my_reference_op @reference::@nested_reference Given that SymbolRefAttr is now more general, the existing functionality centered around a single reference is moved to a derived class FlatSymbolRefAttr. Followup commits will add support to lookups, rauw, etc. for scoped references. PiperOrigin-RevId: 279860501 --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 3 ++- mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 3 ++- .../Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp | 6 +++--- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 17 +++++++++-------- mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp | 2 +- mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 2 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 10 +++++----- 7 files changed, 23 insertions(+), 20 deletions(-) (limited to 'mlir/lib/Dialect') diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index d70d51feee7..bfd094d6203 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -533,7 +533,8 @@ unsigned LaunchFuncOp::getNumKernelOperands() { } StringRef LaunchFuncOp::getKernelModuleName() { - return getAttrOfType(getKernelModuleAttrName()).getValue(); + return getAttrOfType(getKernelModuleAttrName()) + .getRootReference(); } Value *LaunchFuncOp::getKernelOperand(unsigned i) { diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 672beee56cd..420b2340bc9 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -189,7 +189,8 @@ private: if (Optional symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - StringRef symbolName = symbolUse.getSymbolRef().getValue(); + StringRef symbolName = + symbolUse.getSymbolRef().cast().getValue(); if (moduleManager.lookupSymbol(symbolName)) continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 7a8bc7162af..2dc46bf7b2b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -359,8 +359,8 @@ public: // Get a SymbolRefAttr containing the library function name for the LinalgOp. // If the library function does not exist, insert a declaration. template -static SymbolRefAttr getLibraryCallSymbolRef(Operation *op, - PatternRewriter &rewriter) { +static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, + PatternRewriter &rewriter) { auto linalgOp = cast(op); auto fnName = linalgOp.getLibraryCallName(); if (fnName.empty()) { @@ -369,7 +369,7 @@ static SymbolRefAttr getLibraryCallSymbolRef(Operation *op, } // fnName is a dynamic std::String, unique it via a SymbolRefAttr. - SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); + FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); auto module = op->getParentOfType(); if (module.lookupSymbol(fnName)) { return fnNameAttr; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index b3a9f6f7443..3c1563ed515 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -658,7 +658,7 @@ void spirv::AddressOfOp::build(Builder *builder, OperationState &state, static ParseResult parseAddressOfOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr varRefAttr; + FlatSymbolRefAttr varRefAttr; Type type; if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName, state.attributes) || @@ -1088,7 +1088,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser, SmallVector idTypes; SmallVector interfaceVars; - SymbolRefAttr fn; + FlatSymbolRefAttr fn; if (parseEnumAttribute(execModel, parser, state) || parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) { return failure(); @@ -1099,7 +1099,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser, do { // The name of the interface variable attribute isnt important auto attrName = "var_symbol"; - SymbolRefAttr var; + FlatSymbolRefAttr var; SmallVector attrs; if (parser.parseAttribute(var, Type(), attrName, attrs)) { return failure(); @@ -1186,7 +1186,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { static ParseResult parseFunctionCallOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr calleeAttr; + FlatSymbolRefAttr calleeAttr; FunctionType type; SmallVector operands; auto loc = parser.getNameLoc(); @@ -1305,7 +1305,7 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser, // Parse optional initializer if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) { - SymbolRefAttr initSymbol; + FlatSymbolRefAttr initSymbol; if (parser.parseLParen() || parser.parseAttribute(initSymbol, Type(), kInitializerAttrName, state.attributes) || @@ -1361,7 +1361,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) { if (varOp.storageClass() == spirv::StorageClass::Generic) return varOp.emitOpError("storage class cannot be 'Generic'"); - if (auto init = varOp.getAttrOfType(kInitializerAttrName)) { + if (auto init = + varOp.getAttrOfType(kInitializerAttrName)) { auto moduleOp = varOp.getParentOfType(); auto *initOp = moduleOp.lookupSymbol(init.getValue()); // TODO: Currently only variable initialization with specialization @@ -1713,7 +1714,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { } if (auto interface = entryPointOp.interface()) { for (Attribute varRef : interface) { - auto varSymRef = varRef.dyn_cast(); + auto varSymRef = varRef.dyn_cast(); if (!varSymRef) { return entryPointOp.emitError( "expected symbol reference for interface " @@ -1790,7 +1791,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { static ParseResult parseReferenceOfOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr constRefAttr; + FlatSymbolRefAttr constRefAttr; Type type; if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName, state.attributes) || diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 11660ed4e87..40b53185529 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -970,7 +970,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef operands) { wordIndex++; // Initializer. - SymbolRefAttr initializer = nullptr; + FlatSymbolRefAttr initializer = nullptr; if (wordIndex < operands.size()) { auto initializerOp = getGlobalVariable(operands[wordIndex]); if (!initializerOp) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index f92b9ae62be..805a3393b0c 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1696,7 +1696,7 @@ Serializer::processOp(spirv::EntryPointOp op) { // Add the interface values. if (auto interface = op.interface()) { for (auto var : interface.getValue()) { - auto id = getVariableID(var.cast().getValue()); + auto id = getVariableID(var.cast().getValue()); if (!id) { return op.emitError("referencing undefined global variable." "spv.EntryPoint is at the end of spv.module. All " diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 12029248168..8c08868bc7a 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -528,7 +528,7 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { - SymbolRefAttr calleeAttr; + FlatSymbolRefAttr calleeAttr; FunctionType calleeType; SmallVector operands; auto calleeLoc = parser.getNameLoc(); @@ -555,7 +555,7 @@ static void print(OpAsmPrinter &p, CallOp op) { static LogicalResult verify(CallOp op) { // Check that the callee attribute was specified. - auto fnAttr = op.getAttrOfType("callee"); + auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' symbol reference attribute"); auto fn = @@ -608,8 +608,8 @@ struct SimplifyIndirectCallWithKnownCallee // Replace with a direct call. SmallVector callResults(indirectCall.getResultTypes()); SmallVector callOperands(indirectCall.getArgOperands()); - rewriter.replaceOpWithNewOp(indirectCall, calledFn.getValue(), - callResults, callOperands); + rewriter.replaceOpWithNewOp(indirectCall, calledFn, callResults, + callOperands); return matchSuccess(); } }; @@ -1206,7 +1206,7 @@ static LogicalResult verify(ConstantOp &op) { } if (type.isa()) { - auto fnAttr = value.dyn_cast(); + auto fnAttr = value.dyn_cast(); if (!fnAttr) return op.emitOpError("requires 'value' to be a function reference"); -- cgit v1.2.3