summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/GPU/IR/GPUDialect.cpp')
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp201
1 files changed, 195 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index bfd094d6203..5fc1cade760 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -45,7 +45,7 @@ bool GPUDialect::isKernel(Operation *op) {
GPUDialect::GPUDialect(MLIRContext *context)
: Dialect(getDialectName(), context) {
- addOperations<LaunchOp, LaunchFuncOp,
+ addOperations<LaunchOp, LaunchFuncOp, GPUFuncOp,
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
>();
@@ -93,7 +93,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
// Check that `launch_func` refers to a well-formed kernel function.
StringRef kernelName = launchOp.kernel();
Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
- auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc);
+ auto kernelStdFunction = dyn_cast_or_null<::mlir::FuncOp>(kernelFunc);
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
if (!kernelStdFunction && !kernelLLVMFunction)
return launchOp.emitOpError("kernel function '")
@@ -501,9 +501,10 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
void LaunchFuncOp::build(Builder *builder, OperationState &result,
- FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
- Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
- Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
+ ::mlir::FuncOp kernelFunc, Value *gridSizeX,
+ Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
+ Value *blockSizeY, Value *blockSizeZ,
+ ArrayRef<Value *> kernelOperands) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands(
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
@@ -517,7 +518,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState &result,
}
void LaunchFuncOp::build(Builder *builder, OperationState &result,
- FuncOp kernelFunc, KernelDim3 gridSize,
+ ::mlir::FuncOp kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands) {
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
@@ -572,3 +573,191 @@ LogicalResult LaunchFuncOp::verify() {
return success();
}
+
+//===----------------------------------------------------------------------===//
+// GPUFuncOp
+//===----------------------------------------------------------------------===//
+
+void GPUFuncOp::build(Builder *builder, OperationState &result, StringRef name,
+ FunctionType type, ArrayRef<Type> workgroupAttributions,
+ ArrayRef<Type> privateAttributions,
+ ArrayRef<NamedAttribute> attrs) {
+ result.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+ result.addAttribute(getNumWorkgroupAttributionsAttrName(),
+ builder->getI64IntegerAttr(workgroupAttributions.size()));
+ result.addAttributes(attrs);
+ Region *body = result.addRegion();
+ Block *entryBlock = new Block;
+ entryBlock->addArguments(type.getInputs());
+ entryBlock->addArguments(workgroupAttributions);
+ entryBlock->addArguments(privateAttributions);
+
+ body->getBlocks().push_back(entryBlock);
+}
+
+/// Parses a GPU function memory attribution.
+///
+/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
+/// (`private` `(` ssa-id-and-type-list `)`)?
+///
+/// Note that this function parses only one of the two similar parts, with the
+/// keyword provided as argument.
+static ParseResult
+parseAttributions(OpAsmParser &parser, StringRef keyword,
+ SmallVectorImpl<OpAsmParser::OperandType> &args,
+ SmallVectorImpl<Type> &argTypes) {
+ // If we could not parse the keyword, just assume empty list and succeed.
+ if (failed(parser.parseOptionalKeyword(keyword)))
+ return success();
+
+ if (failed(parser.parseLParen()))
+ return failure();
+
+ // Early exit for an empty list.
+ if (succeeded(parser.parseOptionalRParen()))
+ return success();
+
+ do {
+ OpAsmParser::OperandType arg;
+ Type type;
+
+ if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
+ return failure();
+
+ args.push_back(arg);
+ argTypes.push_back(type);
+ } while (succeeded(parser.parseOptionalComma()));
+
+ return parser.parseRParen();
+}
+
+/// Parses a GPU function.
+///
+/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
+/// (`->` function-result-list)? memory-attribution `kernel`?
+/// function-attributes? region
+ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 8> entryArgs;
+ SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
+ SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
+ SmallVector<Type, 8> argTypes;
+ SmallVector<Type, 4> resultTypes;
+ bool isVariadic;
+
+ // Parse the function name.
+ StringAttr nameAttr;
+ if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
+ result.attributes))
+ return failure();
+
+ auto signatureLocation = parser.getCurrentLocation();
+ if (failed(impl::parseFunctionSignature(
+ parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
+ isVariadic, resultTypes, resultAttrs)))
+ return failure();
+
+ if (entryArgs.empty() && !argTypes.empty())
+ return parser.emitError(signatureLocation)
+ << "gpu.func requires named arguments";
+
+ // Construct the function type. More types will be added to the region, but
+ // not to the functiont type.
+ Builder &builder = parser.getBuilder();
+ auto type = builder.getFunctionType(argTypes, resultTypes);
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+
+ // Parse workgroup memory attributions.
+ if (failed(parseAttributions(parser, getWorkgroupKeyword(), entryArgs,
+ argTypes)))
+ return failure();
+
+ // Store the number of operands we just parsed as the number of workgroup
+ // memory attributions.
+ unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs();
+ result.addAttribute(getNumWorkgroupAttributionsAttrName(),
+ builder.getI64IntegerAttr(numWorkgroupAttrs));
+
+ // Parse private memory attributions.
+ if (failed(
+ parseAttributions(parser, getPrivateKeyword(), entryArgs, argTypes)))
+ return failure();
+
+ // Parse the kernel attribute if present.
+ if (succeeded(parser.parseOptionalKeyword(getKernelKeyword())))
+ result.addAttribute(GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+
+ // Parse attributes.
+ if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
+ return failure();
+ mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+
+ // Parse the region. If no argument names were provided, take all names
+ // (including those of attributions) from the entry block.
+ auto *body = result.addRegion();
+ return parser.parseRegion(*body, entryArgs, argTypes);
+}
+
+static void printAttributions(OpAsmPrinter &p, StringRef keyword,
+ ArrayRef<BlockArgument *> values) {
+ if (values.empty())
+ return;
+
+ p << ' ' << keyword << '(';
+ interleaveComma(values, p.getStream(),
+ [&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
+ p << ')';
+}
+
+void GPUFuncOp::print(OpAsmPrinter &p) {
+ p << getOperationName() << ' ';
+ p.printSymbolName(getName());
+
+ FunctionType type = getType();
+ impl::printFunctionSignature(p, this->getOperation(), type.getInputs(),
+ /*isVariadic=*/false, type.getResults());
+
+ printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
+ printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
+ if (isKernel())
+ p << ' ' << getKernelKeyword();
+
+ impl::printFunctionAttributes(p, this->getOperation(), type.getNumInputs(),
+ type.getNumResults(),
+ {getNumWorkgroupAttributionsAttrName(),
+ GPUDialect::getKernelFuncAttrName()});
+ p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
+}
+
+/// Hook for FunctionLike verifier.
+LogicalResult GPUFuncOp::verifyType() {
+ Type type = getTypeAttr().getValue();
+ if (!type.isa<FunctionType>())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of function type");
+ return success();
+}
+
+/// Verifies the body of the function.
+LogicalResult GPUFuncOp::verifyBody() {
+ unsigned numFuncArguments = getNumArguments();
+ unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
+ unsigned numBlockArguments = front().getNumArguments();
+ if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
+ return emitOpError() << "expected at least "
+ << numFuncArguments + numWorkgroupAttributions
+ << " arguments to body region";
+
+ ArrayRef<Type> funcArgTypes = getType().getInputs();
+ for (unsigned i = 0; i < numFuncArguments; ++i) {
+ Type blockArgType = front().getArgument(i)->getType();
+ if (funcArgTypes[i] != blockArgType)
+ return emitOpError() << "expected body region argument #" << i
+ << " to be of type " << funcArgTypes[i] << ", got "
+ << blockArgType;
+ }
+
+ return success();
+}
OpenPOWER on IntegriCloud