summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/g3doc/Dialects/GPU.md78
-rw-r--r--mlir/include/mlir/Dialect/GPU/GPUDialect.h84
-rw-r--r--mlir/include/mlir/IR/FunctionSupport.h75
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp201
-rw-r--r--mlir/lib/IR/FunctionSupport.cpp102
-rw-r--r--mlir/test/Dialect/GPU/invalid.mlir22
-rw-r--r--mlir/test/Dialect/GPU/ops.mlir49
7 files changed, 552 insertions, 59 deletions
diff --git a/mlir/g3doc/Dialects/GPU.md b/mlir/g3doc/Dialects/GPU.md
index 7d27e1555e1..b1cc30e510f 100644
--- a/mlir/g3doc/Dialects/GPU.md
+++ b/mlir/g3doc/Dialects/GPU.md
@@ -47,6 +47,84 @@ Example:
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
```
+### `gpu.func`
+
+Defines a function that can be executed on a GPU. This supports memory
+attribution and its body has a particular execution model.
+
+GPU functions are either kernels (as indicated by the `kernel` attribute) or
+regular functions. The former can be launched from the host side, while the
+latter are device side only.
+
+The memory attribution defines SSA values that correspond to memory buffers
+allocated in the memory hierarchy of the GPU (see below).
+
+The operation has one attached region that corresponds to the body of the
+function. The region arguments consist of the function arguments without
+modification, followed by buffers defined in memory annotations. The body of a
+GPU function, when launched, is executed by multiple work items. There are no
+guarantees on the order in which work items execute, or on the connection
+between them. In particular, work items are not necessarily executed in
+lock-step. Synchronization ops such as "gpu.barrier" should be used to
+coordinate work items. Declarations of GPU functions, i.e. not having the body
+region, are not supported.
+
+#### Memory attribution
+
+Memory buffers are defined at the function level, either in "gpu.launch" or in
+"gpu.func" ops. This encoding makes it clear where the memory belongs and makes
+the lifetime of the memory visible. The memory is only accessible while the
+kernel is launched/the function is currently invoked. The latter is more strict
+than actual GPU implementations but using static memory at the function level is
+just for convenience. It is also always possible to pass pointers to the
+workgroup memory into other functions, provided they expect the correct memory
+space.
+
+The buffers are considered live throughout the execution of the GPU function
+body. The absence of memory attribution syntax means that the function does not
+require special buffers. Rationale: although the underlying models declare
+memory buffers at the module level, we chose to do it at the function level to
+provide some structuring for the lifetime of those buffers; this avoids the
+incentive to use the buffers for communicating between different kernels or
+launches of the same kernel, which should be done through function arguments
+intead; we chose not to use `alloca`-style approach that would require more
+complex lifetime analysis following the principles of MLIR that promote
+structure and representing analysis results in the IR.
+
+Syntax:
+
+``` {.ebnf}
+op ::= `gpu.func` symbol-ref-id `(` argument-list `)` (`->`
+function-result-list)?
+ memory-attribution `kernel`? function-attributes? region
+
+memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
+ (`private` `(` ssa-id-and-type-list `)`)?
+```
+
+Example:
+
+```mlir {.mlir}
+gpu.func @foo(%arg0: index)
+ workgroup(%workgroup: memref<32xf32, 3>)
+ private(%private: memref<1xf32, 5>)
+ kernel
+ attributes {qux: "quux"} {
+ gpu.return
+}
+```
+
+The generic form illustrates the concept
+
+```mlir {.mlir}
+"gpu.func"(%arg: index) {sym_name: "foo", kernel, qux: "quux"} ({
+^bb0(%arg0: index, %workgroup: memref<32xf32, 3>, %private: memref<1xf32, 5>):
+ "gpu.return"() : () -> ()
+}) : (index) -> ()
+```
+
+Note the non-default memory spaces used in memref types in memory-attribution.
+
### `gpu.launch`
Launch a kernel on the specified grid of thread blocks. The body of the kernel
diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
index f59cd32ff87..fb906b2ace5 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h
@@ -24,7 +24,9 @@
#define MLIR_DIALECT_GPU_GPUDIALECT_H
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/SymbolTable.h"
namespace mlir {
class FuncOp;
@@ -191,6 +193,88 @@ private:
static StringRef getKernelModuleAttrName() { return "kernel_module"; }
};
+class GPUFuncOp : public Op<GPUFuncOp, OpTrait::FunctionLike,
+ OpTrait::IsIsolatedFromAbove, OpTrait::Symbol> {
+public:
+ using Op::Op;
+
+ /// Returns the name of the operation.
+ static StringRef getOperationName() { return "gpu.func"; }
+
+ /// Constructs a FuncOp, hook for Builder methods.
+ static void build(Builder *builder, OperationState &result, StringRef name,
+ FunctionType type, ArrayRef<Type> workgroupAttributions,
+ ArrayRef<Type> privateAttributions,
+ ArrayRef<NamedAttribute> attrs);
+
+ /// Prints the Op in custom format.
+ void print(OpAsmPrinter &p);
+
+ /// Parses the Op in custom format.
+ static ParseResult parse(OpAsmParser &parser, OperationState &result);
+
+ /// Returns `true` if the GPU function defined by this Op is a kernel, i.e.
+ /// it is intended to be launched from host.
+ bool isKernel() {
+ return getAttrOfType<UnitAttr>(GPUDialect::getKernelFuncAttrName()) !=
+ nullptr;
+ }
+
+ /// Returns the type of the function this Op defines.
+ FunctionType getType() {
+ return getTypeAttr().getValue().cast<FunctionType>();
+ }
+
+ /// Returns the number of buffers located in the workgroup memory.
+ unsigned getNumWorkgroupAttributions() {
+ return getAttrOfType<IntegerAttr>(getNumWorkgroupAttributionsAttrName())
+ .getInt();
+ }
+
+ /// Returns a list of block arguments that correspond to buffers located in
+ /// the workgroup memory
+ ArrayRef<BlockArgument *> getWorkgroupAttributions() {
+ auto begin =
+ std::next(getBody().front().args_begin(), getType().getNumInputs());
+ auto end = std::next(begin, getNumWorkgroupAttributions());
+ return {begin, end};
+ }
+
+ /// Returns a list of block arguments that correspond to buffers located in
+ /// the private memory.
+ ArrayRef<BlockArgument *> getPrivateAttributions() {
+ auto begin =
+ std::next(getBody().front().args_begin(),
+ getType().getNumInputs() + getNumWorkgroupAttributions());
+ return {begin, getBody().front().args_end()};
+ }
+
+private:
+ // FunctionLike trait needs access to the functions below.
+ friend class OpTrait::FunctionLike<GPUFuncOp>;
+
+ /// Hooks for the input/output type enumeration in FunctionLike .
+ unsigned getNumFuncArguments() { return getType().getNumInputs(); }
+ unsigned getNumFuncResults() { return getType().getNumResults(); }
+
+ /// Returns the name of the attribute containing the number of buffers located
+ /// in the workgroup memory.
+ static StringRef getNumWorkgroupAttributionsAttrName() {
+ return "workgroup_attibutions";
+ }
+
+ /// Returns the keywords used in the custom syntax for this Op.
+ static StringRef getWorkgroupKeyword() { return "workgroup"; }
+ static StringRef getPrivateKeyword() { return "private"; }
+ static StringRef getKernelKeyword() { return "kernel"; }
+
+ /// Hook for FunctionLike verifier.
+ LogicalResult verifyType();
+
+ /// Verifies the body of the function.
+ LogicalResult verifyBody();
+};
+
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/GPUOps.h.inc"
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 6ce03f0b8cd..38e406e8f08 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -24,6 +24,7 @@
#define MLIR_IR_FUNCTIONSUPPORT_H
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/SmallString.h"
namespace mlir {
@@ -83,6 +84,14 @@ private:
bool variadic;
};
+/// Adds argument and result attributes, provided as `argAttrs` and
+/// `resultAttrs` arguments, to the list of operation attributes in `result`.
+/// Internally, argument and result attributes are stored as dict attributes
+/// with special names given by getResultAttrName, getArgumentAttrName.
+void addArgAndResultAttrs(Builder &builder, OperationState &result,
+ ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs,
+ ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs);
+
/// Callback type for `parseFunctionLikeOp`, the callback should produce the
/// type that will be associated with a function-like operation from lists of
/// function arguments and results, VariadicFlag indicates whether the function
@@ -91,6 +100,18 @@ private:
using FuncTypeBuilder = llvm::function_ref<Type(
Builder &, ArrayRef<Type>, ArrayRef<Type>, VariadicFlag, std::string &)>;
+/// Parses a function signature using `parser`. The `allowVariadic` argument
+/// indicates whether functions with variadic arguments are supported. The
+/// trailing arguments are populated by this function with names, types and
+/// attributes of the arguments and those of the results.
+ParseResult parseFunctionSignature(
+ OpAsmParser &parser, bool allowVariadic,
+ SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+ SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs, bool &isVariadic,
+ SmallVectorImpl<Type> &resultTypes,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &resultAttrs);
+
/// Parser implementation for function-like operations. Uses
/// `funcTypeBuilder` to construct the custom function type given lists of
/// input and output types. If `allowVariadic` is set, the parser will accept
@@ -108,6 +129,21 @@ void printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
ArrayRef<Type> argTypes, bool isVariadic,
ArrayRef<Type> resultTypes);
+/// Prints the signature of the function-like operation `op`. Assumes `op` has
+/// the FunctionLike trait and passed the verification.
+void printFunctionSignature(OpAsmPrinter &p, Operation *op,
+ ArrayRef<Type> argTypes, bool isVariadic,
+ ArrayRef<Type> resultTypes);
+
+/// Prints the list of function prefixed with the "attributes" keyword. The
+/// attributes with names listed in "elided" as well as those used by the
+/// function-like operation internally are not printed. Nothing is printed
+/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
+/// passed the verification.
+void printFunctionAttributes(OpAsmPrinter &p, Operation *op, unsigned numInputs,
+ unsigned numResults,
+ ArrayRef<StringRef> elided = {});
+
} // namespace impl
namespace OpTrait {
@@ -117,7 +153,7 @@ namespace OpTrait {
/// - Ops have a single region with multiple blocks that corresponds to the body
/// of the function;
/// - the absence of a region corresponds to an external function;
-/// - arguments of the first block of the region are treated as function
+/// - leading arguments of the first block of the region are treated as function
/// arguments;
/// - they can have argument attributes that are stored in a dictionary
/// attribute on the Op itself.
@@ -137,6 +173,9 @@ namespace OpTrait {
/// redefine the `verifyType()` hook that will be called after verifying the
/// presence of the `type` attribute and before any call to
/// `getNumFuncArguments`/`getNumFuncResults` from the verifier.
+/// - To verify that the body respects op-specific invariants, concrete ops may
+/// redefine the `verifyBody()` hook that will be called after verifying the
+/// function type and the presence of the (potentially empty) body region.
template <typename ConcreteType>
class FunctionLike : public OpTrait::TraitBase<ConcreteType, FunctionLike> {
public:
@@ -178,6 +217,11 @@ public:
Block &back() { return getBody().back(); }
Block &front() { return getBody().front(); }
+ /// Hook for concrete ops to verify the contents of the body. Called as a
+ /// part of trait verification, after type verification and ensuring that a
+ /// region exists.
+ LogicalResult verifyBody();
+
//===--------------------------------------------------------------------===//
// Type Attribute Handling
//===--------------------------------------------------------------------===//
@@ -384,6 +428,23 @@ protected:
LogicalResult verifyType() { return success(); }
};
+/// Default verifier checks that if the entry block exists, it has the same
+/// number of arguments as the function-like operation.
+template <typename ConcreteType>
+LogicalResult FunctionLike<ConcreteType>::verifyBody() {
+ auto funcOp = cast<ConcreteType>(this->getOperation());
+
+ if (funcOp.isExternal())
+ return success();
+
+ unsigned numArguments = funcOp.getNumArguments();
+ if (funcOp.front().getNumArguments() != numArguments)
+ return funcOp.emitOpError("entry block must have ")
+ << numArguments << " arguments to match function signature";
+
+ return success();
+}
+
template <typename ConcreteType>
LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
MLIRContext *ctx = op->getContext();
@@ -433,17 +494,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (op->getNumRegions() != 1)
return funcOp.emitOpError("expects one region");
- // Check that if the entry block exists, it has the same number of arguments
- // as the function-like operation.
- if (funcOp.isExternal())
- return success();
-
- unsigned numArguments = funcOp.getNumArguments();
- if (funcOp.front().getNumArguments() != numArguments)
- return funcOp.emitOpError("entry block must have ")
- << numArguments << " arguments to match function signature";
-
- return success();
+ return funcOp.verifyBody();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index bfd094d6203..5fc1cade760 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -45,7 +45,7 @@ bool GPUDialect::isKernel(Operation *op) {
GPUDialect::GPUDialect(MLIRContext *context)
: Dialect(getDialectName(), context) {
- addOperations<LaunchOp, LaunchFuncOp,
+ addOperations<LaunchOp, LaunchFuncOp, GPUFuncOp,
#define GET_OP_LIST
#include "mlir/Dialect/GPU/GPUOps.cpp.inc"
>();
@@ -93,7 +93,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
// Check that `launch_func` refers to a well-formed kernel function.
StringRef kernelName = launchOp.kernel();
Operation *kernelFunc = kernelModule.lookupSymbol(kernelName);
- auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc);
+ auto kernelStdFunction = dyn_cast_or_null<::mlir::FuncOp>(kernelFunc);
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
if (!kernelStdFunction && !kernelLLVMFunction)
return launchOp.emitOpError("kernel function '")
@@ -501,9 +501,10 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
//===----------------------------------------------------------------------===//
void LaunchFuncOp::build(Builder *builder, OperationState &result,
- FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY,
- Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY,
- Value *blockSizeZ, ArrayRef<Value *> kernelOperands) {
+ ::mlir::FuncOp kernelFunc, Value *gridSizeX,
+ Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX,
+ Value *blockSizeY, Value *blockSizeZ,
+ ArrayRef<Value *> kernelOperands) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands(
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
@@ -517,7 +518,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState &result,
}
void LaunchFuncOp::build(Builder *builder, OperationState &result,
- FuncOp kernelFunc, KernelDim3 gridSize,
+ ::mlir::FuncOp kernelFunc, KernelDim3 gridSize,
KernelDim3 blockSize,
ArrayRef<Value *> kernelOperands) {
build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z,
@@ -572,3 +573,191 @@ LogicalResult LaunchFuncOp::verify() {
return success();
}
+
+//===----------------------------------------------------------------------===//
+// GPUFuncOp
+//===----------------------------------------------------------------------===//
+
+void GPUFuncOp::build(Builder *builder, OperationState &result, StringRef name,
+ FunctionType type, ArrayRef<Type> workgroupAttributions,
+ ArrayRef<Type> privateAttributions,
+ ArrayRef<NamedAttribute> attrs) {
+ result.addAttribute(SymbolTable::getSymbolAttrName(),
+ builder->getStringAttr(name));
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+ result.addAttribute(getNumWorkgroupAttributionsAttrName(),
+ builder->getI64IntegerAttr(workgroupAttributions.size()));
+ result.addAttributes(attrs);
+ Region *body = result.addRegion();
+ Block *entryBlock = new Block;
+ entryBlock->addArguments(type.getInputs());
+ entryBlock->addArguments(workgroupAttributions);
+ entryBlock->addArguments(privateAttributions);
+
+ body->getBlocks().push_back(entryBlock);
+}
+
+/// Parses a GPU function memory attribution.
+///
+/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)?
+/// (`private` `(` ssa-id-and-type-list `)`)?
+///
+/// Note that this function parses only one of the two similar parts, with the
+/// keyword provided as argument.
+static ParseResult
+parseAttributions(OpAsmParser &parser, StringRef keyword,
+ SmallVectorImpl<OpAsmParser::OperandType> &args,
+ SmallVectorImpl<Type> &argTypes) {
+ // If we could not parse the keyword, just assume empty list and succeed.
+ if (failed(parser.parseOptionalKeyword(keyword)))
+ return success();
+
+ if (failed(parser.parseLParen()))
+ return failure();
+
+ // Early exit for an empty list.
+ if (succeeded(parser.parseOptionalRParen()))
+ return success();
+
+ do {
+ OpAsmParser::OperandType arg;
+ Type type;
+
+ if (parser.parseRegionArgument(arg) || parser.parseColonType(type))
+ return failure();
+
+ args.push_back(arg);
+ argTypes.push_back(type);
+ } while (succeeded(parser.parseOptionalComma()));
+
+ return parser.parseRParen();
+}
+
+/// Parses a GPU function.
+///
+/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)`
+/// (`->` function-result-list)? memory-attribution `kernel`?
+/// function-attributes? region
+ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::OperandType, 8> entryArgs;
+ SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
+ SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
+ SmallVector<Type, 8> argTypes;
+ SmallVector<Type, 4> resultTypes;
+ bool isVariadic;
+
+ // Parse the function name.
+ StringAttr nameAttr;
+ if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
+ result.attributes))
+ return failure();
+
+ auto signatureLocation = parser.getCurrentLocation();
+ if (failed(impl::parseFunctionSignature(
+ parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs,
+ isVariadic, resultTypes, resultAttrs)))
+ return failure();
+
+ if (entryArgs.empty() && !argTypes.empty())
+ return parser.emitError(signatureLocation)
+ << "gpu.func requires named arguments";
+
+ // Construct the function type. More types will be added to the region, but
+ // not to the functiont type.
+ Builder &builder = parser.getBuilder();
+ auto type = builder.getFunctionType(argTypes, resultTypes);
+ result.addAttribute(getTypeAttrName(), TypeAttr::get(type));
+
+ // Parse workgroup memory attributions.
+ if (failed(parseAttributions(parser, getWorkgroupKeyword(), entryArgs,
+ argTypes)))
+ return failure();
+
+ // Store the number of operands we just parsed as the number of workgroup
+ // memory attributions.
+ unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs();
+ result.addAttribute(getNumWorkgroupAttributionsAttrName(),
+ builder.getI64IntegerAttr(numWorkgroupAttrs));
+
+ // Parse private memory attributions.
+ if (failed(
+ parseAttributions(parser, getPrivateKeyword(), entryArgs, argTypes)))
+ return failure();
+
+ // Parse the kernel attribute if present.
+ if (succeeded(parser.parseOptionalKeyword(getKernelKeyword())))
+ result.addAttribute(GPUDialect::getKernelFuncAttrName(),
+ builder.getUnitAttr());
+
+ // Parse attributes.
+ if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
+ return failure();
+ mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
+
+ // Parse the region. If no argument names were provided, take all names
+ // (including those of attributions) from the entry block.
+ auto *body = result.addRegion();
+ return parser.parseRegion(*body, entryArgs, argTypes);
+}
+
+static void printAttributions(OpAsmPrinter &p, StringRef keyword,
+ ArrayRef<BlockArgument *> values) {
+ if (values.empty())
+ return;
+
+ p << ' ' << keyword << '(';
+ interleaveComma(values, p.getStream(),
+ [&p](BlockArgument *v) { p << *v << " : " << v->getType(); });
+ p << ')';
+}
+
+void GPUFuncOp::print(OpAsmPrinter &p) {
+ p << getOperationName() << ' ';
+ p.printSymbolName(getName());
+
+ FunctionType type = getType();
+ impl::printFunctionSignature(p, this->getOperation(), type.getInputs(),
+ /*isVariadic=*/false, type.getResults());
+
+ printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
+ printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
+ if (isKernel())
+ p << ' ' << getKernelKeyword();
+
+ impl::printFunctionAttributes(p, this->getOperation(), type.getNumInputs(),
+ type.getNumResults(),
+ {getNumWorkgroupAttributionsAttrName(),
+ GPUDialect::getKernelFuncAttrName()});
+ p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
+}
+
+/// Hook for FunctionLike verifier.
+LogicalResult GPUFuncOp::verifyType() {
+ Type type = getTypeAttr().getValue();
+ if (!type.isa<FunctionType>())
+ return emitOpError("requires '" + getTypeAttrName() +
+ "' attribute of function type");
+ return success();
+}
+
+/// Verifies the body of the function.
+LogicalResult GPUFuncOp::verifyBody() {
+ unsigned numFuncArguments = getNumArguments();
+ unsigned numWorkgroupAttributions = getNumWorkgroupAttributions();
+ unsigned numBlockArguments = front().getNumArguments();
+ if (numBlockArguments < numFuncArguments + numWorkgroupAttributions)
+ return emitOpError() << "expected at least "
+ << numFuncArguments + numWorkgroupAttributions
+ << " arguments to body region";
+
+ ArrayRef<Type> funcArgTypes = getType().getInputs();
+ for (unsigned i = 0; i < numFuncArguments; ++i) {
+ Type blockArgType = front().getArgument(i)->getType();
+ if (funcArgTypes[i] != blockArgType)
+ return emitOpError() << "expected body region argument #" << i
+ << " to be of type " << funcArgTypes[i] << ", got "
+ << blockArgType;
+ }
+
+ return success();
+}
diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp
index 1f39575331c..c6f2673ef2a 100644
--- a/mlir/lib/IR/FunctionSupport.cpp
+++ b/mlir/lib/IR/FunctionSupport.cpp
@@ -128,9 +128,11 @@ static ParseResult parseFunctionResultList(
return parser.parseRParen();
}
-/// Parse a function signature, starting with a name and including the
-/// parameter list.
-static ParseResult parseFunctionSignature(
+/// Parses a function signature using `parser`. The `allowVariadic` argument
+/// indicates whether functions with variadic arguments are supported. The
+/// trailing arguments are populated by this function with names, types and
+/// attributes of the arguments and those of the results.
+ParseResult mlir::impl::parseFunctionSignature(
OpAsmParser &parser, bool allowVariadic,
SmallVectorImpl<OpAsmParser::OperandType> &argNames,
SmallVectorImpl<Type> &argTypes,
@@ -145,6 +147,24 @@ static ParseResult parseFunctionSignature(
return success();
}
+void mlir::impl::addArgAndResultAttrs(
+ Builder &builder, OperationState &result,
+ ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs,
+ ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) {
+ // Add the attributes to the function arguments.
+ SmallString<8> attrNameBuf;
+ for (unsigned i = 0, e = argAttrs.size(); i != e; ++i)
+ if (!argAttrs[i].empty())
+ result.addAttribute(getArgAttrName(i, attrNameBuf),
+ builder.getDictionaryAttr(argAttrs[i]));
+
+ // Add the attributes to the function results.
+ for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i)
+ if (!resultAttrs[i].empty())
+ result.addAttribute(getResultAttrName(i, attrNameBuf),
+ builder.getDictionaryAttr(resultAttrs[i]));
+}
+
/// Parser implementation for function-like operations. Uses `funcTypeBuilder`
/// to construct the custom function type given lists of input and output types.
ParseResult
@@ -158,7 +178,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
SmallVector<Type, 4> resultTypes;
auto &builder = parser.getBuilder();
- // Parse the name as a symbol reference attribute.
+ // Parse the name as a symbol.
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
@@ -185,26 +205,14 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
return failure();
// Add the attributes to the function arguments.
- SmallString<8> attrNameBuf;
- for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
- if (!argAttrs[i].empty())
- result.addAttribute(getArgAttrName(i, attrNameBuf),
- builder.getDictionaryAttr(argAttrs[i]));
-
- // Add the attributes to the function results.
- for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
- if (!resultAttrs[i].empty())
- result.addAttribute(getResultAttrName(i, attrNameBuf),
- builder.getDictionaryAttr(resultAttrs[i]));
+ assert(argAttrs.size() == argTypes.size());
+ assert(resultAttrs.size() == resultTypes.size());
+ addArgAndResultAttrs(builder, result, argAttrs, resultAttrs);
// Parse the optional function body.
auto *body = result.addRegion();
- if (parser.parseOptionalRegion(*body, entryArgs,
- entryArgs.empty() ? llvm::ArrayRef<Type>()
- : argTypes))
- return failure();
-
- return success();
+ return parser.parseOptionalRegion(
+ *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes);
}
// Print a function result list.
@@ -227,9 +235,10 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types,
/// Print the signature of the function-like operation `op`. Assumes `op` has
/// the FunctionLike trait and passed the verification.
-static void printSignature(OpAsmPrinter &p, Operation *op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
+void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op,
+ ArrayRef<Type> argTypes,
+ bool isVariadic,
+ ArrayRef<Type> resultTypes) {
Region &body = op->getRegion(0);
bool isExternal = body.empty();
@@ -264,42 +273,53 @@ static void printSignature(OpAsmPrinter &p, Operation *op,
}
}
-/// Printer implementation for function-like operations. Accepts lists of
-/// argument and result types to use while printing.
-void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
- ArrayRef<Type> argTypes, bool isVariadic,
- ArrayRef<Type> resultTypes) {
- // Print the operation and the function name.
- auto funcName =
- op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
- .getValue();
- p << op->getName() << ' ';
- p.printSymbolName(funcName);
-
- // Print the signature.
- printSignature(p, op, argTypes, isVariadic, resultTypes);
-
+/// Prints the list of function prefixed with the "attributes" keyword. The
+/// attributes with names listed in "elided" as well as those used by the
+/// function-like operation internally are not printed. Nothing is printed
+/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and
+/// passed the verification.
+void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op,
+ unsigned numInputs,
+ unsigned numResults,
+ ArrayRef<StringRef> elided) {
// Print out function attributes, if present.
SmallVector<StringRef, 2> ignoredAttrs = {
::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()};
+ ignoredAttrs.append(elided.begin(), elided.end());
SmallString<8> attrNameBuf;
// Ignore any argument attributes.
std::vector<SmallString<8>> argAttrStorage;
- for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
+ for (unsigned i = 0; i != numInputs; ++i)
if (op->getAttr(getArgAttrName(i, attrNameBuf)))
argAttrStorage.emplace_back(attrNameBuf);
ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end());
// Ignore any result attributes.
std::vector<SmallString<8>> resultAttrStorage;
- for (unsigned i = 0, e = resultTypes.size(); i != e; ++i)
+ for (unsigned i = 0; i != numResults; ++i)
if (op->getAttr(getResultAttrName(i, attrNameBuf)))
resultAttrStorage.emplace_back(attrNameBuf);
ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end());
p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs);
+}
+
+/// Printer implementation for function-like operations. Accepts lists of
+/// argument and result types to use while printing.
+void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op,
+ ArrayRef<Type> argTypes, bool isVariadic,
+ ArrayRef<Type> resultTypes) {
+ // Print the operation and the function name.
+ auto funcName =
+ op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName())
+ .getValue();
+ p << op->getName() << ' ';
+ p.printSymbolName(funcName);
+
+ printFunctionSignature(p, op, argTypes, isVariadic, resultTypes);
+ printFunctionAttributes(p, op, argTypes.size(), resultTypes.size());
// Print the body if this is not an external function.
Region &body = op->getRegion(0);
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 9dace254b42..6565c628377 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -360,3 +360,25 @@ func @reduce_incorrect_yield(%arg0 : f32) {
}) : (f32) -> (f32)
}
+// -----
+
+module {
+ module @gpu_funcs attributes {gpu.kernel_module} {
+ // expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}
+ gpu.func @kernel_1(f32, f32) {
+ ^bb0(%arg0: f32):
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module {
+ module @gpu_funcs attributes {gpu.kernel_module} {
+ // expected-error @+1 {{requires 'type' attribute of function type}}
+ "gpu.func"() ({
+ gpu.return
+ }) {sym_name="kernel_1", type=f32} : () -> ()
+ }
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index bfc0f154309..e2fd26f254b 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -112,4 +112,53 @@ module attributes {gpu.container_module} {
return
}
+ module @gpu_funcs attributes {gpu.kernel_module} {
+ // CHECK-LABEL: gpu.func @kernel_1({{.*}}: f32) -> f32
+ // CHECK: workgroup
+ // CHECK: private
+ // CHECK: attributes
+ gpu.func @kernel_1(%arg0: f32) -> f32
+ workgroup(%arg1: memref<42xf32, 3>)
+ private(%arg2: memref<2xf32, 5>, %arg3: memref<1xf32, 5>)
+ kernel
+ attributes {foo="bar"} {
+ "use"(%arg1) : (memref<42xf32, 3>) -> ()
+ "use"(%arg2) : (memref<2xf32, 5>) -> ()
+ "use"(%arg3) : (memref<1xf32, 5>) -> ()
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @no_attribution
+ // CHECK: {
+ gpu.func @no_attribution(%arg0: f32) {
+ gpu.return
+ }
+
+ // CHECK-LABEL: @no_attribution_attrs
+ // CHECK: attributes
+ // CHECK: {
+ gpu.func @no_attribution_attrs(%arg0: f32) attributes {foo="bar"} {
+ gpu.return
+ }
+
+ // CHECK-LABEL: @workgroup_only
+ // CHECK: workgroup({{.*}}: {{.*}})
+ // CHECK: {
+ gpu.func @workgroup_only() workgroup(%arg0: memref<42xf32, 3>) {
+ gpu.return
+ }
+ // CHECK-LABEL: @private_only
+ // CHECK: private({{.*}}: {{.*}})
+ // CHECK: {
+ gpu.func @private_only() private(%arg0: memref<2xf32, 5>) {
+ gpu.return
+ }
+
+ // CHECK-LABEL: @empty_attribution
+ // CHECK: {
+ gpu.func @empty_attribution(%arg0: f32) workgroup() private() {
+ gpu.return
+ }
+ }
+
}
OpenPOWER on IntegriCloud