summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp201
-rw-r--r--mlir/lib/IR/FunctionSupport.cpp102
2 files changed, 256 insertions, 47 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index bfd094d6203..5fc1cade760 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -45,7 +45,7 @@ bool GPUDialect::isKernel(Operation *op) {
GPUDialect::GPUDialect(MLIRContext *context)
: Dialect(getDialectName(), context) {
- addOperations<LaunchOp, LaunchFuncOp,
+ addOperations<LaunchOp, LaunchFuncOp, GPUFuncOp,
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
>();
@@ -93,7 +93,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
// Check that `launch_func` refers to a well-formed kernel function.
StringRef kernelName = launchOp.kernel();
Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
- auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc);
+ auto kernelStdFunction = dyn_cast_or_null<::mlir::FuncOp>(kernelFunc);
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
if (!kernelStdFunction && !kernelLLVMFunction)
return launchOp.emitOpError("kernel function '")
@@ -501,9 +501,10 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
void LaunchFuncOp::build(Builder *builder, OperationState &result,
- FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
- Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
- Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
+ ::mlir::FuncOp kernelFunc, Value *gridSizeX,
+ Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
+ Value *blockSizeY, Value *blockSizeZ,
+ ArrayRef<Value *> kernelOperands) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands(
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
@@ -517,7 +518,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState &result,
}
void LaunchFuncOp::build(Builder *builder, OperationState &result,
- FuncOp kernelFunc, KernelDim3 gridSize,
+ ::mlir::FuncOp kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands) {
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
@@ -572,3 +573,191 @@ LogicalResult LaunchFuncOp::verify() {
return success();
}
+
+//===----------------------------------------------------------------------===//
+// GPUFuncOp
+//===----------------------------------------------------------------------===//
+
+void GPUFuncOp::build(Builder *builder, OperationState &result, StringRef name,
+ FunctionType type, ArrayRef<Type> workgroupAttributions,
+ ArrayRef<Type> privateAttributions,
+ ArrayRef<NamedAttribute> attrs) {
+ result.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+ result.addAttribute(getNumWorkgroupAttributionsAttrName(),
+ builder->getI64IntegerAttr(workgroupAttributions.size()));
+ result.addAttributes(attrs);
+ Region *body = result.addRegion();
+ Block *entryBlock = new Block;
+ entryBlock->addArguments(type.getInputs());
+ entryBlock->addArguments(workgroupAttributions);
+ entryBlock->addArguments(privateAttributions);
+
+ body->getBlocks().push_back(entryBlock);
+}
+
+/// Parses a GPU function memory attribution.
+///
+/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
+/// (`private` `(` ssa-id-and-type-list `)`)?
+///
+/// Note that this function parses only one of the two similar parts, with the
+/// keyword provided as argument.
+static ParseResult
+parseAttributions(OpAsmParser &parser, StringRef keyword,
+ SmallVectorImpl<OpAsmParser::OperandType> &args,
+ SmallVectorImpl<Type> &argTypes) {
+ // If we could not parse the keyword, just assume empty list and succeed.
+ if (failed(parser.parseOptionalKeyword(keyword)))
+ return success();
+
+ if (failed(parser.parseLParen()))
+ return failure();
+
+ // Early exit for an empty list.
+ if (succeeded(parser.parseOptionalRParen()))
+ return success();
+
+ do {
+ OpAsmParser::OperandType arg;
+ Type type;
+
+ if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
+ return failure();
+
+ args.push_back(arg);
+ argTypes.push_back(type);
+ } while (succeeded(parser.parseOptionalComma()));
+
+ return parser.parseRParen();
+}
+
+/// Parses a GPU function.
+///
+/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
+/// (`->` function-result-list)? memory-attribution `kernel`?
+/// function-attributes? region
+ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 8> entryArgs;
+ SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
+ SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
+ SmallVector<Type, 8> argTypes;
+ SmallVector<Type, 4> resultTypes;
+ bool isVariadic;
+
+ // Parse the function name.
+ StringAttr nameAttr;
+ if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
+ result.attributes))
+ return failure();
+
+ auto signatureLocation = parser.getCurrentLocation();
+ if (failed(impl::parseFunctionSignature(
+ parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
+ isVariadic, resultTypes, resultAttrs)))
+ return failure();
+
+ if (entryArgs.empty() && !argTypes.empty())
+ return parser.emitError(signatureLocation)
+ << "gpu.func requires named arguments";
+
+ // Construct the function type. More types will be added to the region, but
+ // not to the functiont type.
+ Builder &builder = parser.getBuilder();
+ auto type = builder.getFunctionType(argTypes, resultTypes);
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+
+ // Parse workgroup memory attributions.
+ if (failed(parseAttributions(parser, getWorkgroupKeyword(), entryArgs,
+ argTypes)))
+ return failure();
+
+ // Store the number of operands we just parsed as the number of workgroup
+ // memory attributions.
+ unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs();
+ result.addAttribute(getNumWorkgroupAttributionsAttrName(),
+ builder.getI64IntegerAttr(numWorkgroupAttrs));
+
+ // Parse private memory attributions.
+ if (failed(
+ parseAttributions(parser, getPrivateKeyword(), entryArgs, argTypes)))
+ return failure();
+
+ // Parse the kernel attribute if present.
+ if (succeeded(parser.parseOptionalKeyword(getKernelKeyword())))
+ result.addAttribute(GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+
+ // Parse attributes.
+ if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
+ return failure();
+ mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+
+ // Parse the region. If no argument names were provided, take all names
+ // (including those of attributions) from the entry block.
+ auto *body = result.addRegion();
+ return parser.parseRegion(*body, entryArgs, argTypes);
+}
+
+static void printAttributions(OpAsmPrinter &p, StringRef keyword,
+ ArrayRef<BlockArgument *> values) {
+ if (values.empty())
+ return;
+
+ p << ' ' << keyword << '(';
+ interleaveComma(values, p.getStream(),
+ [&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
+ p << ')';
+}
+
+void GPUFuncOp::print(OpAsmPrinter &p) {
+ p << getOperationName() << ' ';
+ p.printSymbolName(getName());
+
+ FunctionType type = getType();
+ impl::printFunctionSignature(p, this->getOperation(), type.getInputs(),
+ /*isVariadic=*/false, type.getResults());
+
+ printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
+ printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
+ if (isKernel())
+ p << ' ' << getKernelKeyword();
+
+ impl::printFunctionAttributes(p, this->getOperation(), type.getNumInputs(),
+ type.getNumResults(),
+ {getNumWorkgroupAttributionsAttrName(),
+ GPUDialect::getKernelFuncAttrName()});
+ p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
+}
+
+/// Hook for FunctionLike verifier.
+LogicalResult GPUFuncOp::verifyType() {
+ Type type = getTypeAttr().getValue();
+ if (!type.isa<FunctionType>())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of function type");
+ return success();
+}
+
+/// Verifies the body of the function.
+LogicalResult GPUFuncOp::verifyBody() {
+ unsigned numFuncArguments = getNumArguments();
+ unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
+ unsigned numBlockArguments = front().getNumArguments();
+ if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
+ return emitOpError() << "expected at least "
+ << numFuncArguments + numWorkgroupAttributions
+ << " arguments to body region";
+
+ ArrayRef<Type> funcArgTypes = getType().getInputs();
+ for (unsigned i = 0; i < numFuncArguments; ++i) {
+ Type blockArgType = front().getArgument(i)->getType();
+ if (funcArgTypes[i] != blockArgType)
+ return emitOpError() << "expected body region argument #" << i
+ << " to be of type " << funcArgTypes[i] << ", got "
+ << blockArgType;
+ }
+
+ return success();
+}
diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index 1f39575331c..c6f2673ef2a 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -128,9 +128,11 @@ static ParseResult parseFunctionResultList(
return parser.parseRParen();
}
-/// Parse a function signature, starting with a name and including the
-/// parameter list.
-static ParseResult parseFunctionSignature(
+/// Parses a function signature using `parser`. The `allowVariadic` argument
+/// indicates whether functions with variadic arguments are supported. The
+/// trailing arguments are populated by this function with names, types and
+/// attributes of the arguments and those of the results.
+ParseResult mlir::impl::parseFunctionSignature(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::OperandType> &argNames,
SmallVectorImpl<Type> &argTypes,
@@ -145,6 +147,24 @@ static ParseResult parseFunctionSignature(
return success();
}
+void mlir::impl::addArgAndResultAttrs(
+ Builder &builder, OperationState &result,
+ ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs,
+ ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) {
+ // Add the attributes to the function arguments.
+ SmallString<8> attrNameBuf;
+ for (unsigned i = 0, e = argAttrs.size(); i != e; ++i)
+ if (!argAttrs[i].empty())
+ result.addAttribute(getArgAttrName(i, attrNameBuf),
+ builder.getDictionaryAttr(argAttrs[i]));
+
+ // Add the attributes to the function results.
+ for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i)
+ if (!resultAttrs[i].empty())
+ result.addAttribute(getResultAttrName(i, attrNameBuf),
+ builder.getDictionaryAttr(resultAttrs[i]));
+}
+
/// Parser implementation for function-like operations. Uses `funcTypeBuilder`
/// to construct the custom function type given lists of input and output types.
ParseResult
@@ -158,7 +178,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
SmallVector<Type, 4> resultTypes;
auto &builder = parser.getBuilder();
- // Parse the name as a symbol reference attribute.
+ // Parse the name as a symbol.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
@@ -185,26 +205,14 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
return failure();
// Add the attributes to the function arguments.
- SmallString<8> attrNameBuf;
- for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
- if (!argAttrs[i].empty())
- result.addAttribute(getArgAttrName(i, attrNameBuf),
- builder.getDictionaryAttr(argAttrs[i]));
-
- // Add the attributes to the function results.
- for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
- if (!resultAttrs[i].empty())
- result.addAttribute(getResultAttrName(i, attrNameBuf),
- builder.getDictionaryAttr(resultAttrs[i]));
+ assert(argAttrs.size() == argTypes.size());
+ assert(resultAttrs.size() == resultTypes.size());
+ addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
// Parse the optional function body.
auto *body = result.addRegion();
- if (parser.parseOptionalRegion(*body, entryArgs,
- entryArgs.empty() ? llvm::ArrayRef<Type>()
- : argTypes))
- return failure();
-
- return success();
+ return parser.parseOptionalRegion(
+ *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes);
}
// Print a function result list.
@@ -227,9 +235,10 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
/// Print the signature of the function-like operation `op`. Assumes `op` has
/// the FunctionLike trait and passed the verification.
-static void printSignature(OpAsmPrinter &p, Operation *op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
+void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
+ ArrayRef<Type> argTypes,
+ bool isVariadic,
+ ArrayRef<Type> resultTypes) {
Region &body = op->getRegion(0);
bool isExternal = body.empty();
@@ -264,42 +273,53 @@ static void printSignature(OpAsmPrinter &p, Operation *op,
}
}
-/// Printer implementation for function-like operations. Accepts lists of
-/// argument and result types to use while printing.
-void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
- // Print the operation and the function name.
- auto funcName =
- op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
- .getValue();
- p << op->getName() << ' ';
- p.printSymbolName(funcName);
-
- // Print the signature.
- printSignature(p, op, argTypes, isVariadic, resultTypes);
-
+/// Prints the list of function prefixed with the "attributes" keyword. The
+/// attributes with names listed in "elided" as well as those used by the
+/// function-like operation internally are not printed. Nothing is printed
+/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
+/// passed the verification.
+void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op,
+ unsigned numInputs,
+ unsigned numResults,
+ ArrayRef<StringRef> elided) {
// Print out function attributes, if present.
SmallVector<StringRef, 2> ignoredAttrs = {
::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
+ ignoredAttrs.append(elided.begin(), elided.end());
SmallString<8> attrNameBuf;
// Ignore any argument attributes.
std::vector<SmallString<8>> argAttrStorage;
- for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
+ for (unsigned i = 0; i != numInputs; ++i)
if (op->getAttr(getArgAttrName(i, attrNameBuf)))
argAttrStorage.emplace_back(attrNameBuf);
ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
// Ignore any result attributes.
std::vector<SmallString<8>> resultAttrStorage;
- for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+ for (unsigned i = 0; i != numResults; ++i)
if (op->getAttr(getResultAttrName(i, attrNameBuf)))
resultAttrStorage.emplace_back(attrNameBuf);
ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
+}
+
+/// Printer implementation for function-like operations. Accepts lists of
+/// argument and result types to use while printing.
+void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
+ ArrayRef<Type> argTypes, bool isVariadic,
+ ArrayRef<Type> resultTypes) {
+ // Print the operation and the function name.
+ auto funcName =
+ op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
+ .getValue();
+ p << op->getName() << ' ';
+ p.printSymbolName(funcName);
+
+ printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
+ printFunctionAttributes(p, op, argTypes.size(), resultTypes.size());
// Print the body if this is not an external function.
Region &body = op->getRegion(0);
OpenPOWER on IntegriCloud