summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp')
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp23
1 files changed, 12 insertions, 11 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 53a40dfa365..035de4f815d 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -136,26 +136,26 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
signatureConverter, newFuncOp))) {
return failure();
}
- // Create spv.Variable ops for each of the arguments. These need to be bound
- // by the runtime. For now use descriptor_set 0, and arg number as the binding
- // number.
+ // Create spv.globalVariable ops for each of the arguments. These need to be
+ // bound by the runtime. For now use descriptor_set 0, and arg number as the
+ // binding number.
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
if (!module) {
return funcOp.emitError("expected op to be within a spv.module");
}
OpBuilder builder(module.getOperation()->getRegion(0));
- SmallVector<Value *, 4> interface;
+ SmallVector<Attribute, 4> interface;
for (auto &convertedArgType :
llvm::enumerate(signatureConverter.getConvertedTypes())) {
- auto variableOp = builder.create<spirv::VariableOp>(
- funcOp.getLoc(), convertedArgType.value(),
- builder.getI32IntegerAttr(
- static_cast<int32_t>(spirv::StorageClass::StorageBuffer)),
- llvm::None);
+ std::string varName = funcOp.getName().str() + "_arg_" +
+ std::to_string(convertedArgType.index());
+ auto variableOp = builder.create<spirv::GlobalVariableOp>(
+ funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
+ builder.getStringAttr(varName), nullptr);
variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
variableOp.setAttr("binding",
builder.getI32IntegerAttr(convertedArgType.index()));
- interface.push_back(variableOp.getResult());
+ interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
}
// Create an entry point instruction for this function.
// TODO(ravishankarm) : Add execution mode for the entry function
@@ -164,7 +164,8 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
funcOp.getLoc(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
- builder.getSymbolRefAttr(newFuncOp.getName()), interface);
+ builder.getSymbolRefAttr(newFuncOp.getName()),
+ builder.getArrayAttr(interface));
return success();
}
} // namespace mlir
OpenPOWER on IntegriCloud