summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-08-01 14:12:58 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-01 14:13:37 -0700
commit00a7b6706d4ff8c8f4e4fe9bfbddf1ae47c8c658 (patch)
treecb6b6692809d0aa21fac461e773207235b9e5388 /mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
parentb5fd117b2314c39361cc417c032f2bed6d26e03f (diff)
downloadbcm5719-llvm-00a7b6706d4ff8c8f4e4fe9bfbddf1ae47c8c658.tar.gz
bcm5719-llvm-00a7b6706d4ff8c8f4e4fe9bfbddf1ae47c8c658.zip
[spirv] Add support for specialization constant
This CL extends the existing spv.constant op to also support specialization constant by adding an extra unit attribute on it. PiperOrigin-RevId: 261194869
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp133
1 files changed, 79 insertions, 54 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index 35c4088fa0a..188b08d35cd 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -168,15 +168,17 @@ private:
/// and `valueAttr`. `constType` is needed here because we can interpret the
/// `valueAttr` as a different type than the type of `valueAttr` itself; for
/// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
- /// constants.
- uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
+ /// constants. If `isSpec` is true, then the constant will be serialized as
+ /// a specialization constant.
+ uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
+ bool isSpec);
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
/// with a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareBoolVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
- spirv::Opcode &opcode,
+ bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
@@ -184,7 +186,7 @@ private:
/// constant to `operands`.
LogicalResult prepareIntVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
- spirv::Opcode &opcode,
+ bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares float ElementsAttr serialization. This method updates `opcode`
@@ -192,14 +194,14 @@ private:
/// constant to `operands`.
LogicalResult prepareFloatVectorConstant(Location loc,
DenseFPElementsAttr elementsAttr,
- spirv::Opcode &opcode,
+ bool isSpec, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
- uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr);
+ uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
- uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr);
+ uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
- uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr);
+ uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
//===--------------------------------------------------------------------===//
// Operations
@@ -317,7 +319,8 @@ LogicalResult Serializer::processMemoryModel() {
}
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
- if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
+ if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
+ op.is_spec_const())) {
valueIDMap[op.getResult()] = resultID;
return success();
}
@@ -484,7 +487,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
}
operands.push_back(elementTypeID);
if (auto elementCountID = prepareConstantInt(
- loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
+ loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
+ /*isSpec=*/false)) {
operands.push_back(elementCountID);
return success();
}
@@ -535,15 +539,15 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
//===----------------------------------------------------------------------===//
uint32_t Serializer::prepareConstant(Location loc, Type constType,
- Attribute valueAttr) {
+ Attribute valueAttr, bool isSpec) {
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
- return prepareConstantFp(loc, floatAttr);
+ return prepareConstantFp(loc, floatAttr, isSpec);
}
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
- return prepareConstantInt(loc, intAttr);
+ return prepareConstantInt(loc, intAttr, isSpec);
}
if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
- return prepareConstantBool(loc, boolAttr);
+ return prepareConstantBool(loc, boolAttr, isSpec);
}
// This is a composite literal. We need to handle each component separately
@@ -566,21 +570,25 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
if (vectorAttr.getType().getElementType().isInteger(1)) {
- if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
+ if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
+ operands)))
return 0;
- } else if (failed(
- prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
+ } else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
+ operands)))
return 0;
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
- if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
+ if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
+ operands)))
return 0;
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
- opcode = spirv::Opcode::OpConstantComposite;
+ opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+ : spirv::Opcode::OpConstantComposite;
operands.reserve(arrayAttr.size() + 2);
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
for (Attribute elementAttr : arrayAttr)
- if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
+ if (auto elementID =
+ prepareConstant(loc, elementType, elementAttr, isSpec)) {
operands.push_back(elementID);
} else {
return 0;
@@ -596,8 +604,8 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
}
LogicalResult Serializer::prepareBoolVectorConstant(
- Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
- SmallVectorImpl<uint32_t> &operands) {
+ Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
+ spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@@ -612,13 +620,15 @@ LogicalResult Serializer::prepareBoolVectorConstant(
// the splat value is zero.
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
- if (!splatAttr.cast<BoolAttr>().getValue()) {
+ if (!isSpec && !splatAttr.cast<BoolAttr>().getValue()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
- if (auto id = prepareConstantBool(loc, splatAttr.cast<BoolAttr>())) {
- opcode = spirv::Opcode::OpConstantComposite;
+ if (auto id =
+ prepareConstantBool(loc, splatAttr.cast<BoolAttr>(), isSpec)) {
+ opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+ : spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@@ -628,13 +638,14 @@ LogicalResult Serializer::prepareBoolVectorConstant(
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
- opcode = spirv::Opcode::OpConstantComposite;
+ opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+ : spirv::Opcode::OpConstantComposite;
for (APInt intValue : elementsAttr) {
// We are constructing an BoolAttr for each APInt here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
auto boolAttr = mlirBuilder.getBoolAttr(intValue.isOneValue());
- if (auto elementID = prepareConstantBool(loc, boolAttr)) {
+ if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
@@ -644,8 +655,8 @@ LogicalResult Serializer::prepareBoolVectorConstant(
}
LogicalResult Serializer::prepareIntVectorConstant(
- Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
- SmallVectorImpl<uint32_t> &operands) {
+ Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
+ spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@@ -661,13 +672,15 @@ LogicalResult Serializer::prepareIntVectorConstant(
// the splat value is zero.
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
- if (splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
+ if (!isSpec && splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
- if (auto id = prepareConstantInt(loc, splatAttr.cast<IntegerAttr>())) {
- opcode = spirv::Opcode::OpConstantComposite;
+ if (auto id =
+ prepareConstantInt(loc, splatAttr.cast<IntegerAttr>(), isSpec)) {
+ opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+ : spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@@ -676,7 +689,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
- opcode = spirv::Opcode::OpConstantComposite;
+ opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+ : spirv::Opcode::OpConstantComposite;
for (APInt intValue : elementsAttr) {
// We are constructing an IntegerAttr for each APInt here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
@@ -684,7 +698,7 @@ LogicalResult Serializer::prepareIntVectorConstant(
// TODO(antiagainst): revisit this if special extensions enabling large
// vectors are supported.
auto intAttr = mlirBuilder.getIntegerAttr(elementType, intValue);
- if (auto elementID = prepareConstantInt(loc, intAttr)) {
+ if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
@@ -694,8 +708,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
}
LogicalResult Serializer::prepareFloatVectorConstant(
- Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
- SmallVectorImpl<uint32_t> &operands) {
+ Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
+ spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@@ -706,13 +720,14 @@ LogicalResult Serializer::prepareFloatVectorConstant(
operands.reserve(count + 2);
if (Attribute splatAttr = elementsAttr.getSplatValue()) {
- if (splatAttr.cast<FloatAttr>().getValue().isZero()) {
+ if (!isSpec && splatAttr.cast<FloatAttr>().getValue().isZero()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
- if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>())) {
- opcode = spirv::Opcode::OpConstantComposite;
+ if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>(), isSpec)) {
+ opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+ : spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@@ -720,10 +735,11 @@ LogicalResult Serializer::prepareFloatVectorConstant(
return failure();
}
- opcode = spirv::Opcode::OpConstantComposite;
+ opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+ : spirv::Opcode::OpConstantComposite;
for (APFloat floatValue : elementsAttr) {
auto fpAttr = mlirBuilder.getFloatAttr(elementType, floatValue);
- if (auto elementID = prepareConstantFp(loc, fpAttr)) {
+ if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
operands.push_back(elementID);
} else {
return failure();
@@ -732,7 +748,8 @@ LogicalResult Serializer::prepareFloatVectorConstant(
return success();
}
-uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
+uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
+ bool isSpec) {
if (auto id = findConstantID(boolAttr)) {
return id;
}
@@ -744,14 +761,18 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
}
auto resultID = getNextID();
- auto opcode = boolAttr.getValue() ? spirv::Opcode::OpConstantTrue
- : spirv::Opcode::OpConstantFalse;
+ auto opcode = boolAttr.getValue()
+ ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
+ : spirv::Opcode::OpConstantTrue)
+ : (isSpec ? spirv::Opcode::OpSpecConstantFalse
+ : spirv::Opcode::OpConstantFalse);
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
return constIDMap[boolAttr] = resultID;
}
-uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
+uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
+ bool isSpec) {
if (auto id = findConstantID(intAttr)) {
return id;
}
@@ -767,6 +788,9 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
unsigned bitwidth = value.getBitWidth();
bool isSigned = value.isSignedIntN(bitwidth);
+ auto opcode =
+ isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
+
// According to SPIR-V spec, "When the type's bit width is less than 32-bits,
// the literal's value appears in the low-order bits of the word, and the
// high-order bits must be 0 for a floating-point type, or 0 for an integer
@@ -778,8 +802,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
} else {
word = static_cast<uint32_t>(value.getZExtValue());
}
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
- {typeID, resultID, word});
+ encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
}
// According to SPIR-V spec: "When the type's bit width is larger than one
// word, the literal’s low-order words appear first."
@@ -793,7 +816,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
} else {
words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
}
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
+ encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else {
std::string valueStr;
@@ -808,7 +831,8 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
return constIDMap[intAttr] = resultID;
}
-uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
+uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
+ bool isSpec) {
if (auto id = findConstantID(floatAttr)) {
return id;
}
@@ -823,22 +847,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
APFloat value = floatAttr.getValue();
APInt intValue = value.bitcastToAPInt();
+ auto opcode =
+ isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
+
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
- {typeID, resultID, word});
+ encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
+ encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
- {typeID, resultID, word});
+ encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
OpenPOWER on IntegriCloud