summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-08-20 13:33:41 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-08-20 13:34:13 -0700
commitf4934bcc3e38812051f37a1aadbc4d20913ebadc (patch)
tree9a710091fe78b44983bd446c823af3c98a5726c8 /mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
parent82cf6051ee7157a2883210baab191345cbd075bc (diff)
downloadbcm5719-llvm-f4934bcc3e38812051f37a1aadbc4d20913ebadc.tar.gz
bcm5719-llvm-f4934bcc3e38812051f37a1aadbc4d20913ebadc.zip
Add spv.specConstant and spv._reference_of
Similar to global variables, specialization constants also live in the module scope and can be referenced by instructions in functions in native SPIR-V. A direct modelling would be to allow functions in the SPIR-V dialect to implicit capture, but it means we are losing the ability to write passes for Functions. While in SPIR-V normally we want to process the module as a whole, it's not common to see multiple functions get used so we'd like to leave the door open for those cases. Therefore, similar to global variables, we introduce spv.specConstant to model three SPIR-V instructions: OpSpecConstantTrue, OpSpecConstantFalse, and OpSpecConstant. They do not return SSA value results; instead they have symbols and can only be referenced by the symbols. To use it in a function, we need to have another op spv._reference_of to turn the symbol into an SSA value. This breaks the tie and makes functions still explicit capture. Previously specialization constants were handled similarly as normal constants. That is incorrect given that specialization constant actually acts more like variable (without need to load and store). E.g., they cannot be de-duplicated like normal constants. This CL also refines various documents and comments. PiperOrigin-RevId: 264455172
Diffstat (limited to 'mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp')
-rw-r--r--mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp230
1 files changed, 146 insertions, 84 deletions
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
index bc0b706092c..233e5251492 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
@@ -118,17 +118,23 @@ private:
// Module structure
//===--------------------------------------------------------------------===//
- LogicalResult processMemoryModel();
+ uint32_t findSpecConstID(StringRef constName) const {
+ return specConstIDMap.lookup(constName);
+ }
- LogicalResult processConstantOp(spirv::ConstantOp op);
+ uint32_t findVariableID(StringRef varName) const {
+ return globalVarIDMap.lookup(varName);
+ }
uint32_t findFunctionID(StringRef fnName) const {
return funcIDMap.lookup(fnName);
}
- uint32_t findVariableID(StringRef varName) const {
- return globalVarIDMap.lookup(varName);
- }
+ LogicalResult processMemoryModel();
+
+ LogicalResult processConstantOp(spirv::ConstantOp op);
+
+ LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
/// Emit OpName for the given `resultID`.
LogicalResult processName(uint32_t resultID, StringRef name);
@@ -190,17 +196,15 @@ private:
/// and `valueAttr`. `constType` is needed here because we can interpret the
/// `valueAttr` as a different type than the type of `valueAttr` itself; for
/// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
- /// constants. If `isSpec` is true, then the constant will be serialized as
- /// a specialization constant.
- uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
- bool isSpec);
+ /// constants.
+ uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
/// with a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareBoolVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
- bool isSpec, spirv::Opcode &opcode,
+ spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
@@ -208,7 +212,7 @@ private:
/// constant to `operands`.
LogicalResult prepareIntVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
- bool isSpec, spirv::Opcode &opcode,
+ spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares float ElementsAttr serialization. This method updates `opcode`
@@ -216,14 +220,24 @@ private:
/// constant to `operands`.
LogicalResult prepareFloatVectorConstant(Location loc,
DenseFPElementsAttr elementsAttr,
- bool isSpec, spirv::Opcode &opcode,
+ spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
- uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
+ /// Prepares scalar attribute serialization. This method emits corresponding
+ /// OpConstant* and returns the result <id> associated with it. Returns 0 if
+ /// the attribute is not for a scalar bool/integer/float value. If `isSpec` is
+ /// true, then the constant will be serialized as a specialization constant.
+ uint32_t prepareConstantScalar(Location loc, Attribute valueAttr,
+ bool isSpec = false);
- uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
+ uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr,
+ bool isSpec = false);
- uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
+ uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr,
+ bool isSpec = false);
+
+ uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr,
+ bool isSpec = false);
//===--------------------------------------------------------------------===//
// Operations
@@ -231,9 +245,10 @@ private:
uint32_t findValueID(Value *val) const { return valueIDMap.lookup(val); }
- /// Process spv.addressOf operations.
LogicalResult processAddressOfOp(spirv::AddressOfOp addressOfOp);
+ LogicalResult processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp);
+
/// Main dispatch method for serializing an operation.
LogicalResult processOperation(Operation *op);
@@ -275,19 +290,22 @@ private:
SmallVector<uint32_t, 0> typesGlobalValues;
SmallVector<uint32_t, 0> functions;
- /// Map from type used in SPIR-V module to their <id>s
+ /// Map from type used in SPIR-V module to their <id>s.
DenseMap<Type, uint32_t> typeIDMap;
- /// Map from constant values to their <id>s
+ /// Map from constant values to their <id>s.
DenseMap<Attribute, uint32_t> constIDMap;
- /// Map from FuncOps name to <id>s.
- llvm::StringMap<uint32_t> funcIDMap;
+ /// Map from specialization constant names to their <id>s.
+ llvm::StringMap<uint32_t> specConstIDMap;
- /// Map from GlobalVariableOps name to <id>s
+ /// Map from GlobalVariableOps name to <id>s.
llvm::StringMap<uint32_t> globalVarIDMap;
- /// Map from results of normal operations to their <id>s
+ /// Map from FuncOps name to <id>s.
+ llvm::StringMap<uint32_t> funcIDMap;
+
+ /// Map from results of normal operations to their <id>s.
DenseMap<Value *, uint32_t> valueIDMap;
};
} // namespace
@@ -347,14 +365,22 @@ LogicalResult Serializer::processMemoryModel() {
}
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
- if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
- op.is_spec_const())) {
+ if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
valueIDMap[op.getResult()] = resultID;
return success();
}
return failure();
}
+LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
+ if (auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
+ /*isSpec=*/true)) {
+ specConstIDMap[op.sym_name()] = resultID;
+ return processName(resultID, op.sym_name());
+ }
+ return failure();
+}
+
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
auto attrName = attr.first.strref();
@@ -395,6 +421,8 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
}
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
+ assert(!name.empty() && "unexpected empty string for OpName");
+
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(resultID);
if (failed(encodeStringLiteralInto(nameOperands, name))) {
@@ -616,8 +644,7 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
}
operands.push_back(elementTypeID);
if (auto elementCountID = prepareConstantInt(
- loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
- /*isSpec=*/false)) {
+ loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
operands.push_back(elementCountID);
}
return processTypeDecoration(loc, arrayType, resultID);
@@ -692,17 +719,10 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
//===----------------------------------------------------------------------===//
uint32_t Serializer::prepareConstant(Location loc, Type constType,
- Attribute valueAttr, bool isSpec) {
- if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
- return prepareConstantFp(loc, floatAttr, isSpec);
- }
- if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
- return prepareConstantInt(loc, intAttr, isSpec);
- }
- if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
- return prepareConstantBool(loc, boolAttr, isSpec);
+ Attribute valueAttr) {
+ if (auto id = prepareConstantScalar(loc, valueAttr)) {
+ return id;
}
-
// This is a composite literal. We need to handle each component separately
// and then emit an OpConstantComposite for the whole.
@@ -723,25 +743,21 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
if (vectorAttr.getType().getElementType().isInteger(1)) {
- if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
- operands)))
+ if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
return 0;
- } else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
- operands)))
+ } else if (failed(
+ prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
return 0;
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
- if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
- operands)))
+ if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
return 0;
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
- opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
- : spirv::Opcode::OpConstantComposite;
+ opcode = spirv::Opcode::OpConstantComposite;
operands.reserve(arrayAttr.size() + 2);
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
for (Attribute elementAttr : arrayAttr)
- if (auto elementID =
- prepareConstant(loc, elementType, elementAttr, isSpec)) {
+ if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
operands.push_back(elementID);
} else {
return 0;
@@ -757,8 +773,8 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
}
LogicalResult Serializer::prepareBoolVectorConstant(
- Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
- spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
+ Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
+ SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@@ -773,15 +789,14 @@ LogicalResult Serializer::prepareBoolVectorConstant(
// the splat value is zero.
if (elementsAttr.isSplat()) {
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
- if (!isSpec && !elementsAttr.getSplatValue<bool>()) {
+ if (!elementsAttr.getSplatValue<bool>()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
- if (auto id = prepareConstantBool(
- loc, elementsAttr.getSplatValue<BoolAttr>(), isSpec)) {
- opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
- : spirv::Opcode::OpConstantComposite;
+ if (auto id =
+ prepareConstantBool(loc, elementsAttr.getSplatValue<BoolAttr>())) {
+ opcode = spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@@ -791,13 +806,12 @@ LogicalResult Serializer::prepareBoolVectorConstant(
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
- opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
- : spirv::Opcode::OpConstantComposite;
+ opcode = spirv::Opcode::OpConstantComposite;
for (auto boolAttr : elementsAttr.getValues<BoolAttr>()) {
// We are constructing an BoolAttr for each value here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
- if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
+ if (auto elementID = prepareConstantBool(loc, boolAttr)) {
operands.push_back(elementID);
} else {
return failure();
@@ -807,8 +821,8 @@ LogicalResult Serializer::prepareBoolVectorConstant(
}
LogicalResult Serializer::prepareIntVectorConstant(
- Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
- spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
+ Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
+ SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@@ -826,14 +840,13 @@ LogicalResult Serializer::prepareIntVectorConstant(
auto splatAttr = elementsAttr.getSplatValue<IntegerAttr>();
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
- if (!isSpec && splatAttr.getValue().isNullValue()) {
+ if (splatAttr.getValue().isNullValue()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
- if (auto id = prepareConstantInt(loc, splatAttr, isSpec)) {
- opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
- : spirv::Opcode::OpConstantComposite;
+ if (auto id = prepareConstantInt(loc, splatAttr)) {
+ opcode = spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@@ -842,15 +855,14 @@ LogicalResult Serializer::prepareIntVectorConstant(
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
- opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
- : spirv::Opcode::OpConstantComposite;
+ opcode = spirv::Opcode::OpConstantComposite;
for (auto intAttr : elementsAttr.getValues<IntegerAttr>()) {
// We are constructing an IntegerAttr for each value here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
// TODO(antiagainst): revisit this if special extensions enabling large
// vectors are supported.
- if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
+ if (auto elementID = prepareConstantInt(loc, intAttr)) {
operands.push_back(elementID);
} else {
return failure();
@@ -860,8 +872,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
}
LogicalResult Serializer::prepareFloatVectorConstant(
- Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
- spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
+ Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
+ SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
@@ -872,14 +884,13 @@ LogicalResult Serializer::prepareFloatVectorConstant(
if (elementsAttr.isSplat()) {
FloatAttr splatAttr = elementsAttr.getSplatValue<FloatAttr>();
- if (!isSpec && splatAttr.getValue().isZero()) {
+ if (splatAttr.getValue().isZero()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
- if (auto id = prepareConstantFp(loc, splatAttr, isSpec)) {
- opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
- : spirv::Opcode::OpConstantComposite;
+ if (auto id = prepareConstantFp(loc, splatAttr)) {
+ opcode = spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
@@ -887,10 +898,9 @@ LogicalResult Serializer::prepareFloatVectorConstant(
return failure();
}
- opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
- : spirv::Opcode::OpConstantComposite;
+ opcode = spirv::Opcode::OpConstantComposite;
for (auto fpAttr : elementsAttr.getValues<FloatAttr>()) {
- if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
+ if (auto elementID = prepareConstantFp(loc, fpAttr)) {
operands.push_back(elementID);
} else {
return failure();
@@ -899,10 +909,28 @@ LogicalResult Serializer::prepareFloatVectorConstant(
return success();
}
+uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
+ bool isSpec) {
+ if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
+ return prepareConstantFp(loc, floatAttr, isSpec);
+ }
+ if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
+ return prepareConstantInt(loc, intAttr, isSpec);
+ }
+ if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
+ return prepareConstantBool(loc, boolAttr, isSpec);
+ }
+
+ return 0;
+}
+
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
bool isSpec) {
- if (auto id = findConstantID(boolAttr)) {
- return id;
+ if (!isSpec) {
+ // We can de-duplicate nomral contants, but not specialization constants.
+ if (auto id = findConstantID(boolAttr)) {
+ return id;
+ }
}
// Process the type for this bool literal
@@ -919,13 +947,19 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
: spirv::Opcode::OpConstantFalse);
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
- return constIDMap[boolAttr] = resultID;
+ if (!isSpec) {
+ constIDMap[boolAttr] = resultID;
+ }
+ return resultID;
}
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
bool isSpec) {
- if (auto id = findConstantID(intAttr)) {
- return id;
+ if (!isSpec) {
+ // We can de-duplicate nomral contants, but not specialization constants.
+ if (auto id = findConstantID(intAttr)) {
+ return id;
+ }
}
// Process the type for this integer literal
@@ -972,20 +1006,26 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
- value.print(rss, /*isSigned*/ false);
+ value.print(rss, /*isSigned=*/false);
emitError(loc, "cannot serialize ")
<< bitwidth << "-bit integer literal: " << rss.str();
return 0;
}
- return constIDMap[intAttr] = resultID;
+ if (!isSpec) {
+ constIDMap[intAttr] = resultID;
+ }
+ return resultID;
}
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec) {
- if (auto id = findConstantID(floatAttr)) {
- return id;
+ if (!isSpec) {
+ // We can de-duplicate nomral contants, but not specialization constants.
+ if (auto id = findConstantID(floatAttr)) {
+ return id;
+ }
}
// Process the type for this float literal
@@ -1025,7 +1065,10 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
return 0;
}
- return constIDMap[floatAttr] = resultID;
+ if (!isSpec) {
+ constIDMap[floatAttr] = resultID;
+ }
+ return resultID;
}
//===----------------------------------------------------------------------===//
@@ -1043,12 +1086,31 @@ LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
return success();
}
+LogicalResult
+Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
+ auto constName = referenceOfOp.spec_const();
+ auto constID = findSpecConstID(constName);
+ if (!constID) {
+ return referenceOfOp.emitError(
+ "unknown result <id> for specialization constant ")
+ << constName;
+ }
+ valueIDMap[referenceOfOp.reference()] = constID;
+ return success();
+}
+
LogicalResult Serializer::processOperation(Operation *op) {
// First dispatch the methods that do not directly mirror an operation from
// the SPIR-V spec
if (auto constOp = dyn_cast<spirv::ConstantOp>(op)) {
return processConstantOp(constOp);
}
+ if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) {
+ return processSpecConstantOp(specConstOp);
+ }
+ if (auto refOpOp = dyn_cast<spirv::ReferenceOfOp>(op)) {
+ return processReferenceOfOp(refOpOp);
+ }
if (auto fnOp = dyn_cast<FuncOp>(op)) {
return processFuncOp(fnOp);
}
OpenPOWER on IntegriCloud