summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/GPU/IR/GPUDialect.cpp')
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp55
1 files changed, 45 insertions, 10 deletions
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e750d0fefff..dbca1fb003a 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -72,15 +72,10 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
// Check that `launch_func` refers to a well-formed GPU kernel module.
StringRef kernelModuleName = launchOp.getKernelModuleName();
- auto kernelModule = module.lookupSymbol<ModuleOp>(kernelModuleName);
+ auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName);
if (!kernelModule)
return launchOp.emitOpError()
<< "kernel module '" << kernelModuleName << "' is undefined";
- if (!kernelModule.getAttrOfType<UnitAttr>(
- GPUDialect::getKernelModuleAttrName()))
- return launchOp.emitOpError("module '")
- << kernelModuleName << "' is missing the '"
- << GPUDialect::getKernelModuleAttrName() << "' attribute";
// Check that `launch_func` refers to a well-formed kernel function.
StringRef kernelName = launchOp.kernel();
@@ -517,10 +512,9 @@ void LaunchFuncOp::build(Builder *builder, OperationState &result,
result.addOperands(kernelOperands);
result.addAttribute(getKernelAttrName(),
builder->getStringAttr(kernelFunc.getName()));
- auto kernelModule = kernelFunc.getParentOfType<ModuleOp>();
- if (Optional<StringRef> kernelModuleName = kernelModule.getName())
- result.addAttribute(getKernelModuleAttrName(),
- builder->getSymbolRefAttr(*kernelModuleName));
+ auto kernelModule = kernelFunc.getParentOfType<GPUModuleOp>();
+ result.addAttribute(getKernelModuleAttrName(),
+ builder->getSymbolRefAttr(kernelModule.getName()));
}
void LaunchFuncOp::build(Builder *builder, OperationState &result,
@@ -820,6 +814,47 @@ LogicalResult GPUFuncOp::verifyBody() {
return success();
}
+//===----------------------------------------------------------------------===//
+// GPUModuleOp
+//===----------------------------------------------------------------------===//
+
+void GPUModuleOp::build(Builder *builder, OperationState &result,
+ StringRef name) {
+ ensureTerminator(*result.addRegion(), *builder, result.location);
+ result.attributes.push_back(builder->getNamedAttr(
+ ::mlir::SymbolTable::getSymbolAttrName(), builder->getStringAttr(name)));
+}
+
+static ParseResult parseGPUModuleOp(OpAsmParser &parser,
+ OperationState &result) {
+ StringAttr nameAttr;
+ if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+ result.attributes))
+ return failure();
+
+ // If module attributes are present, parse them.
+ if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
+ return failure();
+
+ // Parse the module body.
+ auto *body = result.addRegion();
+ if (parser.parseRegion(*body, None, None))
+ return failure();
+
+ // Ensure that this module has a valid terminator.
+ GPUModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
+ return success();
+}
+
+static void print(OpAsmPrinter &p, GPUModuleOp op) {
+ p << op.getOperationName() << ' ';
+ p.printSymbolName(op.getName());
+ p.printOptionalAttrDictWithKeyword(op.getAttrs(),
+ {SymbolTable::getSymbolAttrName()});
+ p.printRegion(op.getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+}
+
// Namespace avoids ambiguous ReturnOpOperandAdaptor.
namespace mlir {
namespace gpu {
OpenPOWER on IntegriCloud