diff options
| author | Lei Zhang <antiagainst@google.com> | 2019-08-01 14:12:58 -0700 |
|---|---|---|
| committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-08-01 14:13:37 -0700 |
| commit | 00a7b6706d4ff8c8f4e4fe9bfbddf1ae47c8c658 (patch) | |
| tree | cb6b6692809d0aa21fac461e773207235b9e5388 /mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | |
| parent | b5fd117b2314c39361cc417c032f2bed6d26e03f (diff) | |
| download | bcm5719-llvm-00a7b6706d4ff8c8f4e4fe9bfbddf1ae47c8c658.tar.gz bcm5719-llvm-00a7b6706d4ff8c8f4e4fe9bfbddf1ae47c8c658.zip | |
[spirv] Add support for specialization constant
This CL extends the existing spv.constant op to also support
specialization constant by adding an extra unit attribute
on it.
PiperOrigin-RevId: 261194869
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp | 133 |
1 files changed, 79 insertions, 54 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 35c4088fa0a..188b08d35cd 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -168,15 +168,17 @@ private: /// and `valueAttr`. `constType` is needed here because we can interpret the /// `valueAttr` as a different type than the type of `valueAttr` itself; for /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType - /// constants. - uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); + /// constants. If `isSpec` is true, then the constant will be serialized as + /// a specialization constant. + uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr, + bool isSpec); /// Prepares bool ElementsAttr serialization. This method updates `opcode` /// with a proper OpConstant* instruction and pushes literal values for the /// constant to `operands`. LogicalResult prepareBoolVectorConstant(Location loc, DenseIntElementsAttr elementsAttr, - spirv::Opcode &opcode, + bool isSpec, spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands); /// Prepares int ElementsAttr serialization. This method updates `opcode` with @@ -184,7 +186,7 @@ private: /// constant to `operands`. LogicalResult prepareIntVectorConstant(Location loc, DenseIntElementsAttr elementsAttr, - spirv::Opcode &opcode, + bool isSpec, spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands); /// Prepares float ElementsAttr serialization. This method updates `opcode` @@ -192,14 +194,14 @@ private: /// constant to `operands`. LogicalResult prepareFloatVectorConstant(Location loc, DenseFPElementsAttr elementsAttr, - spirv::Opcode &opcode, + bool isSpec, spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands); - uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr); + uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec); - uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr); + uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec); - uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr); + uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec); //===--------------------------------------------------------------------===// // Operations @@ -317,7 +319,8 @@ LogicalResult Serializer::processMemoryModel() { } LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) { - if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) { + if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(), + op.is_spec_const())) { valueIDMap[op.getResult()] = resultID; return success(); } @@ -484,7 +487,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum, } operands.push_back(elementTypeID); if (auto elementCountID = prepareConstantInt( - loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { + loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()), + /*isSpec=*/false)) { operands.push_back(elementCountID); return success(); } @@ -535,15 +539,15 @@ Serializer::prepareFunctionType(Location loc, FunctionType type, //===----------------------------------------------------------------------===// uint32_t Serializer::prepareConstant(Location loc, Type constType, - Attribute valueAttr) { + Attribute valueAttr, bool isSpec) { if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { - return prepareConstantFp(loc, floatAttr); + return prepareConstantFp(loc, floatAttr, isSpec); } if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { - return prepareConstantInt(loc, intAttr); + return prepareConstantInt(loc, intAttr, isSpec); } if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { - return prepareConstantBool(loc, boolAttr); + return prepareConstantBool(loc, boolAttr, isSpec); } // This is a composite literal. We need to handle each component separately @@ -566,21 +570,25 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType, if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { if (vectorAttr.getType().getElementType().isInteger(1)) { - if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands))) + if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode, + operands))) return 0; - } else if (failed( - prepareIntVectorConstant(loc, vectorAttr, opcode, operands))) + } else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode, + operands))) return 0; } else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { - if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands))) + if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode, + operands))) return 0; } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { - opcode = spirv::Opcode::OpConstantComposite; + opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite + : spirv::Opcode::OpConstantComposite; operands.reserve(arrayAttr.size() + 2); auto elementType = constType.cast<spirv::ArrayType>().getElementType(); for (Attribute elementAttr : arrayAttr) - if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { + if (auto elementID = + prepareConstant(loc, elementType, elementAttr, isSpec)) { operands.push_back(elementID); } else { return 0; @@ -596,8 +604,8 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType, } LogicalResult Serializer::prepareBoolVectorConstant( - Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode, - SmallVectorImpl<uint32_t> &operands) { + Location loc, DenseIntElementsAttr elementsAttr, bool isSpec, + spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) { auto type = elementsAttr.getType(); assert(type.hasRank() && type.getRank() == 1 && "spv.constant should have verified only vector literal uses " @@ -612,13 +620,15 @@ LogicalResult Serializer::prepareBoolVectorConstant( // the splat value is zero. if (Attribute splatAttr = elementsAttr.getSplatValue()) { // We can use OpConstantNull if this bool ElementsAttr is splatting false. - if (!splatAttr.cast<BoolAttr>().getValue()) { + if (!isSpec && !splatAttr.cast<BoolAttr>().getValue()) { opcode = spirv::Opcode::OpConstantNull; return success(); } - if (auto id = prepareConstantBool(loc, splatAttr.cast<BoolAttr>())) { - opcode = spirv::Opcode::OpConstantComposite; + if (auto id = + prepareConstantBool(loc, splatAttr.cast<BoolAttr>(), isSpec)) { + opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite + : spirv::Opcode::OpConstantComposite; operands.append(count, id); return success(); } @@ -628,13 +638,14 @@ LogicalResult Serializer::prepareBoolVectorConstant( // Otherwise, we need to process each element and compose them with // OpConstantComposite. - opcode = spirv::Opcode::OpConstantComposite; + opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite + : spirv::Opcode::OpConstantComposite; for (APInt intValue : elementsAttr) { // We are constructing an BoolAttr for each APInt here. But given that // we only use ElementsAttr for vectors with no more than 4 elements, it // should be fine here. auto boolAttr = mlirBuilder.getBoolAttr(intValue.isOneValue()); - if (auto elementID = prepareConstantBool(loc, boolAttr)) { + if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) { operands.push_back(elementID); } else { return failure(); @@ -644,8 +655,8 @@ LogicalResult Serializer::prepareBoolVectorConstant( } LogicalResult Serializer::prepareIntVectorConstant( - Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode, - SmallVectorImpl<uint32_t> &operands) { + Location loc, DenseIntElementsAttr elementsAttr, bool isSpec, + spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) { auto type = elementsAttr.getType(); assert(type.hasRank() && type.getRank() == 1 && "spv.constant should have verified only vector literal uses " @@ -661,13 +672,15 @@ LogicalResult Serializer::prepareIntVectorConstant( // the splat value is zero. if (Attribute splatAttr = elementsAttr.getSplatValue()) { // We can use OpConstantNull if this int ElementsAttr is splatting 0. - if (splatAttr.cast<IntegerAttr>().getValue().isNullValue()) { + if (!isSpec && splatAttr.cast<IntegerAttr>().getValue().isNullValue()) { opcode = spirv::Opcode::OpConstantNull; return success(); } - if (auto id = prepareConstantInt(loc, splatAttr.cast<IntegerAttr>())) { - opcode = spirv::Opcode::OpConstantComposite; + if (auto id = + prepareConstantInt(loc, splatAttr.cast<IntegerAttr>(), isSpec)) { + opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite + : spirv::Opcode::OpConstantComposite; operands.append(count, id); return success(); } @@ -676,7 +689,8 @@ LogicalResult Serializer::prepareIntVectorConstant( // Otherwise, we need to process each element and compose them with // OpConstantComposite. - opcode = spirv::Opcode::OpConstantComposite; + opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite + : spirv::Opcode::OpConstantComposite; for (APInt intValue : elementsAttr) { // We are constructing an IntegerAttr for each APInt here. But given that // we only use ElementsAttr for vectors with no more than 4 elements, it @@ -684,7 +698,7 @@ LogicalResult Serializer::prepareIntVectorConstant( // TODO(antiagainst): revisit this if special extensions enabling large // vectors are supported. auto intAttr = mlirBuilder.getIntegerAttr(elementType, intValue); - if (auto elementID = prepareConstantInt(loc, intAttr)) { + if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) { operands.push_back(elementID); } else { return failure(); @@ -694,8 +708,8 @@ LogicalResult Serializer::prepareIntVectorConstant( } LogicalResult Serializer::prepareFloatVectorConstant( - Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode, - SmallVectorImpl<uint32_t> &operands) { + Location loc, DenseFPElementsAttr elementsAttr, bool isSpec, + spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) { auto type = elementsAttr.getType(); assert(type.hasRank() && type.getRank() == 1 && "spv.constant should have verified only vector literal uses " @@ -706,13 +720,14 @@ LogicalResult Serializer::prepareFloatVectorConstant( operands.reserve(count + 2); if (Attribute splatAttr = elementsAttr.getSplatValue()) { - if (splatAttr.cast<FloatAttr>().getValue().isZero()) { + if (!isSpec && splatAttr.cast<FloatAttr>().getValue().isZero()) { opcode = spirv::Opcode::OpConstantNull; return success(); } - if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>())) { - opcode = spirv::Opcode::OpConstantComposite; + if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>(), isSpec)) { + opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite + : spirv::Opcode::OpConstantComposite; operands.append(count, id); return success(); } @@ -720,10 +735,11 @@ LogicalResult Serializer::prepareFloatVectorConstant( return failure(); } - opcode = spirv::Opcode::OpConstantComposite; + opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite + : spirv::Opcode::OpConstantComposite; for (APFloat floatValue : elementsAttr) { auto fpAttr = mlirBuilder.getFloatAttr(elementType, floatValue); - if (auto elementID = prepareConstantFp(loc, fpAttr)) { + if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) { operands.push_back(elementID); } else { return failure(); @@ -732,7 +748,8 @@ LogicalResult Serializer::prepareFloatVectorConstant( return success(); } -uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) { +uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, + bool isSpec) { if (auto id = findConstantID(boolAttr)) { return id; } @@ -744,14 +761,18 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) { } auto resultID = getNextID(); - auto opcode = boolAttr.getValue() ? spirv::Opcode::OpConstantTrue - : spirv::Opcode::OpConstantFalse; + auto opcode = boolAttr.getValue() + ? (isSpec ? spirv::Opcode::OpSpecConstantTrue + : spirv::Opcode::OpConstantTrue) + : (isSpec ? spirv::Opcode::OpSpecConstantFalse + : spirv::Opcode::OpConstantFalse); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); return constIDMap[boolAttr] = resultID; } -uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) { +uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, + bool isSpec) { if (auto id = findConstantID(intAttr)) { return id; } @@ -767,6 +788,9 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) { unsigned bitwidth = value.getBitWidth(); bool isSigned = value.isSignedIntN(bitwidth); + auto opcode = + isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; + // According to SPIR-V spec, "When the type's bit width is less than 32-bits, // the literal's value appears in the low-order bits of the word, and the // high-order bits must be 0 for a floating-point type, or 0 for an integer @@ -778,8 +802,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) { } else { word = static_cast<uint32_t>(value.getZExtValue()); } - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant, - {typeID, resultID, word}); + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } // According to SPIR-V spec: "When the type's bit width is larger than one // word, the literal’s low-order words appear first." @@ -793,7 +816,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) { } else { words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); } - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant, + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else { std::string valueStr; @@ -808,7 +831,8 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) { return constIDMap[intAttr] = resultID; } -uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) { +uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, + bool isSpec) { if (auto id = findConstantID(floatAttr)) { return id; } @@ -823,22 +847,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) { APFloat value = floatAttr.getValue(); APInt intValue = value.bitcastToAPInt(); + auto opcode = + isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; + if (&value.getSemantics() == &APFloat::IEEEsingle()) { uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant, - {typeID, resultID, word}); + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { struct DoubleWord { uint32_t word1; uint32_t word2; } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant, + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { uint32_t word = static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant, - {typeID, resultID, word}); + encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); } else { std::string valueStr; llvm::raw_string_ostream rss(valueStr); |

