diff options
Diffstat (limited to 'mlir/lib/Dialect/GPU/IR/GPUDialect.cpp')
-rw-r--r-- | mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 55 |
1 files changed, 45 insertions, 10 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index e750d0fefff..dbca1fb003a 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -72,15 +72,10 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, // Check that `launch_func` refers to a well-formed GPU kernel module. StringRef kernelModuleName = launchOp.getKernelModuleName(); - auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName); + auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName); if (!kernelModule) return launchOp.emitOpError() << "kernel module '" << kernelModuleName << "' is undefined"; - if (!kernelModule.getAttrOfType<UnitAttr>( - GPUDialect::getKernelModuleAttrName())) - return launchOp.emitOpError("module '") - << kernelModuleName << "' is missing the '" - << GPUDialect::getKernelModuleAttrName() << "' attribute"; // Check that `launch_func` refers to a well-formed kernel function. StringRef kernelName = launchOp.kernel(); @@ -517,10 +512,9 @@ void LaunchFuncOp::build(Builder *builder, OperationState &result, result.addOperands(kernelOperands); result.addAttribute(getKernelAttrName(), builder->getStringAttr(kernelFunc.getName())); - auto kernelModule = kernelFunc.getParentOfType<ModuleOp>(); - if (Optional<StringRef> kernelModuleName = kernelModule.getName()) - result.addAttribute(getKernelModuleAttrName(), - builder->getSymbolRefAttr(*kernelModuleName)); + auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>(); + result.addAttribute(getKernelModuleAttrName(), + builder->getSymbolRefAttr(kernelModule.getName())); } void LaunchFuncOp::build(Builder *builder, OperationState &result, @@ -820,6 +814,47 @@ LogicalResult GPUFuncOp::verifyBody() { return success(); } +//===----------------------------------------------------------------------===// +// GPUModuleOp +//===----------------------------------------------------------------------===// + +void GPUModuleOp::build(Builder *builder, OperationState &result, + StringRef name) { + ensureTerminator(*result.addRegion(), *builder, result.location); + result.attributes.push_back(builder->getNamedAttr( + ::mlir::SymbolTable::getSymbolAttrName(), builder->getStringAttr(name))); +} + +static ParseResult parseGPUModuleOp(OpAsmParser &parser, + OperationState &result) { + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // If module attributes are present, parse them. + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // Parse the module body. + auto *body = result.addRegion(); + if (parser.parseRegion(*body, None, None)) + return failure(); + + // Ensure that this module has a valid terminator. + GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); + return success(); +} + +static void print(OpAsmPrinter &p, GPUModuleOp op) { + p << op.getOperationName() << ' '; + p.printSymbolName(op.getName()); + p.printOptionalAttrDictWithKeyword(op.getAttrs(), + {SymbolTable::getSymbolAttrName()}); + p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); +} + // Namespace avoids ambiguous ReturnOpOperandAdaptor. namespace mlir { namespace gpu { |