summaryrefslogtreecommitdiffstats
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h2
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h2
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td7
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp3
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp29
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp4
6 files changed, 41 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
index a5b3fc27413..8faa90cb134 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
@@ -56,8 +56,6 @@ public:
protected:
/// Type lowering class.
SPIRVTypeConverter &typeConverter;
-
-private:
};
#include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc"
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index 104a4798e7c..353004b6c76 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -26,6 +26,8 @@
#include "mlir/IR/Function.h"
namespace mlir {
+class OpBuilder;
+
namespace spirv {
#define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
index 1ec825aab5c..34b386ebc17 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
@@ -118,6 +118,13 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
let extraClassDeclaration = [{
// Returns true if a constant can be built for the given `type`.
static bool isBuildableWith(Type type);
+
+ // Creates a constant zero/one of the given `type` at the current insertion
+ // point of `builder` and returns it.
+ static spirv::ConstantOp getZero(Type type, Location loc,
+ OpBuilder *builder);
+ static spirv::ConstantOp getOne(Type type, Location loc,
+ OpBuilder *builder);
}];
let hasOpcode = 0;
diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
index 62cabf66a0d..4a3d25fbd38 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
@@ -145,8 +145,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc,
// Need to add a '0' at the beginning of the index list for accessing into the
// struct that wraps the nested array types.
- Value *zero = builder.create<spirv::ConstantOp>(
- loc, indexType, builder.getIntegerAttr(indexType, 0));
+ Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder);
SmallVector<Value *, 4> accessIndices;
accessIndices.reserve(1 + indices.size());
accessIndices.push_back(zero);
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index ae7643fa915..e82420022ea 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -1169,6 +1169,35 @@ bool spirv::ConstantOp::isBuildableWith(Type type) {
return true;
}
+spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
+ OpBuilder *builder) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ unsigned width = intType.getWidth();
+ Attribute val;
+ if (width == 1)
+ return builder->create<spirv::ConstantOp>(loc, type,
+ builder->getBoolAttr(false));
+ return builder->create<spirv::ConstantOp>(
+ loc, type, builder->getIntegerAttr(type, APInt(width, 0)));
+ }
+
+ llvm_unreachable("unimplemented types for ConstantOp::getZero()");
+}
+
+spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
+ OpBuilder *builder) {
+ if (auto intType = type.dyn_cast<IntegerType>()) {
+ unsigned width = intType.getWidth();
+ if (width == 1)
+ return builder->create<spirv::ConstantOp>(loc, type,
+ builder->getBoolAttr(true));
+ return builder->create<spirv::ConstantOp>(
+ loc, type, builder->getIntegerAttr(type, APInt(width, 1)));
+ }
+
+ llvm_unreachable("unimplemented types for ConstantOp::getOne()");
+}
+
//===----------------------------------------------------------------------===//
// spv.ControlBarrier
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index e9d36f66369..d48b31fe491 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -194,8 +194,8 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
if (isScalarOrVectorType(argType.value())) {
auto indexType =
typeConverter.convertType(IndexType::get(funcOp.getContext()));
- auto zero = rewriter.create<spirv::ConstantOp>(
- funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
+ auto zero =
+ spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter);
auto loadPtr = rewriter.create<spirv::AccessChainOp>(
funcOp.getLoc(), replacement, zero.constant());
replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr,
OpenPOWER on IntegriCloud