diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 201 | ||||
| -rw-r--r-- | mlir/lib/IR/FunctionSupport.cpp | 102 |
2 files changed, 256 insertions, 47 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index bfd094d6203..5fc1cade760 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -45,7 +45,7 @@ bool GPUDialect::isKernel(Operation *op) { GPUDialect::GPUDialect(MLIRContext *context) : Dialect(getDialectName(), context) { - addOperations<LaunchOp, LaunchFuncOp, + addOperations<LaunchOp, LaunchFuncOp, GPUFuncOp, #define GET_OP_LIST #include "mlir/Dialect/GPU/GPUOps.cpp.inc" >(); @@ -93,7 +93,7 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, // Check that `launch_func` refers to a well-formed kernel function. StringRef kernelName = launchOp.kernel(); Operation *kernelFunc = kernelModule.lookupSymbol(kernelName); - auto kernelStdFunction = dyn_cast_or_null<FuncOp>(kernelFunc); + auto kernelStdFunction = dyn_cast_or_null<::mlir::FuncOp>(kernelFunc); auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc); if (!kernelStdFunction && !kernelLLVMFunction) return launchOp.emitOpError("kernel function '") @@ -501,9 +501,10 @@ void LaunchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// void LaunchFuncOp::build(Builder *builder, OperationState &result, - FuncOp kernelFunc, Value *gridSizeX, Value *gridSizeY, - Value *gridSizeZ, Value *blockSizeX, Value *blockSizeY, - Value *blockSizeZ, ArrayRef<Value *> kernelOperands) { + ::mlir::FuncOp kernelFunc, Value *gridSizeX, + Value *gridSizeY, Value *gridSizeZ, Value *blockSizeX, + Value *blockSizeY, Value *blockSizeZ, + ArrayRef<Value *> kernelOperands) { // Add grid and block sizes as op operands, followed by the data operands. result.addOperands( {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ}); @@ -517,7 +518,7 @@ void LaunchFuncOp::build(Builder *builder, OperationState &result, } void LaunchFuncOp::build(Builder *builder, OperationState &result, - FuncOp kernelFunc, KernelDim3 gridSize, + ::mlir::FuncOp kernelFunc, KernelDim3 gridSize, KernelDim3 blockSize, ArrayRef<Value *> kernelOperands) { build(builder, result, kernelFunc, gridSize.x, gridSize.y, gridSize.z, @@ -572,3 +573,191 @@ LogicalResult LaunchFuncOp::verify() { return success(); } + +//===----------------------------------------------------------------------===// +// GPUFuncOp +//===----------------------------------------------------------------------===// + +void GPUFuncOp::build(Builder *builder, OperationState &result, StringRef name, + FunctionType type, ArrayRef<Type> workgroupAttributions, + ArrayRef<Type> privateAttributions, + ArrayRef<NamedAttribute> attrs) { + result.addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + result.addAttribute(getNumWorkgroupAttributionsAttrName(), + builder->getI64IntegerAttr(workgroupAttributions.size())); + result.addAttributes(attrs); + Region *body = result.addRegion(); + Block *entryBlock = new Block; + entryBlock->addArguments(type.getInputs()); + entryBlock->addArguments(workgroupAttributions); + entryBlock->addArguments(privateAttributions); + + body->getBlocks().push_back(entryBlock); +} + +/// Parses a GPU function memory attribution. +/// +/// memory-attribution ::= (`workgroup` `(` ssa-id-and-type-list `)`)? +/// (`private` `(` ssa-id-and-type-list `)`)? +/// +/// Note that this function parses only one of the two similar parts, with the +/// keyword provided as argument. +static ParseResult +parseAttributions(OpAsmParser &parser, StringRef keyword, + SmallVectorImpl<OpAsmParser::OperandType> &args, + SmallVectorImpl<Type> &argTypes) { + // If we could not parse the keyword, just assume empty list and succeed. + if (failed(parser.parseOptionalKeyword(keyword))) + return success(); + + if (failed(parser.parseLParen())) + return failure(); + + // Early exit for an empty list. + if (succeeded(parser.parseOptionalRParen())) + return success(); + + do { + OpAsmParser::OperandType arg; + Type type; + + if (parser.parseRegionArgument(arg) || parser.parseColonType(type)) + return failure(); + + args.push_back(arg); + argTypes.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + return parser.parseRParen(); +} + +/// Parses a GPU function. +/// +/// <operation> ::= `gpu.func` symbol-ref-id `(` argument-list `)` +/// (`->` function-result-list)? memory-attribution `kernel`? +/// function-attributes? region +ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::OperandType, 8> entryArgs; + SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs; + SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs; + SmallVector<Type, 8> argTypes; + SmallVector<Type, 4> resultTypes; + bool isVariadic; + + // Parse the function name. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + auto signatureLocation = parser.getCurrentLocation(); + if (failed(impl::parseFunctionSignature( + parser, /*allowVariadic=*/false, entryArgs, argTypes, argAttrs, + isVariadic, resultTypes, resultAttrs))) + return failure(); + + if (entryArgs.empty() && !argTypes.empty()) + return parser.emitError(signatureLocation) + << "gpu.func requires named arguments"; + + // Construct the function type. More types will be added to the region, but + // not to the functiont type. + Builder &builder = parser.getBuilder(); + auto type = builder.getFunctionType(argTypes, resultTypes); + result.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + + // Parse workgroup memory attributions. + if (failed(parseAttributions(parser, getWorkgroupKeyword(), entryArgs, + argTypes))) + return failure(); + + // Store the number of operands we just parsed as the number of workgroup + // memory attributions. + unsigned numWorkgroupAttrs = argTypes.size() - type.getNumInputs(); + result.addAttribute(getNumWorkgroupAttributionsAttrName(), + builder.getI64IntegerAttr(numWorkgroupAttrs)); + + // Parse private memory attributions. + if (failed( + parseAttributions(parser, getPrivateKeyword(), entryArgs, argTypes))) + return failure(); + + // Parse the kernel attribute if present. + if (succeeded(parser.parseOptionalKeyword(getKernelKeyword()))) + result.addAttribute(GPUDialect::getKernelFuncAttrName(), + builder.getUnitAttr()); + + // Parse attributes. + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + mlir::impl::addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); + + // Parse the region. If no argument names were provided, take all names + // (including those of attributions) from the entry block. + auto *body = result.addRegion(); + return parser.parseRegion(*body, entryArgs, argTypes); +} + +static void printAttributions(OpAsmPrinter &p, StringRef keyword, + ArrayRef<BlockArgument *> values) { + if (values.empty()) + return; + + p << ' ' << keyword << '('; + interleaveComma(values, p.getStream(), + [&p](BlockArgument *v) { p << *v << " : " << v->getType(); }); + p << ')'; +} + +void GPUFuncOp::print(OpAsmPrinter &p) { + p << getOperationName() << ' '; + p.printSymbolName(getName()); + + FunctionType type = getType(); + impl::printFunctionSignature(p, this->getOperation(), type.getInputs(), + /*isVariadic=*/false, type.getResults()); + + printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions()); + printAttributions(p, getPrivateKeyword(), getPrivateAttributions()); + if (isKernel()) + p << ' ' << getKernelKeyword(); + + impl::printFunctionAttributes(p, this->getOperation(), type.getNumInputs(), + type.getNumResults(), + {getNumWorkgroupAttributionsAttrName(), + GPUDialect::getKernelFuncAttrName()}); + p.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +/// Hook for FunctionLike verifier. +LogicalResult GPUFuncOp::verifyType() { + Type type = getTypeAttr().getValue(); + if (!type.isa<FunctionType>()) + return emitOpError("requires '" + getTypeAttrName() + + "' attribute of function type"); + return success(); +} + +/// Verifies the body of the function. +LogicalResult GPUFuncOp::verifyBody() { + unsigned numFuncArguments = getNumArguments(); + unsigned numWorkgroupAttributions = getNumWorkgroupAttributions(); + unsigned numBlockArguments = front().getNumArguments(); + if (numBlockArguments < numFuncArguments + numWorkgroupAttributions) + return emitOpError() << "expected at least " + << numFuncArguments + numWorkgroupAttributions + << " arguments to body region"; + + ArrayRef<Type> funcArgTypes = getType().getInputs(); + for (unsigned i = 0; i < numFuncArguments; ++i) { + Type blockArgType = front().getArgument(i)->getType(); + if (funcArgTypes[i] != blockArgType) + return emitOpError() << "expected body region argument #" << i + << " to be of type " << funcArgTypes[i] << ", got " + << blockArgType; + } + + return success(); +} diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index 1f39575331c..c6f2673ef2a 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -128,9 +128,11 @@ static ParseResult parseFunctionResultList( return parser.parseRParen(); } -/// Parse a function signature, starting with a name and including the -/// parameter list. -static ParseResult parseFunctionSignature( +/// Parses a function signature using `parser`. The `allowVariadic` argument +/// indicates whether functions with variadic arguments are supported. The +/// trailing arguments are populated by this function with names, types and +/// attributes of the arguments and those of the results. +ParseResult mlir::impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl<OpAsmParser::OperandType> &argNames, SmallVectorImpl<Type> &argTypes, @@ -145,6 +147,24 @@ static ParseResult parseFunctionSignature( return success(); } +void mlir::impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, + ArrayRef<SmallVector<NamedAttribute, 2>> argAttrs, + ArrayRef<SmallVector<NamedAttribute, 2>> resultAttrs) { + // Add the attributes to the function arguments. + SmallString<8> attrNameBuf; + for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) + if (!argAttrs[i].empty()) + result.addAttribute(getArgAttrName(i, attrNameBuf), + builder.getDictionaryAttr(argAttrs[i])); + + // Add the attributes to the function results. + for (unsigned i = 0, e = resultAttrs.size(); i != e; ++i) + if (!resultAttrs[i].empty()) + result.addAttribute(getResultAttrName(i, attrNameBuf), + builder.getDictionaryAttr(resultAttrs[i])); +} + /// Parser implementation for function-like operations. Uses `funcTypeBuilder` /// to construct the custom function type given lists of input and output types. ParseResult @@ -158,7 +178,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, SmallVector<Type, 4> resultTypes; auto &builder = parser.getBuilder(); - // Parse the name as a symbol reference attribute. + // Parse the name as a symbol. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), result.attributes)) @@ -185,26 +205,14 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, return failure(); // Add the attributes to the function arguments. - SmallString<8> attrNameBuf; - for (unsigned i = 0, e = argTypes.size(); i != e; ++i) - if (!argAttrs[i].empty()) - result.addAttribute(getArgAttrName(i, attrNameBuf), - builder.getDictionaryAttr(argAttrs[i])); - - // Add the attributes to the function results. - for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) - if (!resultAttrs[i].empty()) - result.addAttribute(getResultAttrName(i, attrNameBuf), - builder.getDictionaryAttr(resultAttrs[i])); + assert(argAttrs.size() == argTypes.size()); + assert(resultAttrs.size() == resultTypes.size()); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs); // Parse the optional function body. auto *body = result.addRegion(); - if (parser.parseOptionalRegion(*body, entryArgs, - entryArgs.empty() ? llvm::ArrayRef<Type>() - : argTypes)) - return failure(); - - return success(); + return parser.parseOptionalRegion( + *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes); } // Print a function result list. @@ -227,9 +235,10 @@ static void printFunctionResultList(OpAsmPrinter &p, ArrayRef<Type> types, /// Print the signature of the function-like operation `op`. Assumes `op` has /// the FunctionLike trait and passed the verification. -static void printSignature(OpAsmPrinter &p, Operation *op, - ArrayRef<Type> argTypes, bool isVariadic, - ArrayRef<Type> resultTypes) { +void mlir::impl::printFunctionSignature(OpAsmPrinter &p, Operation *op, + ArrayRef<Type> argTypes, + bool isVariadic, + ArrayRef<Type> resultTypes) { Region &body = op->getRegion(0); bool isExternal = body.empty(); @@ -264,42 +273,53 @@ static void printSignature(OpAsmPrinter &p, Operation *op, } } -/// Printer implementation for function-like operations. Accepts lists of -/// argument and result types to use while printing. -void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, - ArrayRef<Type> argTypes, bool isVariadic, - ArrayRef<Type> resultTypes) { - // Print the operation and the function name. - auto funcName = - op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName()) - .getValue(); - p << op->getName() << ' '; - p.printSymbolName(funcName); - - // Print the signature. - printSignature(p, op, argTypes, isVariadic, resultTypes); - +/// Prints the list of function prefixed with the "attributes" keyword. The +/// attributes with names listed in "elided" as well as those used by the +/// function-like operation internally are not printed. Nothing is printed +/// if all attributes are elided. Assumes `op` has the `FunctionLike` trait and +/// passed the verification. +void mlir::impl::printFunctionAttributes(OpAsmPrinter &p, Operation *op, + unsigned numInputs, + unsigned numResults, + ArrayRef<StringRef> elided) { // Print out function attributes, if present. SmallVector<StringRef, 2> ignoredAttrs = { ::mlir::SymbolTable::getSymbolAttrName(), getTypeAttrName()}; + ignoredAttrs.append(elided.begin(), elided.end()); SmallString<8> attrNameBuf; // Ignore any argument attributes. std::vector<SmallString<8>> argAttrStorage; - for (unsigned i = 0, e = argTypes.size(); i != e; ++i) + for (unsigned i = 0; i != numInputs; ++i) if (op->getAttr(getArgAttrName(i, attrNameBuf))) argAttrStorage.emplace_back(attrNameBuf); ignoredAttrs.append(argAttrStorage.begin(), argAttrStorage.end()); // Ignore any result attributes. std::vector<SmallString<8>> resultAttrStorage; - for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) + for (unsigned i = 0; i != numResults; ++i) if (op->getAttr(getResultAttrName(i, attrNameBuf))) resultAttrStorage.emplace_back(attrNameBuf); ignoredAttrs.append(resultAttrStorage.begin(), resultAttrStorage.end()); p.printOptionalAttrDictWithKeyword(op->getAttrs(), ignoredAttrs); +} + +/// Printer implementation for function-like operations. Accepts lists of +/// argument and result types to use while printing. +void mlir::impl::printFunctionLikeOp(OpAsmPrinter &p, Operation *op, + ArrayRef<Type> argTypes, bool isVariadic, + ArrayRef<Type> resultTypes) { + // Print the operation and the function name. + auto funcName = + op->getAttrOfType<StringAttr>(::mlir::SymbolTable::getSymbolAttrName()) + .getValue(); + p << op->getName() << ' '; + p.printSymbolName(funcName); + + printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); + printFunctionAttributes(p, op, argTypes.size(), resultTypes.size()); // Print the body if this is not an external function. Region &body = op->getRegion(0); |

