diff options
Diffstat (limited to 'mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp')
-rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 53a40dfa365..035de4f815d 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -136,26 +136,26 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands, signatureConverter, newFuncOp))) { return failure(); } - // Create spv.Variable ops for each of the arguments. These need to be bound - // by the runtime. For now use descriptor_set 0, and arg number as the binding - // number. + // Create spv.globalVariable ops for each of the arguments. These need to be + // bound by the runtime. For now use descriptor_set 0, and arg number as the + // binding number. auto module = funcOp.getParentOfType<spirv::ModuleOp>(); if (!module) { return funcOp.emitError("expected op to be within a spv.module"); } OpBuilder builder(module.getOperation()->getRegion(0)); - SmallVector<Value *, 4> interface; + SmallVector<Attribute, 4> interface; for (auto &convertedArgType : llvm::enumerate(signatureConverter.getConvertedTypes())) { - auto variableOp = builder.create<spirv::VariableOp>( - funcOp.getLoc(), convertedArgType.value(), - builder.getI32IntegerAttr( - static_cast<int32_t>(spirv::StorageClass::StorageBuffer)), - llvm::None); + std::string varName = funcOp.getName().str() + "_arg_" + + std::to_string(convertedArgType.index()); + auto variableOp = builder.create<spirv::GlobalVariableOp>( + funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()), + builder.getStringAttr(varName), nullptr); variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0)); variableOp.setAttr("binding", builder.getI32IntegerAttr(convertedArgType.index())); - interface.push_back(variableOp.getResult()); + interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name())); } // Create an entry point instruction for this function. // TODO(ravishankarm) : Add execution mode for the entry function @@ -164,7 +164,8 @@ LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands, funcOp.getLoc(), builder.getI32IntegerAttr( static_cast<int32_t>(spirv::ExecutionModel::GLCompute)), - builder.getSymbolRefAttr(newFuncOp.getName()), interface); + builder.getSymbolRefAttr(newFuncOp.getName()), + builder.getArrayAttr(interface)); return success(); } } // namespace mlir |