summaryrefslogtreecommitdiffstats
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp3
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp29
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp4
3 files changed, 32 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 62cabf66a0d..4a3d25fbd38 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -145,8 +145,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc,
// Need to add a '0' at the beginning of the index list for accessing into the
// struct that wraps the nested array types.
- Value *zero = builder.create<spirv::ConstantOp>(
- loc, indexType, builder.getIntegerAttr(indexType, 0));
+ Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
SmallVector<Value *, 4> accessIndices;
accessIndices.reserve(1 + indices.size());
accessIndices.push_back(zero);
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ae7643fa915..e82420022ea 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1169,6 +1169,35 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
return true;
}
+spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
+ OpBuilder *builder) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ unsigned width = intType.getWidth();
+ Attribute val;
+ if (width == 1)
+ return builder->create<spirv::ConstantOp>(loc, type,
+ builder->getBoolAttr(false));
+ return builder->create<spirv::ConstantOp>(
+ loc, type, builder->getIntegerAttr(type, APInt(width, 0)));
+ }
+
+ llvm_unreachable("unimplemented types for ConstantOp::getZero()");
+}
+
+spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
+ OpBuilder *builder) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ unsigned width = intType.getWidth();
+ if (width == 1)
+ return builder->create<spirv::ConstantOp>(loc, type,
+ builder->getBoolAttr(true));
+ return builder->create<spirv::ConstantOp>(
+ loc, type, builder->getIntegerAttr(type, APInt(width, 1)));
+ }
+
+ llvm_unreachable("unimplemented types for ConstantOp::getOne()");
+}
+
//===----------------------------------------------------------------------===//
// spv.ControlBarrier
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index e9d36f66369..d48b31fe491 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -194,8 +194,8 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
if (isScalarOrVectorType(argType.value())) {
auto indexType =
typeConverter.convertType(IndexType::get(funcOp.getContext()));
- auto zero = rewriter.create<spirv::ConstantOp>(
- funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+ auto zero =
+ spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), replacement, zero.constant());
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr,
OpenPOWER on IntegriCloud