summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
authorRiver Riddle <riverriddle@google.com>2019-11-11 18:18:02 -0800
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-11-11 18:18:31 -0800
commit9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf (patch)
tree2fd6e7aeaa1e41a5e5a6355f860f35bc63ca8d99 /mlir/lib
parent5cf6e0ce7f03f9841675b1a9d44232540f3df5cc (diff)
downloadbcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.tar.gz
bcm5719-llvm-9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf.zip
Add support for nested symbol references.
This change allows for adding additional nested references to a SymbolRefAttr to allow for further resolving a symbol if that symbol also defines a SymbolTable. If a referenced symbol also defines a symbol table, a nested reference can be used to refer to a symbol within that table. Nested references are printed after the main reference in the following form: symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)* Example: module @reference { func @nested_reference() } my_reference_op @reference::@nested_reference Given that SymbolRefAttr is now more general, the existing functionality centered around a single reference is moved to a derived class FlatSymbolRefAttr. Followup commits will add support to lookups, rauw, etc. for scoped references. PiperOrigin-RevId: 279860501
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/CallGraph.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp3
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp6
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp17
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp2
-rw-r--r--mlir/lib/Dialect/StandardOps/Ops.cpp10
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp10
-rw-r--r--mlir/lib/IR/AttributeDetail.h37
-rw-r--r--mlir/lib/IR/Attributes.cpp23
-rw-r--r--mlir/lib/IR/Builders.cpp9
-rw-r--r--mlir/lib/IR/FunctionSupport.cpp2
-rw-r--r--mlir/lib/Parser/Parser.cpp28
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp4
15 files changed, 126 insertions, 34 deletions
diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp
index 3c02802dd35..2b5894ff162 100644
--- a/mlir/lib/Analysis/CallGraph.cpp
+++ b/mlir/lib/Analysis/CallGraph.cpp
@@ -184,7 +184,9 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
// Get the callee operation from the callable.
Operation *callee;
if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
- callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef.getValue());
+ // TODO(riverriddle) Support nested references.
+ callee = SymbolTable::lookupNearestSymbolFrom(from,
+ symbolRef.getRootReference());
else
callee = callable.get<Value *>()->getDefiningOp();
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d70d51feee7..bfd094d6203 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -533,7 +533,8 @@ unsigned LaunchFuncOp::getNumKernelOperands() {
}
StringRef LaunchFuncOp::getKernelModuleName() {
- return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue();
+ return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName())
+ .getRootReference();
}
Value *LaunchFuncOp::getKernelOperand(unsigned i) {
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 672beee56cd..420b2340bc9 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -189,7 +189,8 @@ private:
if (Optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName = symbolUse.getSymbolRef().getValue();
+ StringRef symbolName =
+ symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
if (moduleManager.lookupSymbol(symbolName))
continue;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
index 7a8bc7162af..2dc46bf7b2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
@@ -359,8 +359,8 @@ public:
// Get a SymbolRefAttr containing the library function name for the LinalgOp.
// If the library function does not exist, insert a declaration.
template <typename LinalgOp>
-static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
- PatternRewriter &rewriter) {
+static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
+ PatternRewriter &rewriter) {
auto linalgOp = cast<LinalgOp>(op);
auto fnName = linalgOp.getLibraryCallName();
if (fnName.empty()) {
@@ -369,7 +369,7 @@ static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
}
// fnName is a dynamic std::String, unique it via a SymbolRefAttr.
- SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
+ FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
auto module = op->getParentOfType<ModuleOp>();
if (module.lookupSymbol(fnName)) {
return fnNameAttr;
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index b3a9f6f7443..3c1563ed515 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -658,7 +658,7 @@ void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
static ParseResult parseAddressOfOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr varRefAttr;
+ FlatSymbolRefAttr varRefAttr;
Type type;
if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName,
state.attributes) ||
@@ -1088,7 +1088,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
SmallVector<Type, 0> idTypes;
SmallVector<Attribute, 4> interfaceVars;
- SymbolRefAttr fn;
+ FlatSymbolRefAttr fn;
if (parseEnumAttribute(execModel, parser, state) ||
parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
return failure();
@@ -1099,7 +1099,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
do {
// The name of the interface variable attribute isnt important
auto attrName = "var_symbol";
- SymbolRefAttr var;
+ FlatSymbolRefAttr var;
SmallVector<NamedAttribute, 1> attrs;
if (parser.parseAttribute(var, Type(), attrName, attrs)) {
return failure();
@@ -1186,7 +1186,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
static ParseResult parseFunctionCallOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr calleeAttr;
+ FlatSymbolRefAttr calleeAttr;
FunctionType type;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto loc = parser.getNameLoc();
@@ -1305,7 +1305,7 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
// Parse optional initializer
if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
- SymbolRefAttr initSymbol;
+ FlatSymbolRefAttr initSymbol;
if (parser.parseLParen() ||
parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
state.attributes) ||
@@ -1361,7 +1361,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
if (varOp.storageClass() == spirv::StorageClass::Generic)
return varOp.emitOpError("storage class cannot be 'Generic'");
- if (auto init = varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
+ if (auto init =
+ varOp.getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
auto *initOp = moduleOp.lookupSymbol(init.getValue());
// TODO: Currently only variable initialization with specialization
@@ -1713,7 +1714,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
}
if (auto interface = entryPointOp.interface()) {
for (Attribute varRef : interface) {
- auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
+ auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
if (!varSymRef) {
return entryPointOp.emitError(
"expected symbol reference for interface "
@@ -1790,7 +1791,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
static ParseResult parseReferenceOfOp(OpAsmParser &parser,
OperationState &state) {
- SymbolRefAttr constRefAttr;
+ FlatSymbolRefAttr constRefAttr;
Type type;
if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
state.attributes) ||
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
index 11660ed4e87..40b53185529 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
@@ -970,7 +970,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
wordIndex++;
// Initializer.
- SymbolRefAttr initializer = nullptr;
+ FlatSymbolRefAttr initializer = nullptr;
if (wordIndex < operands.size()) {
auto initializerOp = getGlobalVariable(operands[wordIndex]);
if (!initializerOp) {
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index f92b9ae62be..805a3393b0c 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -1696,7 +1696,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
// Add the interface values.
if (auto interface = op.interface()) {
for (auto var : interface.getValue()) {
- auto id = getVariableID(var.cast<SymbolRefAttr>().getValue());
+ auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
if (!id) {
return op.emitError("referencing undefined global variable."
"spv.EntryPoint is at the end of spv.module. All "
diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp
index 12029248168..8c08868bc7a 100644
--- a/mlir/lib/Dialect/StandardOps/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/Ops.cpp
@@ -528,7 +528,7 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
- SymbolRefAttr calleeAttr;
+ FlatSymbolRefAttr calleeAttr;
FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto calleeLoc = parser.getNameLoc();
@@ -555,7 +555,7 @@ static void print(OpAsmPrinter &p, CallOp op) {
static LogicalResult verify(CallOp op) {
// Check that the callee attribute was specified.
- auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
+ auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' symbol reference attribute");
auto fn =
@@ -608,8 +608,8 @@ struct SimplifyIndirectCallWithKnownCallee
// Replace with a direct call.
SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
- callResults, callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
+ callOperands);
return matchSuccess();
}
};
@@ -1206,7 +1206,7 @@ static LogicalResult verify(ConstantOp &op) {
}
if (type.isa<FunctionType>()) {
- auto fnAttr = value.dyn_cast<SymbolRefAttr>();
+ auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
return op.emitOpError("requires 'value' to be a function reference");
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 43452b2712d..6f77de0e721 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -801,9 +801,15 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
case StandardAttributes::Type:
printType(attr.cast<TypeAttr>().getValue());
break;
- case StandardAttributes::SymbolRef:
- printSymbolReference(attr.cast<SymbolRefAttr>().getValue(), os);
+ case StandardAttributes::SymbolRef: {
+ auto refAttr = attr.dyn_cast<SymbolRefAttr>();
+ printSymbolReference(refAttr.getRootReference(), os);
+ for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
+ os << "::";
+ printSymbolReference(nestedRef.getValue(), os);
+ }
break;
+ }
case StandardAttributes::OpaqueElements: {
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 21f8b68c265..da4aa69dda4 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -321,6 +321,43 @@ struct StringAttributeStorage : public AttributeStorage {
StringRef value;
};
+/// An attribute representing a symbol reference.
+struct SymbolRefAttributeStorage final
+ : public AttributeStorage,
+ public llvm::TrailingObjects<SymbolRefAttributeStorage,
+ FlatSymbolRefAttr> {
+ using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>;
+
+ SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs)
+ : value(value), numNestedRefs(numNestedRefs) {}
+
+ /// Key equality function.
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(value, getNestedRefs());
+ }
+
+ /// Construct a new storage instance.
+ static SymbolRefAttributeStorage *
+ construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+ auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>(
+ key.second.size());
+ auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage));
+ auto result = ::new (rawMem) SymbolRefAttributeStorage(
+ allocator.copyInto(key.first), key.second.size());
+ std::uninitialized_copy(key.second.begin(), key.second.end(),
+ result->getTrailingObjects<FlatSymbolRefAttr>());
+ return result;
+ }
+
+ /// Returns the set of nested references.
+ ArrayRef<FlatSymbolRefAttr> getNestedRefs() const {
+ return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs};
+ }
+
+ StringRef value;
+ size_t numNestedRefs;
+};
+
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
using KeyTy = Type;
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index d74cacbe695..80ac4a59246 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -249,12 +249,27 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
// SymbolRefAttr
//===----------------------------------------------------------------------===//
-SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
- return Base::get(ctx, StandardAttributes::SymbolRef, value,
- NoneType::get(ctx));
+FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
+ .cast<FlatSymbolRefAttr>();
}
-StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
+SymbolRefAttr SymbolRefAttr::get(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences,
+ MLIRContext *ctx) {
+ return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
+}
+
+StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
+
+StringRef SymbolRefAttr::getLeafReference() const {
+ ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
+ return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
+}
+
+ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
+ return getImpl()->getNestedRefs();
+}
//===----------------------------------------------------------------------===//
// IntegerAttr
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 24ae2072f77..afdeefd023c 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -150,15 +150,20 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
return ArrayAttr::get(value, context);
}
-SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
auto symName =
value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
assert(symName && "value does not have a valid symbol name");
return getSymbolRefAttr(symName.getValue());
}
-SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
return SymbolRefAttr::get(value, getContext());
}
+SymbolRefAttr
+Builder::getSymbolRefAttr(StringRef value,
+ ArrayRef<FlatSymbolRefAttr> nestedReferences) {
+ return SymbolRefAttr::get(value, nestedReferences, getContext());
+}
ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
auto attrs = functional::map(
diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index d1ba2d30fa1..29cae177cec 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -159,7 +159,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
auto &builder = parser.getBuilder();
// Parse the name as a symbol reference attribute.
- SymbolRefAttr nameAttr;
+ FlatSymbolRefAttr nameAttr;
if (parser.parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 35c694b6a43..2843aae4bb8 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1400,7 +1400,7 @@ static std::string extractSymbolReference(Token tok) {
/// | type
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
-/// | symbol-ref-id
+/// | symbol-ref-id (`::` symbol-ref-id)*
/// | `dense` `<` attribute-value `>` `:`
/// (tensor-type | vector-type)
/// | `sparse` `<` attribute-value `,` attribute-value `>`
@@ -1509,7 +1509,31 @@ Attribute Parser::parseAttribute(Type type) {
case Token::at_identifier: {
std::string nameStr = extractSymbolReference(getToken());
consumeToken(Token::at_identifier);
- return builder.getSymbolRefAttr(nameStr);
+
+ // Parse any nested references.
+ std::vector<FlatSymbolRefAttr> nestedRefs;
+ while (getToken().is(Token::colon)) {
+ // Check for the '::' prefix.
+ const char *curPointer = getToken().getLoc().getPointer();
+ consumeToken(Token::colon);
+ if (!consumeIf(Token::colon)) {
+ state.lex.resetPointer(curPointer);
+ consumeToken();
+ break;
+ }
+ // Parse the reference itself.
+ auto curLoc = getToken().getLoc();
+ if (getToken().isNot(Token::at_identifier)) {
+ emitError(curLoc, "expected nested symbol reference identifier");
+ return Attribute();
+ }
+
+ std::string nameStr = extractSymbolReference(getToken());
+ consumeToken(Token::at_identifier);
+ nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+ }
+
+ return builder.getSymbolRefAttr(nameStr, nestedRefs);
}
// Parse a 'unit' attribute.
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 69f7e933d49..7f3ce5a738f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -52,7 +52,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
return llvm::ConstantInt::get(llvmType, intAttr.getValue());
if (auto floatAttr = attr.dyn_cast<FloatAttr>())
return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
- if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>())
+ if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
return functionMapping.lookup(funcAttr.getValue());
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
auto *sequentialType = cast<llvm::SequentialType>(llvmType);
@@ -194,7 +194,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
auto operands = lookupValues(op.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
- if (auto attr = op.getAttrOfType<SymbolRefAttr>("callee")) {
+ if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee")) {
return builder.CreateCall(functionMapping.lookup(attr.getValue()),
operandsRef);
} else {
OpenPOWER on IntegriCloud