summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp')
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp1394
1 files changed, 1394 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
new file mode 100644
index 00000000000..906cf344347
--- /dev/null
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -0,0 +1,1394 @@
+//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the types and operation details for the LLVM IR dialect in
+// MLIR, and the LLVM IR dialect. It also registers the dialect.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/StandardTypes.h"
+
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+#include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc"
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::CmpOp.
+//===----------------------------------------------------------------------===//
+static void printICmpOp(OpAsmPrinter *p, ICmpOp &op) {
+ *p << op.getOperationName() << " \"" << stringifyICmpPredicate(op.predicate())
+ << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
+ p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
+ *p << " : " << op.lhs()->getType();
+}
+
+static void printFCmpOp(OpAsmPrinter *p, FCmpOp &op) {
+ *p << op.getOperationName() << " \"" << stringifyFCmpPredicate(op.predicate())
+ << "\" " << *op.getOperand(0) << ", " << *op.getOperand(1);
+ p->printOptionalAttrDict(op.getAttrs(), {"predicate"});
+ *p << " : " << op.lhs()->getType();
+}
+
+// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
+// attribute-dict? `:` type
+// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
+// attribute-dict? `:` type
+template <typename CmpPredicateType>
+static ParseResult parseCmpOp(OpAsmParser *parser, OperationState *result) {
+ Builder &builder = parser->getBuilder();
+
+ Attribute predicate;
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType lhs, rhs;
+ Type type;
+ llvm::SMLoc predicateLoc, trailingTypeLoc;
+ if (parser->getCurrentLocation(&predicateLoc) ||
+ parser->parseAttribute(predicate, "predicate", attrs) ||
+ parser->parseOperand(lhs) || parser->parseComma() ||
+ parser->parseOperand(rhs) || parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+ parser->parseType(type) ||
+ parser->resolveOperand(lhs, type, result->operands) ||
+ parser->resolveOperand(rhs, type, result->operands))
+ return failure();
+
+ // Replace the string attribute `predicate` with an integer attribute.
+ auto predicateStr = predicate.dyn_cast<StringAttr>();
+ if (!predicateStr)
+ return parser->emitError(predicateLoc,
+ "expected 'predicate' attribute of string type");
+
+ int64_t predicateValue = 0;
+ if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
+ Optional<ICmpPredicate> predicate =
+ symbolizeICmpPredicate(predicateStr.getValue());
+ if (!predicate)
+ return parser->emitError(predicateLoc)
+ << "'" << predicateStr.getValue()
+ << "' is an incorrect value of the 'predicate' attribute";
+ predicateValue = static_cast<int64_t>(predicate.getValue());
+ } else {
+ Optional<FCmpPredicate> predicate =
+ symbolizeFCmpPredicate(predicateStr.getValue());
+ if (!predicate)
+ return parser->emitError(predicateLoc)
+ << "'" << predicateStr.getValue()
+ << "' is an incorrect value of the 'predicate' attribute";
+ predicateValue = static_cast<int64_t>(predicate.getValue());
+ }
+
+ attrs[0].second = parser->getBuilder().getI64IntegerAttr(predicateValue);
+
+ // The result type is either i1 or a vector type <? x i1> if the inputs are
+ // vectors.
+ auto *dialect = builder.getContext()->getRegisteredDialect<LLVMDialect>();
+ auto resultType = LLVMType::getInt1Ty(dialect);
+ auto argType = type.dyn_cast<LLVM::LLVMType>();
+ if (!argType)
+ return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type");
+ if (argType.getUnderlyingType()->isVectorTy())
+ resultType = LLVMType::getVectorTy(
+ resultType, argType.getUnderlyingType()->getVectorNumElements());
+
+ result->attributes = attrs;
+ result->addTypes({resultType});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::AllocaOp.
+//===----------------------------------------------------------------------===//
+
+static void printAllocaOp(OpAsmPrinter *p, AllocaOp &op) {
+ auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
+
+ auto funcTy = FunctionType::get({op.arraySize()->getType()}, {op.getType()},
+ op.getContext());
+
+ *p << op.getOperationName() << ' ' << *op.arraySize() << " x " << elemTy;
+ if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
+ p->printOptionalAttrDict(op.getAttrs());
+ else
+ p->printOptionalAttrDict(op.getAttrs(), {"alignment"});
+ *p << " : " << funcTy;
+}
+
+// <operation> ::= `llvm.alloca` ssa-use `x` type attribute-dict?
+// `:` type `,` type
+static ParseResult parseAllocaOp(OpAsmParser *parser, OperationState *result) {
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType arraySize;
+ Type type, elemType;
+ llvm::SMLoc trailingTypeLoc;
+ if (parser->parseOperand(arraySize) || parser->parseKeyword("x") ||
+ parser->parseType(elemType) ||
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+ parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
+ return failure();
+
+ // Extract the result type from the trailing function type.
+ auto funcType = type.dyn_cast<FunctionType>();
+ if (!funcType || funcType.getNumInputs() != 1 ||
+ funcType.getNumResults() != 1)
+ return parser->emitError(
+ trailingTypeLoc,
+ "expected trailing function type with one argument and one result");
+
+ if (parser->resolveOperand(arraySize, funcType.getInput(0), result->operands))
+ return failure();
+
+ result->attributes = attrs;
+ result->addTypes({funcType.getResult(0)});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::GEPOp.
+//===----------------------------------------------------------------------===//
+
+static void printGEPOp(OpAsmPrinter *p, GEPOp &op) {
+ SmallVector<Type, 8> types(op.getOperandTypes());
+ auto funcTy = FunctionType::get(types, op.getType(), op.getContext());
+
+ *p << op.getOperationName() << ' ' << *op.base() << '[';
+ p->printOperands(std::next(op.operand_begin()), op.operand_end());
+ *p << ']';
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << funcTy;
+}
+
+// <operation> ::= `llvm.getelementptr` ssa-use `[` ssa-use-list `]`
+// attribute-dict? `:` type
+static ParseResult parseGEPOp(OpAsmParser *parser, OperationState *result) {
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType base;
+ SmallVector<OpAsmParser::OperandType, 8> indices;
+ Type type;
+ llvm::SMLoc trailingTypeLoc;
+ if (parser->parseOperand(base) ||
+ parser->parseOperandList(indices, OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+ parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
+ return failure();
+
+ // Deconstruct the trailing function type to extract the types of the base
+ // pointer and result (same type) and the types of the indices.
+ auto funcType = type.dyn_cast<FunctionType>();
+ if (!funcType || funcType.getNumResults() != 1 ||
+ funcType.getNumInputs() == 0)
+ return parser->emitError(trailingTypeLoc,
+ "expected trailing function type with at least "
+ "one argument and one result");
+
+ if (parser->resolveOperand(base, funcType.getInput(0), result->operands) ||
+ parser->resolveOperands(indices, funcType.getInputs().drop_front(),
+ parser->getNameLoc(), result->operands))
+ return failure();
+
+ result->attributes = attrs;
+ result->addTypes(funcType.getResults());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::LoadOp.
+//===----------------------------------------------------------------------===//
+
+static void printLoadOp(OpAsmPrinter *p, LoadOp &op) {
+ *p << op.getOperationName() << ' ' << *op.addr();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.addr()->getType();
+}
+
+// Extract the pointee type from the LLVM pointer type wrapped in MLIR. Return
+// the resulting type wrapped in MLIR, or nullptr on error.
+static Type getLoadStoreElementType(OpAsmParser *parser, Type type,
+ llvm::SMLoc trailingTypeLoc) {
+ auto llvmTy = type.dyn_cast<LLVM::LLVMType>();
+ if (!llvmTy)
+ return parser->emitError(trailingTypeLoc, "expected LLVM IR dialect type"),
+ nullptr;
+ if (!llvmTy.getUnderlyingType()->isPointerTy())
+ return parser->emitError(trailingTypeLoc, "expected LLVM pointer type"),
+ nullptr;
+ return llvmTy.getPointerElementTy();
+}
+
+// <operation> ::= `llvm.load` ssa-use attribute-dict? `:` type
+static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *result) {
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType addr;
+ Type type;
+ llvm::SMLoc trailingTypeLoc;
+
+ if (parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+ parser->parseType(type) ||
+ parser->resolveOperand(addr, type, result->operands))
+ return failure();
+
+ Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
+
+ result->attributes = attrs;
+ result->addTypes(elemTy);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::StoreOp.
+//===----------------------------------------------------------------------===//
+
+static void printStoreOp(OpAsmPrinter *p, StoreOp &op) {
+ *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.addr();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.addr()->getType();
+}
+
+// <operation> ::= `llvm.store` ssa-use `,` ssa-use attribute-dict? `:` type
+static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *result) {
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType addr, value;
+ Type type;
+ llvm::SMLoc trailingTypeLoc;
+
+ if (parser->parseOperand(value) || parser->parseComma() ||
+ parser->parseOperand(addr) || parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+ parser->parseType(type))
+ return failure();
+
+ Type elemTy = getLoadStoreElementType(parser, type, trailingTypeLoc);
+ if (!elemTy)
+ return failure();
+
+ if (parser->resolveOperand(value, elemTy, result->operands) ||
+ parser->resolveOperand(addr, type, result->operands))
+ return failure();
+
+ result->attributes = attrs;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::CallOp.
+//===----------------------------------------------------------------------===//
+
+static void printCallOp(OpAsmPrinter *p, CallOp &op) {
+ auto callee = op.callee();
+ bool isDirect = callee.hasValue();
+
+ // Print the direct callee if present as a function attribute, or an indirect
+ // callee (first operand) otherwise.
+ *p << op.getOperationName() << ' ';
+ if (isDirect)
+ *p << '@' << callee.getValue();
+ else
+ *p << *op.getOperand(0);
+
+ *p << '(';
+ p->printOperands(llvm::drop_begin(op.getOperands(), isDirect ? 0 : 1));
+ *p << ')';
+
+ p->printOptionalAttrDict(op.getAttrs(), {"callee"});
+
+ // Reconstruct the function MLIR function type from operand and result types.
+ SmallVector<Type, 1> resultTypes(op.getResultTypes());
+ SmallVector<Type, 8> argTypes(
+ llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
+
+ *p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
+}
+
+// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
+// attribute-dict? `:` function-type
+static ParseResult parseCallOp(OpAsmParser *parser, OperationState *result) {
+ SmallVector<NamedAttribute, 4> attrs;
+ SmallVector<OpAsmParser::OperandType, 8> operands;
+ Type type;
+ SymbolRefAttr funcAttr;
+ llvm::SMLoc trailingTypeLoc;
+
+ // Parse an operand list that will, in practice, contain 0 or 1 operand. In
+ // case of an indirect call, there will be 1 operand before `(`. In case of a
+ // direct call, there will be no operands and the parser will stop at the
+ // function identifier without complaining.
+ if (parser->parseOperandList(operands))
+ return failure();
+ bool isDirect = operands.empty();
+
+ // Optionally parse a function identifier.
+ if (isDirect)
+ if (parser->parseAttribute(funcAttr, "callee", attrs))
+ return failure();
+
+ if (parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+ parser->getCurrentLocation(&trailingTypeLoc) || parser->parseType(type))
+ return failure();
+
+ auto funcType = type.dyn_cast<FunctionType>();
+ if (!funcType)
+ return parser->emitError(trailingTypeLoc, "expected function type");
+ if (isDirect) {
+ // Make sure types match.
+ if (parser->resolveOperands(operands, funcType.getInputs(),
+ parser->getNameLoc(), result->operands))
+ return failure();
+ result->addTypes(funcType.getResults());
+ } else {
+ // Construct the LLVM IR Dialect function type that the first operand
+ // should match.
+ if (funcType.getNumResults() > 1)
+ return parser->emitError(trailingTypeLoc,
+ "expected function with 0 or 1 result");
+
+ Builder &builder = parser->getBuilder();
+ auto *llvmDialect =
+ builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+ LLVM::LLVMType llvmResultType;
+ if (funcType.getNumResults() == 0) {
+ llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
+ } else {
+ llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
+ if (!llvmResultType)
+ return parser->emitError(trailingTypeLoc,
+ "expected result to have LLVM type");
+ }
+
+ SmallVector<LLVM::LLVMType, 8> argTypes;
+ argTypes.reserve(funcType.getNumInputs());
+ for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
+ auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
+ if (!argType)
+ return parser->emitError(trailingTypeLoc,
+ "expected LLVM types as inputs");
+ argTypes.push_back(argType);
+ }
+ auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
+ /*isVarArg=*/false);
+ auto wrappedFuncType = llvmFuncType.getPointerTo();
+
+ auto funcArguments =
+ ArrayRef<OpAsmParser::OperandType>(operands).drop_front();
+
+ // Make sure that the first operand (indirect callee) matches the wrapped
+ // LLVM IR function type, and that the types of the other call operands
+ // match the types of the function arguments.
+ if (parser->resolveOperand(operands[0], wrappedFuncType,
+ result->operands) ||
+ parser->resolveOperands(funcArguments, funcType.getInputs(),
+ parser->getNameLoc(), result->operands))
+ return failure();
+
+ result->addTypes(llvmResultType);
+ }
+
+ result->attributes = attrs;
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ExtractElementOp.
+//===----------------------------------------------------------------------===//
+// Expects vector to be of wrapped LLVM vector type and position to be of
+// wrapped LLVM i32 type.
+void LLVM::ExtractElementOp::build(Builder *b, OperationState *result,
+ Value *vector, Value *position,
+ ArrayRef<NamedAttribute> attrs) {
+ auto wrappedVectorType = vector->getType().cast<LLVM::LLVMType>();
+ auto llvmType = wrappedVectorType.getVectorElementType();
+ build(b, result, llvmType, vector, position);
+ result->addAttributes(attrs);
+}
+
+static void printExtractElementOp(OpAsmPrinter *p, ExtractElementOp &op) {
+ *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.position();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.vector()->getType();
+}
+
+// <operation> ::= `llvm.extractelement` ssa-use `, ` ssa-use
+// attribute-dict? `:` type
+static ParseResult parseExtractElementOp(OpAsmParser *parser,
+ OperationState *result) {
+ llvm::SMLoc loc;
+ OpAsmParser::OperandType vector, position;
+ auto *llvmDialect = parser->getBuilder()
+ .getContext()
+ ->getRegisteredDialect<LLVM::LLVMDialect>();
+ Type type, i32Type = LLVMType::getInt32Ty(llvmDialect);
+ if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) ||
+ parser->parseComma() || parser->parseOperand(position) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(vector, type, result->operands) ||
+ parser->resolveOperand(position, i32Type, result->operands))
+ return failure();
+ auto wrappedVectorType = type.dyn_cast<LLVM::LLVMType>();
+ if (!wrappedVectorType ||
+ !wrappedVectorType.getUnderlyingType()->isVectorTy())
+ return parser->emitError(
+ loc, "expected LLVM IR dialect vector type for operand #1");
+ result->addTypes(wrappedVectorType.getVectorElementType());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ExtractValueOp.
+//===----------------------------------------------------------------------===//
+
+static void printExtractValueOp(OpAsmPrinter *p, ExtractValueOp &op) {
+ *p << op.getOperationName() << ' ' << *op.container() << op.position();
+ p->printOptionalAttrDict(op.getAttrs(), {"position"});
+ *p << " : " << op.container()->getType();
+}
+
+// Extract the type at `position` in the wrapped LLVM IR aggregate type
+// `containerType`. Position is an integer array attribute where each value
+// is a zero-based position of the element in the aggregate type. Return the
+// resulting type wrapped in MLIR, or nullptr on error.
+static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser *parser,
+ Type containerType,
+ Attribute positionAttr,
+ llvm::SMLoc attributeLoc,
+ llvm::SMLoc typeLoc) {
+ auto wrappedContainerType = containerType.dyn_cast<LLVM::LLVMType>();
+ if (!wrappedContainerType)
+ return parser->emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
+
+ auto positionArrayAttr = positionAttr.dyn_cast<ArrayAttr>();
+ if (!positionArrayAttr)
+ return parser->emitError(attributeLoc, "expected an array attribute"),
+ nullptr;
+
+ // Infer the element type from the structure type: iteratively step inside the
+ // type by taking the element type, indexed by the position attribute for
+ // stuctures. Check the position index before accessing, it is supposed to be
+ // in bounds.
+ for (Attribute subAttr : positionArrayAttr) {
+ auto positionElementAttr = subAttr.dyn_cast<IntegerAttr>();
+ if (!positionElementAttr)
+ return parser->emitError(attributeLoc,
+ "expected an array of integer literals"),
+ nullptr;
+ int position = positionElementAttr.getInt();
+ auto *llvmContainerType = wrappedContainerType.getUnderlyingType();
+ if (llvmContainerType->isArrayTy()) {
+ if (position < 0 || static_cast<unsigned>(position) >=
+ llvmContainerType->getArrayNumElements())
+ return parser->emitError(attributeLoc, "position out of bounds"),
+ nullptr;
+ wrappedContainerType = wrappedContainerType.getArrayElementType();
+ } else if (llvmContainerType->isStructTy()) {
+ if (position < 0 || static_cast<unsigned>(position) >=
+ llvmContainerType->getStructNumElements())
+ return parser->emitError(attributeLoc, "position out of bounds"),
+ nullptr;
+ wrappedContainerType =
+ wrappedContainerType.getStructElementType(position);
+ } else {
+ return parser->emitError(typeLoc,
+ "expected wrapped LLVM IR structure/array type"),
+ nullptr;
+ }
+ }
+ return wrappedContainerType;
+}
+
+// <operation> ::= `llvm.extractvalue` ssa-use
+// `[` integer-literal (`,` integer-literal)* `]`
+// attribute-dict? `:` type
+static ParseResult parseExtractValueOp(OpAsmParser *parser,
+ OperationState *result) {
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType container;
+ Type containerType;
+ Attribute positionAttr;
+ llvm::SMLoc attributeLoc, trailingTypeLoc;
+
+ if (parser->parseOperand(container) ||
+ parser->getCurrentLocation(&attributeLoc) ||
+ parser->parseAttribute(positionAttr, "position", attrs) ||
+ parser->parseOptionalAttributeDict(attrs) || parser->parseColon() ||
+ parser->getCurrentLocation(&trailingTypeLoc) ||
+ parser->parseType(containerType) ||
+ parser->resolveOperand(container, containerType, result->operands))
+ return failure();
+
+ auto elementType = getInsertExtractValueElementType(
+ parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
+ if (!elementType)
+ return failure();
+
+ result->attributes = attrs;
+ result->addTypes(elementType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::InsertElementOp.
+//===----------------------------------------------------------------------===//
+
+static void printInsertElementOp(OpAsmPrinter *p, InsertElementOp &op) {
+ *p << op.getOperationName() << ' ' << *op.vector() << ", " << *op.value()
+ << ", " << *op.position();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.vector()->getType();
+}
+
+// <operation> ::= `llvm.insertelement` ssa-use `,` ssa-use `,` ssa-use
+// attribute-dict? `:` type
+static ParseResult parseInsertElementOp(OpAsmParser *parser,
+ OperationState *result) {
+ llvm::SMLoc loc;
+ OpAsmParser::OperandType vector, value, position;
+ auto *llvmDialect = parser->getBuilder()
+ .getContext()
+ ->getRegisteredDialect<LLVM::LLVMDialect>();
+ Type vectorType, i32Type = LLVMType::getInt32Ty(llvmDialect);
+ if (parser->getCurrentLocation(&loc) || parser->parseOperand(vector) ||
+ parser->parseComma() || parser->parseOperand(value) ||
+ parser->parseComma() || parser->parseOperand(position) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(vectorType))
+ return failure();
+
+ auto wrappedVectorType = vectorType.dyn_cast<LLVM::LLVMType>();
+ if (!wrappedVectorType ||
+ !wrappedVectorType.getUnderlyingType()->isVectorTy())
+ return parser->emitError(
+ loc, "expected LLVM IR dialect vector type for operand #1");
+ auto valueType = wrappedVectorType.getVectorElementType();
+ if (!valueType)
+ return failure();
+
+ if (parser->resolveOperand(vector, vectorType, result->operands) ||
+ parser->resolveOperand(value, valueType, result->operands) ||
+ parser->resolveOperand(position, i32Type, result->operands))
+ return failure();
+
+ result->addTypes(vectorType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::InsertValueOp.
+//===----------------------------------------------------------------------===//
+
+static void printInsertValueOp(OpAsmPrinter *p, InsertValueOp &op) {
+ *p << op.getOperationName() << ' ' << *op.value() << ", " << *op.container()
+ << op.position();
+ p->printOptionalAttrDict(op.getAttrs(), {"position"});
+ *p << " : " << op.container()->getType();
+}
+
+// <operation> ::= `llvm.insertvaluevalue` ssa-use `,` ssa-use
+// `[` integer-literal (`,` integer-literal)* `]`
+// attribute-dict? `:` type
+static ParseResult parseInsertValueOp(OpAsmParser *parser,
+ OperationState *result) {
+ OpAsmParser::OperandType container, value;
+ Type containerType;
+ Attribute positionAttr;
+ llvm::SMLoc attributeLoc, trailingTypeLoc;
+
+ if (parser->parseOperand(value) || parser->parseComma() ||
+ parser->parseOperand(container) ||
+ parser->getCurrentLocation(&attributeLoc) ||
+ parser->parseAttribute(positionAttr, "position", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColon() || parser->getCurrentLocation(&trailingTypeLoc) ||
+ parser->parseType(containerType))
+ return failure();
+
+ auto valueType = getInsertExtractValueElementType(
+ parser, containerType, positionAttr, attributeLoc, trailingTypeLoc);
+ if (!valueType)
+ return failure();
+
+ if (parser->resolveOperand(container, containerType, result->operands) ||
+ parser->resolveOperand(value, valueType, result->operands))
+ return failure();
+
+ result->addTypes(containerType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::SelectOp.
+//===----------------------------------------------------------------------===//
+
+static void printSelectOp(OpAsmPrinter *p, SelectOp &op) {
+ *p << op.getOperationName() << ' ' << *op.condition() << ", "
+ << *op.trueValue() << ", " << *op.falseValue();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.condition()->getType() << ", " << op.trueValue()->getType();
+}
+
+// <operation> ::= `llvm.select` ssa-use `,` ssa-use `,` ssa-use
+// attribute-dict? `:` type, type
+static ParseResult parseSelectOp(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType condition, trueValue, falseValue;
+ Type conditionType, argType;
+
+ if (parser->parseOperand(condition) || parser->parseComma() ||
+ parser->parseOperand(trueValue) || parser->parseComma() ||
+ parser->parseOperand(falseValue) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(conditionType) || parser->parseComma() ||
+ parser->parseType(argType))
+ return failure();
+
+ if (parser->resolveOperand(condition, conditionType, result->operands) ||
+ parser->resolveOperand(trueValue, argType, result->operands) ||
+ parser->resolveOperand(falseValue, argType, result->operands))
+ return failure();
+
+ result->addTypes(argType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::BrOp.
+//===----------------------------------------------------------------------===//
+
+static void printBrOp(OpAsmPrinter *p, BrOp &op) {
+ *p << op.getOperationName() << ' ';
+ p->printSuccessorAndUseList(op.getOperation(), 0);
+ p->printOptionalAttrDict(op.getAttrs());
+}
+
+// <operation> ::= `llvm.br` bb-id (`[` ssa-use-and-type-list `]`)?
+// attribute-dict?
+static ParseResult parseBrOp(OpAsmParser *parser, OperationState *result) {
+ Block *dest;
+ SmallVector<Value *, 4> operands;
+ if (parser->parseSuccessorAndUseList(dest, operands) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return failure();
+
+ result->addSuccessor(dest, operands);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::CondBrOp.
+//===----------------------------------------------------------------------===//
+
+static void printCondBrOp(OpAsmPrinter *p, CondBrOp &op) {
+ *p << op.getOperationName() << ' ' << *op.getOperand(0) << ", ";
+ p->printSuccessorAndUseList(op.getOperation(), 0);
+ *p << ", ";
+ p->printSuccessorAndUseList(op.getOperation(), 1);
+ p->printOptionalAttrDict(op.getAttrs());
+}
+
+// <operation> ::= `llvm.cond_br` ssa-use `,`
+// bb-id (`[` ssa-use-and-type-list `]`)? `,`
+// bb-id (`[` ssa-use-and-type-list `]`)? attribute-dict?
+static ParseResult parseCondBrOp(OpAsmParser *parser, OperationState *result) {
+ Block *trueDest;
+ Block *falseDest;
+ SmallVector<Value *, 4> trueOperands;
+ SmallVector<Value *, 4> falseOperands;
+ OpAsmParser::OperandType condition;
+
+ Builder &builder = parser->getBuilder();
+ auto *llvmDialect =
+ builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
+ auto i1Type = LLVM::LLVMType::getInt1Ty(llvmDialect);
+
+ if (parser->parseOperand(condition) || parser->parseComma() ||
+ parser->parseSuccessorAndUseList(trueDest, trueOperands) ||
+ parser->parseComma() ||
+ parser->parseSuccessorAndUseList(falseDest, falseOperands) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->resolveOperand(condition, i1Type, result->operands))
+ return failure();
+
+ result->addSuccessor(trueDest, trueOperands);
+ result->addSuccessor(falseDest, falseOperands);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ReturnOp.
+//===----------------------------------------------------------------------===//
+
+static void printReturnOp(OpAsmPrinter *p, ReturnOp &op) {
+ *p << op.getOperationName();
+ p->printOptionalAttrDict(op.getAttrs());
+ assert(op.getNumOperands() <= 1);
+
+ if (op.getNumOperands() == 0)
+ return;
+
+ *p << ' ' << *op.getOperand(0) << " : " << op.getOperand(0)->getType();
+}
+
+// <operation> ::= `llvm.return` ssa-use-list attribute-dict? `:`
+// type-list-no-parens
+static ParseResult parseReturnOp(OpAsmParser *parser, OperationState *result) {
+ SmallVector<OpAsmParser::OperandType, 1> operands;
+ Type type;
+
+ if (parser->parseOperandList(operands) ||
+ parser->parseOptionalAttributeDict(result->attributes))
+ return failure();
+ if (operands.empty())
+ return success();
+
+ if (parser->parseColonType(type) ||
+ parser->resolveOperand(operands[0], type, result->operands))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::UndefOp.
+//===----------------------------------------------------------------------===//
+
+static void printUndefOp(OpAsmPrinter *p, UndefOp &op) {
+ *p << op.getOperationName();
+ p->printOptionalAttrDict(op.getAttrs());
+ *p << " : " << op.res()->getType();
+}
+
+// <operation> ::= `llvm.undef` attribute-dict? : type
+static ParseResult parseUndefOp(OpAsmParser *parser, OperationState *result) {
+ Type type;
+
+ if (parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return failure();
+
+ result->addTypes(type);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printer, parser and verifier for LLVM::AddressOfOp.
+//===----------------------------------------------------------------------===//
+
+GlobalOp AddressOfOp::getGlobal() {
+ auto module = getParentOfType<ModuleOp>();
+ assert(module && "unexpected operation outside of a module");
+ return module.lookupSymbol<LLVM::GlobalOp>(global_name());
+}
+
+static void printAddressOfOp(OpAsmPrinter *p, AddressOfOp op) {
+ *p << op.getOperationName() << " @" << op.global_name();
+ p->printOptionalAttrDict(op.getAttrs(), {"global_name"});
+ *p << " : " << op.getResult()->getType();
+}
+
+static ParseResult parseAddressOfOp(OpAsmParser *parser,
+ OperationState *result) {
+ Attribute symRef;
+ Type type;
+ if (parser->parseAttribute(symRef, "global_name", result->attributes) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->addTypeToList(type, result->types))
+ return failure();
+
+ if (!symRef.isa<SymbolRefAttr>())
+ return parser->emitError(parser->getNameLoc(), "expected symbol reference");
+ return success();
+}
+
+static LogicalResult verify(AddressOfOp op) {
+ auto global = op.getGlobal();
+ if (!global)
+ return op.emitOpError("must reference a global defined by 'llvm.global'");
+
+ if (global.getType().getPointerTo() != op.getResult()->getType())
+ return op.emitOpError(
+ "the type must be a pointer to the type of the referred global");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ConstantOp.
+//===----------------------------------------------------------------------===//
+
+static void printConstantOp(OpAsmPrinter *p, ConstantOp &op) {
+ *p << op.getOperationName() << '(' << op.value() << ')';
+ p->printOptionalAttrDict(op.getAttrs(), {"value"});
+ *p << " : " << op.res()->getType();
+}
+
+// <operation> ::= `llvm.constant` `(` attribute `)` attribute-list? : type
+static ParseResult parseConstantOp(OpAsmParser *parser,
+ OperationState *result) {
+ Attribute valueAttr;
+ Type type;
+
+ if (parser->parseLParen() ||
+ parser->parseAttribute(valueAttr, "value", result->attributes) ||
+ parser->parseRParen() ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type))
+ return failure();
+
+ result->addTypes(type);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Builder, printer and verifier for LLVM::GlobalOp.
+//===----------------------------------------------------------------------===//
+
+void GlobalOp::build(Builder *builder, OperationState *result, LLVMType type,
+ bool isConstant, StringRef name, Attribute value,
+ ArrayRef<NamedAttribute> attrs) {
+ result->addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ result->addAttribute("type", builder->getTypeAttr(type));
+ if (isConstant)
+ result->addAttribute("constant", builder->getUnitAttr());
+ result->addAttribute("value", value);
+ result->attributes.append(attrs.begin(), attrs.end());
+}
+
+static void printGlobalOp(OpAsmPrinter *p, GlobalOp op) {
+ *p << op.getOperationName() << ' ';
+ if (op.constant())
+ *p << "constant ";
+ *p << '@' << op.sym_name() << '(';
+ p->printAttribute(op.value());
+ *p << ')';
+ p->printOptionalAttrDict(op.getAttrs(), {SymbolTable::getSymbolAttrName(),
+ "type", "constant", "value"});
+
+ // Print the trailing type unless it's a string global.
+ if (op.value().isa<StringAttr>())
+ return;
+ *p << " : ";
+ p->printType(op.type());
+}
+
+// <operation> ::= `llvm.global` `constant`? `@` identifier `(` attribute `)`
+// attribute-list? (`:` type)?
+//
+// The type can be omitted for string attributes, in which case it will be
+// inferred from the value of the string as [strlen(value) x i8].
+static ParseResult parseGlobalOp(OpAsmParser *parser, OperationState *result) {
+ if (succeeded(parser->parseOptionalKeyword("constant")))
+ result->addAttribute("constant", parser->getBuilder().getUnitAttr());
+
+ Attribute value;
+ StringAttr name;
+ SmallVector<Type, 1> types;
+ if (parser->parseSymbolName(name, SymbolTable::getSymbolAttrName(),
+ result->attributes) ||
+ parser->parseLParen() ||
+ parser->parseAttribute(value, "value", result->attributes) ||
+ parser->parseRParen() ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseOptionalColonTypeList(types))
+ return failure();
+
+ if (types.size() > 1)
+ return parser->emitError(parser->getNameLoc(), "expected zero or one type");
+
+ if (types.empty()) {
+ if (auto strAttr = value.dyn_cast<StringAttr>()) {
+ MLIRContext *context = parser->getBuilder().getContext();
+ auto *dialect = context->getRegisteredDialect<LLVMDialect>();
+ auto arrayType = LLVM::LLVMType::getArrayTy(
+ LLVM::LLVMType::getInt8Ty(dialect), strAttr.getValue().size());
+ types.push_back(arrayType);
+ } else {
+ return parser->emitError(parser->getNameLoc(),
+ "type can only be omitted for string globals");
+ }
+ }
+
+ result->addAttribute("type", parser->getBuilder().getTypeAttr(types[0]));
+ return success();
+}
+
+static LogicalResult verify(GlobalOp op) {
+ if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType()))
+ return op.emitOpError(
+ "expects type to be a valid element type for an LLVM pointer");
+ if (op.getParentOp() && !isa<ModuleOp>(op.getParentOp()))
+ return op.emitOpError("must appear at the module level");
+ if (auto strAttr = op.value().dyn_cast<StringAttr>()) {
+ auto type = op.getType();
+ if (!type.getUnderlyingType()->isArrayTy() ||
+ !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) ||
+ type.getArrayNumElements() != strAttr.getValue().size())
+ return op.emitOpError(
+ "requires an i8 array type of the length equal to that of the string "
+ "attribute");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing/parsing for LLVM::ShuffleVectorOp.
+//===----------------------------------------------------------------------===//
+// Expects vector to be of wrapped LLVM vector type and position to be of
+// wrapped LLVM i32 type.
+void LLVM::ShuffleVectorOp::build(Builder *b, OperationState *result, Value *v1,
+ Value *v2, ArrayAttr mask,
+ ArrayRef<NamedAttribute> attrs) {
+ auto wrappedContainerType1 = v1->getType().cast<LLVM::LLVMType>();
+ auto vType = LLVMType::getVectorTy(
+ wrappedContainerType1.getVectorElementType(), mask.size());
+ build(b, result, vType, v1, v2, mask);
+ result->addAttributes(attrs);
+}
+
+static void printShuffleVectorOp(OpAsmPrinter *p, ShuffleVectorOp &op) {
+ *p << op.getOperationName() << ' ' << *op.v1() << ", " << *op.v2() << " "
+ << op.mask();
+ p->printOptionalAttrDict(op.getAttrs(), {"mask"});
+ *p << " : " << op.v1()->getType() << ", " << op.v2()->getType();
+}
+
+// <operation> ::= `llvm.shufflevector` ssa-use `, ` ssa-use
+// `[` integer-literal (`,` integer-literal)* `]`
+// attribute-dict? `:` type
+static ParseResult parseShuffleVectorOp(OpAsmParser *parser,
+ OperationState *result) {
+ llvm::SMLoc loc;
+ SmallVector<NamedAttribute, 4> attrs;
+ OpAsmParser::OperandType v1, v2;
+ Attribute maskAttr;
+ Type typeV1, typeV2;
+ if (parser->getCurrentLocation(&loc) || parser->parseOperand(v1) ||
+ parser->parseComma() || parser->parseOperand(v2) ||
+ parser->parseAttribute(maskAttr, "mask", attrs) ||
+ parser->parseOptionalAttributeDict(attrs) ||
+ parser->parseColonType(typeV1) || parser->parseComma() ||
+ parser->parseType(typeV2) ||
+ parser->resolveOperand(v1, typeV1, result->operands) ||
+ parser->resolveOperand(v2, typeV2, result->operands))
+ return failure();
+ auto wrappedContainerType1 = typeV1.dyn_cast<LLVM::LLVMType>();
+ if (!wrappedContainerType1 ||
+ !wrappedContainerType1.getUnderlyingType()->isVectorTy())
+ return parser->emitError(
+ loc, "expected LLVM IR dialect vector type for operand #1");
+ auto vType =
+ LLVMType::getVectorTy(wrappedContainerType1.getVectorElementType(),
+ maskAttr.cast<ArrayAttr>().size());
+ result->attributes = attrs;
+ result->addTypes(vType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Builder, printer and verifier for LLVM::LLVMFuncOp.
+//===----------------------------------------------------------------------===//
+
+void LLVMFuncOp::build(Builder *builder, OperationState *result, StringRef name,
+ LLVMType type, ArrayRef<NamedAttribute> attrs,
+ ArrayRef<NamedAttributeList> argAttrs) {
+ result->addRegion();
+ result->addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ result->addAttribute("type", builder->getTypeAttr(type));
+ result->attributes.append(attrs.begin(), attrs.end());
+ if (argAttrs.empty())
+ return;
+
+ unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams();
+ assert(numInputs == argAttrs.size() &&
+ "expected as many argument attribute lists as arguments");
+ SmallString<8> argAttrName;
+ for (unsigned i = 0; i < numInputs; ++i)
+ if (auto argDict = argAttrs[i].getDictionary())
+ result->addAttribute(getArgAttrName(i, argAttrName), argDict);
+}
+
+// Build an LLVM function type from the given lists of input and output types.
+// Returns a null type if any of the types provided are non-LLVM types, or if
+// there is more than one output type.
+static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
+ ArrayRef<Type> outputs,
+ impl::VariadicFlag variadicFlag,
+ std::string &errorMessage) {
+ if (outputs.size() > 1) {
+ errorMessage = "expected zero or one function result";
+ return {};
+ }
+
+ // Convert inputs to LLVM types, exit early on error.
+ SmallVector<LLVMType, 4> llvmInputs;
+ for (auto t : inputs) {
+ auto llvmTy = t.dyn_cast<LLVMType>();
+ if (!llvmTy) {
+ errorMessage = "expected LLVM type for function arguments";
+ return {};
+ }
+ llvmInputs.push_back(llvmTy);
+ }
+
+ // Get the dialect from the input type, if any exist. Look it up in the
+ // context otherwise.
+ LLVMDialect *dialect =
+ llvmInputs.empty() ? b.getContext()->getRegisteredDialect<LLVMDialect>()
+ : &llvmInputs.front().getDialect();
+
+ // No output is denoted as "void" in LLVM type system.
+ LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
+ : outputs.front().dyn_cast<LLVMType>();
+ if (!llvmOutput) {
+ errorMessage = "expected LLVM type for function results";
+ return {};
+ }
+ return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
+ variadicFlag.isVariadic());
+}
+
+// Print the LLVMFuncOp. Collects argument and result types and passes them
+// to the trait printer. Drops "void" result since it cannot be parsed back.
+static void printLLVMFuncOp(OpAsmPrinter *p, LLVMFuncOp op) {
+ LLVMType fnType = op.getType();
+ SmallVector<Type, 8> argTypes;
+ SmallVector<Type, 1> resTypes;
+ argTypes.reserve(fnType.getFunctionNumParams());
+ for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i)
+ argTypes.push_back(fnType.getFunctionParamType(i));
+
+ LLVMType returnType = fnType.getFunctionResultType();
+ if (!returnType.getUnderlyingType()->isVoidTy())
+ resTypes.push_back(returnType);
+
+ impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes);
+}
+
+// Hook for OpTrait::FunctionLike, called after verifying that the 'type'
+// attribute is present. This can check for preconditions of the
+// getNumArguments hook not failing.
+LogicalResult LLVMFuncOp::verifyType() {
+ auto llvmType = getTypeAttr().getValue().dyn_cast_or_null<LLVMType>();
+ if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of wrapped LLVM function type");
+
+ return success();
+}
+
+// Hook for OpTrait::FunctionLike, returns the number of function arguments.
+// Depends on the type attribute being correct as checked by verifyType
+unsigned LLVMFuncOp::getNumFuncArguments() {
+ return getType().getUnderlyingType()->getFunctionNumParams();
+}
+
+static LogicalResult verify(LLVMFuncOp op) {
+ if (op.isExternal())
+ return success();
+
+ if (op.isVarArg())
+ return op.emitOpError("only external functions can be variadic");
+
+ auto *funcType = cast<llvm::FunctionType>(op.getType().getUnderlyingType());
+ unsigned numArguments = funcType->getNumParams();
+ Block &entryBlock = op.front();
+ for (unsigned i = 0; i < numArguments; ++i) {
+ Type argType = entryBlock.getArgument(i)->getType();
+ auto argLLVMType = argType.dyn_cast<LLVMType>();
+ if (!argLLVMType)
+ return op.emitOpError("entry block argument #")
+ << i << " is not of LLVM type";
+ if (funcType->getParamType(i) != argLLVMType.getUnderlyingType())
+ return op.emitOpError("the type of entry block argument #")
+ << i << " does not match the function signature";
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// LLVMDialect initialization, type parsing, and registration.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMDialectImpl {
+ LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {}
+
+ llvm::LLVMContext llvmContext;
+ llvm::Module module;
+
+ /// A set of LLVMTypes that are cached on construction to avoid any lookups or
+ /// locking.
+ LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
+ LLVMType doubleTy, floatTy, halfTy;
+ LLVMType voidTy;
+
+ /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not
+ /// multi-threaded and requires locked access to prevent race conditions.
+ llvm::sys::SmartMutex<true> mutex;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
+LLVMDialect::LLVMDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context),
+ impl(new detail::LLVMDialectImpl()) {
+ addTypes<LLVMType>();
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
+ >();
+
+ // Support unknown operations because not all LLVM operations are registered.
+ allowUnknownOperations();
+
+ // Cache some of the common LLVM types to avoid the need for lookups/locking.
+ auto &llvmContext = impl->llvmContext;
+ /// Integer Types.
+ impl->int1Ty = LLVMType::get(context, llvm::Type::getInt1Ty(llvmContext));
+ impl->int8Ty = LLVMType::get(context, llvm::Type::getInt8Ty(llvmContext));
+ impl->int16Ty = LLVMType::get(context, llvm::Type::getInt16Ty(llvmContext));
+ impl->int32Ty = LLVMType::get(context, llvm::Type::getInt32Ty(llvmContext));
+ impl->int64Ty = LLVMType::get(context, llvm::Type::getInt64Ty(llvmContext));
+ impl->int128Ty = LLVMType::get(context, llvm::Type::getInt128Ty(llvmContext));
+ /// Float Types.
+ impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext));
+ impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext));
+ impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext));
+ /// Other Types.
+ impl->voidTy = LLVMType::get(context, llvm::Type::getVoidTy(llvmContext));
+}
+
+LLVMDialect::~LLVMDialect() {}
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
+
+llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
+llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
+
+/// Parse a type registered to this dialect.
+Type LLVMDialect::parseType(StringRef tyData, Location loc) const {
+ // LLVM is not thread-safe, so lock access to it.
+ llvm::sys::SmartScopedLock<true> lock(impl->mutex);
+
+ llvm::SMDiagnostic errorMessage;
+ llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module);
+ if (!type)
+ return (emitError(loc, errorMessage.getMessage()), nullptr);
+ return LLVMType::get(getContext(), type);
+}
+
+/// Print a type registered to this dialect.
+void LLVMDialect::printType(Type type, raw_ostream &os) const {
+ auto llvmType = type.dyn_cast<LLVMType>();
+ assert(llvmType && "printing wrong type");
+ assert(llvmType.getUnderlyingType() && "no underlying LLVM type");
+ llvmType.getUnderlyingType()->print(os);
+}
+
+/// Verify LLVMIR function argument attributes.
+LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
+ unsigned regionIdx,
+ unsigned argIdx,
+ NamedAttribute argAttr) {
+ // Check that llvm.noalias is a boolean attribute.
+ if (argAttr.first == "llvm.noalias" && !argAttr.second.isa<BoolAttr>())
+ return op->emitError()
+ << "llvm.noalias argument attribute of non boolean type";
+ return success();
+}
+
+static DialectRegistration<LLVMDialect> llvmDialect;
+
+//===----------------------------------------------------------------------===//
+// LLVMType.
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace LLVM {
+namespace detail {
+struct LLVMTypeStorage : public ::mlir::TypeStorage {
+ LLVMTypeStorage(llvm::Type *ty) : underlyingType(ty) {}
+
+ // LLVM types are pointer-unique.
+ using KeyTy = llvm::Type *;
+ bool operator==(const KeyTy &key) const { return key == underlyingType; }
+
+ static LLVMTypeStorage *construct(TypeStorageAllocator &allocator,
+ llvm::Type *ty) {
+ return new (allocator.allocate<LLVMTypeStorage>()) LLVMTypeStorage(ty);
+ }
+
+ llvm::Type *underlyingType;
+};
+} // end namespace detail
+} // end namespace LLVM
+} // end namespace mlir
+
+LLVMType LLVMType::get(MLIRContext *context, llvm::Type *llvmType) {
+ return Base::get(context, FIRST_LLVM_TYPE, llvmType);
+}
+
+/// Get an LLVMType with an llvm type that may cause changes to the underlying
+/// llvm context when constructed.
+LLVMType LLVMType::getLocked(LLVMDialect *dialect,
+ llvm::function_ref<llvm::Type *()> typeBuilder) {
+ // Lock access to the llvm context and build the type.
+ llvm::sys::SmartScopedLock<true> lock(dialect->impl->mutex);
+ return get(dialect->getContext(), typeBuilder());
+}
+
+LLVMDialect &LLVMType::getDialect() {
+ return static_cast<LLVMDialect &>(Type::getDialect());
+}
+
+llvm::Type *LLVMType::getUnderlyingType() const {
+ return getImpl()->underlyingType;
+}
+
+/// Array type utilities.
+LLVMType LLVMType::getArrayElementType() {
+ return get(getContext(), getUnderlyingType()->getArrayElementType());
+}
+unsigned LLVMType::getArrayNumElements() {
+ return getUnderlyingType()->getArrayNumElements();
+}
+
+/// Vector type utilities.
+LLVMType LLVMType::getVectorElementType() {
+ return get(getContext(), getUnderlyingType()->getVectorElementType());
+}
+
+/// Function type utilities.
+LLVMType LLVMType::getFunctionParamType(unsigned argIdx) {
+ return get(getContext(), getUnderlyingType()->getFunctionParamType(argIdx));
+}
+unsigned LLVMType::getFunctionNumParams() {
+ return getUnderlyingType()->getFunctionNumParams();
+}
+LLVMType LLVMType::getFunctionResultType() {
+ return get(
+ getContext(),
+ llvm::cast<llvm::FunctionType>(getUnderlyingType())->getReturnType());
+}
+
+/// Pointer type utilities.
+LLVMType LLVMType::getPointerTo(unsigned addrSpace) {
+ // Lock access to the dialect as this may modify the LLVM context.
+ return getLocked(&getDialect(), [=] {
+ return getUnderlyingType()->getPointerTo(addrSpace);
+ });
+}
+LLVMType LLVMType::getPointerElementTy() {
+ return get(getContext(), getUnderlyingType()->getPointerElementType());
+}
+
+/// Struct type utilities.
+LLVMType LLVMType::getStructElementType(unsigned i) {
+ return get(getContext(), getUnderlyingType()->getStructElementType(i));
+}
+
+/// Utilities used to generate floating point types.
+LLVMType LLVMType::getDoubleTy(LLVMDialect *dialect) {
+ return dialect->impl->doubleTy;
+}
+LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) {
+ return dialect->impl->floatTy;
+}
+LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) {
+ return dialect->impl->halfTy;
+}
+
+/// Utilities used to generate integer types.
+LLVMType LLVMType::getIntNTy(LLVMDialect *dialect, unsigned numBits) {
+ switch (numBits) {
+ case 1:
+ return dialect->impl->int1Ty;
+ case 8:
+ return dialect->impl->int8Ty;
+ case 16:
+ return dialect->impl->int16Ty;
+ case 32:
+ return dialect->impl->int32Ty;
+ case 64:
+ return dialect->impl->int64Ty;
+ case 128:
+ return dialect->impl->int128Ty;
+ default:
+ break;
+ }
+
+ // Lock access to the dialect as this may modify the LLVM context.
+ return getLocked(dialect, [=] {
+ return llvm::Type::getIntNTy(dialect->getLLVMContext(), numBits);
+ });
+}
+
+/// Utilities used to generate other miscellaneous types.
+LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) {
+ // Lock access to the dialect as this may modify the LLVM context.
+ return getLocked(&elementType.getDialect(), [=] {
+ return llvm::ArrayType::get(elementType.getUnderlyingType(), numElements);
+ });
+}
+LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params,
+ bool isVarArg) {
+ SmallVector<llvm::Type *, 8> llvmParams;
+ for (auto param : params)
+ llvmParams.push_back(param.getUnderlyingType());
+
+ // Lock access to the dialect as this may modify the LLVM context.
+ return getLocked(&result.getDialect(), [=] {
+ return llvm::FunctionType::get(result.getUnderlyingType(), llvmParams,
+ isVarArg);
+ });
+}
+LLVMType LLVMType::getStructTy(LLVMDialect *dialect,
+ ArrayRef<LLVMType> elements, bool isPacked) {
+ SmallVector<llvm::Type *, 8> llvmElements;
+ for (auto elt : elements)
+ llvmElements.push_back(elt.getUnderlyingType());
+
+ // Lock access to the dialect as this may modify the LLVM context.
+ return getLocked(dialect, [=] {
+ return llvm::StructType::get(dialect->getLLVMContext(), llvmElements,
+ isPacked);
+ });
+}
+LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) {
+ // Lock access to the dialect as this may modify the LLVM context.
+ return getLocked(&elementType.getDialect(), [=] {
+ return llvm::VectorType::get(elementType.getUnderlyingType(), numElements);
+ });
+}
+LLVMType LLVMType::getVoidTy(LLVMDialect *dialect) {
+ return dialect->impl->voidTy;
+}
OpenPOWER on IntegriCloud