diff options
Diffstat (limited to 'mlir')
6 files changed, 41 insertions, 6 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index a5b3fc27413..8faa90cb134 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -56,8 +56,6 @@ public: protected: /// Type lowering class. SPIRVTypeConverter &typeConverter; - -private: }; #include "mlir/Dialect/SPIRV/SPIRVLowering.h.inc" diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h index 104a4798e7c..353004b6c76 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h @@ -26,6 +26,8 @@ #include "mlir/IR/Function.h" namespace mlir { +class OpBuilder; + namespace spirv { #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index 1ec825aab5c..34b386ebc17 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -118,6 +118,13 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> { let extraClassDeclaration = [{ // Returns true if a constant can be built for the given `type`. static bool isBuildableWith(Type type); + + // Creates a constant zero/one of the given `type` at the current insertion + // point of `builder` and returns it. + static spirv::ConstantOp getZero(Type type, Location loc, + OpBuilder *builder); + static spirv::ConstantOp getOne(Type type, Location loc, + OpBuilder *builder); }]; let hasOpcode = 0; diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp index 62cabf66a0d..4a3d25fbd38 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -145,8 +145,7 @@ spirv::AccessChainOp getElementPtr(OpBuilder &builder, Location loc, // Need to add a '0' at the beginning of the index list for accessing into the // struct that wraps the nested array types. - Value *zero = builder.create<spirv::ConstantOp>( - loc, indexType, builder.getIntegerAttr(indexType, 0)); + Value *zero = spirv::ConstantOp::getZero(indexType, loc, &builder); SmallVector<Value *, 4> accessIndices; accessIndices.reserve(1 + indices.size()); accessIndices.push_back(zero); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index ae7643fa915..e82420022ea 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1169,6 +1169,35 @@ bool spirv::ConstantOp::isBuildableWith(Type type) { return true; } +spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, + OpBuilder *builder) { + if (auto intType = type.dyn_cast<IntegerType>()) { + unsigned width = intType.getWidth(); + Attribute val; + if (width == 1) + return builder->create<spirv::ConstantOp>(loc, type, + builder->getBoolAttr(false)); + return builder->create<spirv::ConstantOp>( + loc, type, builder->getIntegerAttr(type, APInt(width, 0))); + } + + llvm_unreachable("unimplemented types for ConstantOp::getZero()"); +} + +spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, + OpBuilder *builder) { + if (auto intType = type.dyn_cast<IntegerType>()) { + unsigned width = intType.getWidth(); + if (width == 1) + return builder->create<spirv::ConstantOp>(loc, type, + builder->getBoolAttr(true)); + return builder->create<spirv::ConstantOp>( + loc, type, builder->getIntegerAttr(type, APInt(width, 1))); + } + + llvm_unreachable("unimplemented types for ConstantOp::getOne()"); +} + //===----------------------------------------------------------------------===// // spv.ControlBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index e9d36f66369..d48b31fe491 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -194,8 +194,8 @@ FuncOpLowering::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands, if (isScalarOrVectorType(argType.value())) { auto indexType = typeConverter.convertType(IndexType::get(funcOp.getContext())); - auto zero = rewriter.create<spirv::ConstantOp>( - funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0)); + auto zero = + spirv::ConstantOp::getZero(indexType, funcOp.getLoc(), &rewriter); auto loadPtr = rewriter.create<spirv::AccessChainOp>( funcOp.getLoc(), replacement, zero.constant()); replacement = rewriter.create<spirv::LoadOp>(funcOp.getLoc(), loadPtr, |

