diff options
Diffstat (limited to 'mlir/lib')
| -rw-r--r-- | mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | 3 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 29 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp | 4 |
3 files changed, 32 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 62cabf66a0d..4a3d25fbd38 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -145,8 +145,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc, // Need to add a '0' at the beginning of the index list for accessing into the // struct that wraps the nested array types. - Value *zero = builder.create<spirv::ConstantOp>( - loc, indexType, builder.getIntegerAttr(indexType, 0)); + Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder); SmallVector<Value *, 4> accessIndices; accessIndices.reserve(1 + indices.size()); accessIndices.push_back(zero); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index ae7643fa915..e82420022ea 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1169,6 +1169,35 @@ bool spirv::ConstantOp::isBuildableWith(Type type) { return true; } +spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, + OpBuilder *builder) { + if (auto intType = type.dyn_cast<IntegerType>()) { + unsigned width = intType.getWidth(); + Attribute val; + if (width == 1) + return builder->create<spirv::ConstantOp>(loc, type, + builder->getBoolAttr(false)); + return builder->create<spirv::ConstantOp>( + loc, type, builder->getIntegerAttr(type, APInt(width, 0))); + } + + llvm_unreachable("unimplemented types for ConstantOp::getZero()"); +} + +spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, + OpBuilder *builder) { + if (auto intType = type.dyn_cast<IntegerType>()) { + unsigned width = intType.getWidth(); + if (width == 1) + return builder->create<spirv::ConstantOp>(loc, type, + builder->getBoolAttr(true)); + return builder->create<spirv::ConstantOp>( + loc, type, builder->getIntegerAttr(type, APInt(width, 1))); + } + + llvm_unreachable("unimplemented types for ConstantOp::getOne()"); +} + //===----------------------------------------------------------------------===// // spv.ControlBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index e9d36f66369..d48b31fe491 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -194,8 +194,8 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands, if (isScalarOrVectorType(argType.value())) { auto indexType = typeConverter.convertType(IndexType::get(funcOp.getContext())); - auto zero = rewriter.create<spirv::ConstantOp>( - funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0)); + auto zero = + spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter); auto loadPtr = rewriter.create<spirv::AccessChainOp>( funcOp.getLoc(), replacement, zero.constant()); replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr, |

