summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-12-27 16:24:33 -0500
committerLei Zhang <antiagainst@google.com>2019-12-27 16:25:09 -0500
commitb30d87a90ba983d76f8a6cd334ac38244bbf9ded (patch)
tree1de94d9458e552ff6c90dadb621f66521659ca1d
parentc3dbd782f1e0578c7ebc342f2e92f54d9644cff7 (diff)
downloadbcm5719-llvm-b30d87a90ba983d76f8a6cd334ac38244bbf9ded.tar.gz
bcm5719-llvm-b30d87a90ba983d76f8a6cd334ac38244bbf9ded.zip
[mlir][spirv] Add basic definitions for supporting availability
SPIR-V has a few mechanisms to control op availability: version, extension, and capabilities. These mechanisms are considered as different availability classes. This commit introduces basic definitions for modelling SPIR-V availability classes. Specifically, an `Availability` class is added to SPIRVBase.td, along with two subclasses: MinVersion and MaxVersion for versioning. SPV_Op is extended to take a list of `Availability`. Each `Availability` instance carries information for generating op interfaces for the corresponding availability class and also the concrete availability requirements. With the availability spec on ops, we can now auto-generate the op interfaces of all SPIR-V availability classes and also synthesize the op's implementations of these interfaces. The interface generation is done via new TableGen backends -gen-avail-interface-{decls|defs}. The op's implementation is done via -gen-spirv-avail-impls. Differential Revision: https://reviews.llvm.org/D71930
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt16
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td7
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td86
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td154
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td7
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h23
-rw-r--r--mlir/lib/Dialect/SPIRV/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/SPIRV/SPIRVOps.cpp8
-rw-r--r--mlir/test/CMakeLists.txt1
-rw-r--r--mlir/test/Dialect/CMakeLists.txt1
-rw-r--r--mlir/test/Dialect/SPIRV/CMakeLists.txt14
-rw-r--r--mlir/test/Dialect/SPIRV/TestAvailability.cpp73
-rw-r--r--mlir/test/Dialect/SPIRV/availability.mlir31
-rw-r--r--mlir/tools/mlir-opt/CMakeLists.txt1
-rw-r--r--mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp328
15 files changed, 728 insertions, 23 deletions
diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
index fc7180de6cb..52464789439 100644
--- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
@@ -1,8 +1,3 @@
-set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
-mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
-mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
-add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
-
add_mlir_dialect(SPIRVOps SPIRVOps)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
@@ -11,9 +6,20 @@ mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
+mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls)
+mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs)
+mlir_tablegen(SPIRVOpAvailabilityImpl.inc -gen-spirv-avail-impls)
+add_public_tablegen_target(MLIRSPIRVAvailabilityIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
add_public_tablegen_target(MLIRSPIRVSerializationGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils)
add_public_tablegen_target(MLIRSPIRVOpUtilsGen)
+
+set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
+mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
+mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
+add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
index c2ea100c121..17be79dfcfd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
@@ -120,6 +120,13 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
```
}];
+ let availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_3>,
+ Extension<[]>,
+ Capability<[SPV_C_Kernel]>
+ ];
+
let arguments = (ins
SPV_AnyPtr:$pointer,
SPV_ScopeAttr:$memory_scope,
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
new file mode 100644
index 00000000000..8ec74ac955a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
@@ -0,0 +1,86 @@
+//===- SPIRVAvailability.td - Op Availability Base file ----*- tablegen -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_AVAILABILITY
+#define SPIRV_AVAILABILITY
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Op availaility definitions
+//===----------------------------------------------------------------------===//
+
+// The base class for defining op availability dimensions.
+class Availability {
+ // The following are fields for controlling the generated C++ OpInterface.
+
+ // The name for the generated C++ OpInterface subclass.
+ string interfaceName = ?;
+ // The documentation for the generated C++ OpInterface subclass.
+ string interfaceDescription = "";
+
+ // The following are fields for controlling the query function signature.
+
+ // The query function's return type in the generated C++ OpInterface subclass.
+ string queryFnRetType = ?;
+ // The query function's name in the generated C++ OpInterface subclass.
+ string queryFnName = ?;
+
+ // The following are fields for controlling the query function implementation.
+
+ // The logic for merging two availability requirements. This is used to derive
+ // the final availability requirement when, for example, an op has two
+ // operands and these two operands have different availability requirements.
+ //
+ // The code should use `$overall` as the placeholder for the final requirement
+ // and `$instance` for the current availability requirement instance.
+ code mergeAction = ?;
+ // The initializer for the final availability requirement.
+ string initializer = ?;
+ // An availability instance's type.
+ string instanceType = ?;
+
+ // The following are fields for a concrete availability instance.
+
+ // The availability requirement carried by a concrete instance.
+ string instance = ?;
+}
+
+class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
+ : Availability {
+ let interfaceName = name;
+
+ let queryFnRetType = scheme.returnType;
+ let queryFnName = "getMinVersion";
+
+ let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
+ "std::max($overall, $instance))";
+ let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))";
+ let instanceType = scheme.cppNamespace # "::" # scheme.className;
+
+ let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+ min.symbol;
+}
+
+class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
+ : Availability {
+ let interfaceName = name;
+
+ let queryFnRetType = scheme.returnType;
+ let queryFnName = "getMaxVersion";
+
+ let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
+ "std::min($overall, $instance))";
+ let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))";
+ let instanceType = scheme.cppNamespace # "::" # scheme.className;
+
+ let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+ max.symbol;
+}
+
+#endif // SPIRV_AVAILABILITY
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
index 5751a32e169..acbbbfc296b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
@@ -16,6 +16,7 @@
#define SPIRV_BASE
include "mlir/IR/OpBase.td"
+include "mlir/Dialect/SPIRV/SPIRVAvailability.td"
//===----------------------------------------------------------------------===//
// SPIR-V dialect definitions
@@ -46,6 +47,142 @@ def SPV_Dialect : Dialect {
}
//===----------------------------------------------------------------------===//
+// SPIR-V availability definitions
+//===----------------------------------------------------------------------===//
+
+def SPV_V_1_0 : I32EnumAttrCase<"V_1_0", 0>;
+def SPV_V_1_1 : I32EnumAttrCase<"V_1_1", 1>;
+def SPV_V_1_2 : I32EnumAttrCase<"V_1_2", 2>;
+def SPV_V_1_3 : I32EnumAttrCase<"V_1_3", 3>;
+def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4>;
+def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5>;
+
+def SPV_VersionAttr : I32EnumAttr<"Version", "valid SPIR-V version", [
+ SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5]> {
+ let cppNamespace = "::mlir::spirv";
+}
+
+class MinVersion<I32EnumAttrCase min> : MinVersionBase<
+ "QueryMinVersionInterface", SPV_VersionAttr, min> {
+ let interfaceDescription = [{
+ Querying interface for minimal required SPIR-V version.
+
+ This interface provides a `getMinVersion()` method to query the minimal
+ required version for the implementing SPIR-V operation. The returned value
+ is a `mlir::spirv::Version` enumerant.
+ }];
+}
+
+class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
+ "QueryMaxVersionInterface", SPV_VersionAttr, max> {
+ let interfaceDescription = [{
+ Querying interface for maximal supported SPIR-V version.
+
+ This interface provides a `getMaxVersion()` method to query the maximal
+ supported version for the implementing SPIR-V operation. The returned value
+ is a `mlir::spirv::Version` enumerant.
+ }];
+}
+
+class Extension<list<StrEnumAttrCase> extensions> : Availability {
+ let interfaceName = "QueryExtensionInterface";
+ let interfaceDescription = [{
+ Querying interface for required SPIR-V extensions.
+
+ This interface provides a `getExtensions()` method to query the required
+ extensions for the implementing SPIR-V operation. The returned value
+ is a nested vector whose element is `mlir::spirv::Extension`s. The outer
+ vector's elements (which are vectors) should be interpreted as conjunction
+ while the innner vector's elements (which are `mlir::spirv::Extension`s)
+ should be interpreted as disjunction. For example, given
+
+ ```
+ {{Extension::A, Extension::B}, {Extension::C}, {{Extension::D, Extension::E}}
+ ```
+
+ The operation instance is available when (`Extension::A` OR `Extension::B`)
+ AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled.
+ }];
+
+ // TODO(antiagainst): Using SmallVector<SmallVector<...>> is an anti-pattern.
+ // Find a better way for this.
+ let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
+ "::mlir::spirv::Extension, 1>, 1>";
+ let queryFnName = "getExtensions";
+
+ let mergeAction = !if(
+ !empty(extensions), "", "$overall.emplace_back($instance)");
+ let initializer = "{}";
+ let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>";
+
+ // Compose all capabilities as an C++ initializer list
+ let instance = "std::initializer_list<::mlir::spirv::Extension>{" #
+ StrJoin<!foreach(
+ ext, extensions,
+ "::mlir::spirv::Extension::" # ext.symbol)>.result #
+ "}";
+}
+
+class Capability<list<I32EnumAttrCase> capabilities> : Availability {
+ let interfaceName = "QueryCapabilityInterface";
+ let interfaceDescription = [{
+ Querying interface for required SPIR-V capabilities.
+
+ This interface provides a `getCapabilities()` method to query the required
+ capabilities for the implementing SPIR-V operation. The returned value
+ is a neted vector whose element is `mlir::spirv::Capability`s. The outer
+ vector's elements (which are vectors) should be interpreted as conjunction
+ while the innner vector's elements (which are `mlir::spirv::Capability`s)
+ should be interpreted as disjunction. For example, given
+
+ ```
+ {{Capability::A, Capability::B}, {Capability::C}, {{Capability::D, Capability::E}}
+ ```
+
+ The operation instance is available when (`Capability::A` OR `Capability::B`)
+ AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled.
+ }];
+
+ let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
+ "::mlir::spirv::Capability, 1>, 1>";
+ let queryFnName = "getCapabilities";
+
+ let mergeAction = !if(
+ !empty(capabilities), "", "$overall.emplace_back($instance)");
+ let initializer = "{}";
+ let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>";
+
+ // Compose all capabilities as an C++ initializer list
+ let instance = "std::initializer_list<::mlir::spirv::Capability>{" #
+ StrJoin<!foreach(
+ cap, capabilities,
+ "::mlir::spirv::Capability::" # cap.symbol)>.result #
+ "}";
+}
+
+// TODO(antiagainst): the following interfaces definitions are duplicating with
+// the above. Remove them once we are able to support dialect-specific contents
+// in ODS.
+def QueryMinVersionInterface : OpInterface<"QueryMinVersionInterface"> {
+ let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">];
+}
+def QueryMaxVersionInterface : OpInterface<"QueryMaxVersionInterface"> {
+ let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">];
+}
+def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> {
+ let methods = [InterfaceMethod<
+ "",
+ "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>",
+ "getExtensions">];
+}
+def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> {
+ let methods = [InterfaceMethod<
+ "",
+ "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>",
+ "getCapabilities">];
+}
+
+//===----------------------------------------------------------------------===//
// SPIR-V extension definitions
//===----------------------------------------------------------------------===//
@@ -1216,7 +1353,22 @@ def SPV_OpcodeAttr :
// Base class for all SPIR-V ops.
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
- Op<SPV_Dialect, mnemonic, traits> {
+ Op<SPV_Dialect, mnemonic, !listconcat(traits, [
+ // TODO(antiagainst): We don't need all of the following traits for
+ // every op; only the suitabble ones should be added automatically
+ // after ODS supports dialect-specific contents.
+ DeclareOpInterfaceMethods<QueryMinVersionInterface>,
+ DeclareOpInterfaceMethods<QueryMaxVersionInterface>,
+ DeclareOpInterfaceMethods<QueryExtensionInterface>,
+ DeclareOpInterfaceMethods<QueryCapabilityInterface>
+ ])> {
+ // Availability specification for this op itself.
+ list<Availability> availability = [
+ MinVersion<SPV_V_1_0>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[]>,
+ Capability<[]>
+ ];
// For each SPIR-V op, the following static functions need to be defined
// in SPVOps.cpp:
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
index f3a9a61a9e9..1ac0ae1b969 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
@@ -53,6 +53,13 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
```
}];
+ let availability = [
+ MinVersion<SPV_V_1_3>,
+ MaxVersion<SPV_V_1_5>,
+ Extension<[]>,
+ Capability<[SPV_C_GroupNonUniformBallot]>
+ ];
+
let arguments = (ins
SPV_ScopeAttr:$execution_scope,
SPV_Bool:$predicate
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
index 2fa417bfe25..3806418593f 100644
--- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
+++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
@@ -21,18 +21,23 @@ class OpBuilder;
namespace spirv {
+// TableGen'erated operation interfaces for querying versions, extensions, and
+// capabilities.
+#include "mlir/Dialect/SPIRV/SPIRVAvailability.h.inc"
+
+// TablenGen'erated operation declarations.
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
-/// Following methods are auto-generated.
-///
-/// Get the name used in the Op to refer to an enum value of the given
-/// `EnumClass`.
-/// template <typename EnumClass> StringRef attributeName();
-///
-/// Get the function that can be used to symbolize an enum value.
-/// template <typename EnumClass>
-/// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
+// TableGen'erated helper functions.
+//
+// Get the name used in the Op to refer to an enum value of the given
+// `EnumClass`.
+// template <typename EnumClass> StringRef attributeName();
+//
+// Get the function that can be used to symbolize an enum value.
+// template <typename EnumClass>
+// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
#include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
} // end namespace spirv
diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
index 2c3b1b95a68..d3af53e3aaa 100644
--- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt
+++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt
@@ -15,6 +15,7 @@ add_llvm_library(MLIRSPIRV
)
add_dependencies(MLIRSPIRV
+ MLIRSPIRVAvailabilityIncGen
MLIRSPIRVCanonicalizationIncGen
MLIRSPIRVEnumsIncGen
MLIRSPIRVLoweringStructGen
diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
index f42c077f77e..1de7bceaf23 100644
--- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
@@ -3063,8 +3063,16 @@ static LogicalResult verify(spirv::VariableOp varOp) {
namespace mlir {
namespace spirv {
+// TableGen'erated operation interfaces for querying versions, extensions, and
+// capabilities.
+#include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"
+
+// TablenGen'erated operation definitions.
#define GET_OP_CLASSES
#include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
+// TableGen'erated operation availability interface implementations.
+#include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 95792548221..571a0d863dd 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -1,3 +1,4 @@
+add_subdirectory(Dialect)
add_subdirectory(EDSC)
add_subdirectory(mlir-cpu-runner)
add_subdirectory(SDBM)
diff --git a/mlir/test/Dialect/CMakeLists.txt b/mlir/test/Dialect/CMakeLists.txt
new file mode 100644
index 00000000000..cc1766c6127
--- /dev/null
+++ b/mlir/test/Dialect/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(SPIRV)
diff --git a/mlir/test/Dialect/SPIRV/CMakeLists.txt b/mlir/test/Dialect/SPIRV/CMakeLists.txt
new file mode 100644
index 00000000000..25ea9625ae8
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/CMakeLists.txt
@@ -0,0 +1,14 @@
+add_llvm_library(MLIRSPIRVTestPasses
+ TestAvailability.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+ )
+
+target_link_libraries(MLIRSPIRVTestPasses
+ MLIRIR
+ MLIRPass
+ MLIRSPIRV
+ MLIRSupport
+ )
diff --git a/mlir/test/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/Dialect/SPIRV/TestAvailability.cpp
new file mode 100644
index 00000000000..bb164215995
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/TestAvailability.cpp
@@ -0,0 +1,73 @@
+//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass for testing SPIR-V op availability.
+struct TestAvailability : public FunctionPass<TestAvailability> {
+ void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestAvailability::runOnFunction() {
+ auto f = getFunction();
+ llvm::outs() << f.getName() << "\n";
+
+ Dialect *spvDialect = getContext().getRegisteredDialect("spv");
+
+ f.getOperation()->walk([&](Operation *op) {
+ if (op->getDialect() != spvDialect)
+ return WalkResult::advance();
+
+ auto &os = llvm::outs();
+
+ if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
+ os << " min version: "
+ << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
+
+ if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
+ os << " max version: "
+ << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
+
+ if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
+ os << " extensions: [";
+ for (const auto &exts : extension.getExtensions()) {
+ os << " [";
+ interleaveComma(exts, os, [&](spirv::Extension ext) {
+ os << spirv::stringifyExtension(ext);
+ });
+ os << "]";
+ }
+ os << " ]\n";
+ }
+
+ if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
+ os << " capabilities: [";
+ for (const auto &caps : capability.getCapabilities()) {
+ os << " [";
+ interleaveComma(caps, os, [&](spirv::Capability cap) {
+ os << spirv::stringifyCapability(cap);
+ });
+ os << "]";
+ }
+ os << " ]\n";
+ }
+ os.flush();
+
+ return WalkResult::advance();
+ });
+}
+
+static PassRegistration<TestAvailability> pass("test-spirv-op-availability",
+ "Test SPIR-V op availability");
diff --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir
new file mode 100644
index 00000000000..ed4d29cc51d
--- /dev/null
+++ b/mlir/test/Dialect/SPIRV/availability.mlir
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s
+
+// CHECK-LABEL: iadd
+func @iadd(%arg: i32) -> i32 {
+ // CHECK: min version: V_1_0
+ // CHECK: max version: V_1_5
+ // CHECK: extensions: [ ]
+ // CHECK: capabilities: [ ]
+ %0 = spv.IAdd %arg, %arg: i32
+ return %0: i32
+}
+
+// CHECK: atomic_compare_exchange_weak
+func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 {
+ // CHECK: min version: V_1_0
+ // CHECK: max version: V_1_3
+ // CHECK: extensions: [ ]
+ // CHECK: capabilities: [ [Kernel] ]
+ %0 = spv.AtomicCompareExchangeWeak "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
+ return %0: i32
+}
+
+// CHECK-LABEL: subgroup_ballot
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+ // CHECK: min version: V_1_3
+ // CHECK: max version: V_1_5
+ // CHECK: extensions: [ ]
+ // CHECK: capabilities: [ [GroupNonUniformBallot] ]
+ %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+ return %0: vector<4xi32>
+}
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index b30d7e39ce8..1281569ef27 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -41,6 +41,7 @@ set(LIBS
MLIRROCDLIR
MLIRSPIRV
MLIRStandardToSPIRVTransforms
+ MLIRSPIRVTestPasses
MLIRSPIRVTransforms
MLIRStandardOps
MLIRStandardToLLVM
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index d65b216e109..639f01458a6 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -13,6 +13,7 @@
#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"
@@ -44,6 +45,233 @@ using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
+// Availability Wrapper Class
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Wrapper class with helper methods for accessing availability defined in
+// TableGen.
+class Availability {
+public:
+ explicit Availability(const Record *def);
+
+ // Returns the name of the direct TableGen class for this availability
+ // instance.
+ StringRef getClass() const;
+
+ // Returns the generated C++ interface's class name.
+ StringRef getInterfaceClassName() const;
+
+ // Returns the generated C++ interface's description.
+ StringRef getInterfaceDescription() const;
+
+ // Returns the name of the query function insided the generated C++ interface.
+ StringRef getQueryFnName() const;
+
+ // Returns the return type of the query function insided the generated C++
+ // interface.
+ StringRef getQueryFnRetType() const;
+
+ // Returns the code for merging availability requirements.
+ StringRef getMergeActionCode() const;
+
+ // Returns the initializer expression for initializing the final availability
+ // requirements.
+ StringRef getMergeInitializer() const;
+
+ // Returns the C++ type for an availability instance.
+ StringRef getMergeInstanceType() const;
+
+ // Returns the concrete availability instance carried in this case.
+ StringRef getMergeInstance() const;
+
+private:
+ // The TableGen definition of this availability.
+ const llvm::Record *def;
+};
+} // namespace
+
+Availability::Availability(const llvm::Record *def) : def(def) {
+ assert(def->isSubClassOf("Availability") &&
+ "must be subclass of TableGen 'Availability' class");
+}
+
+StringRef Availability::getClass() const {
+ SmallVector<Record *, 1> parentClass;
+ def->getDirectSuperClasses(parentClass);
+ if (parentClass.size() != 1) {
+ PrintFatalError(def->getLoc(),
+ "expected to only have one direct superclass");
+ }
+ return parentClass.front()->getName();
+}
+
+StringRef Availability::getInterfaceClassName() const {
+ return def->getValueAsString("interfaceName");
+}
+
+StringRef Availability::getInterfaceDescription() const {
+ return def->getValueAsString("interfaceDescription");
+}
+
+StringRef Availability::getQueryFnRetType() const {
+ return def->getValueAsString("queryFnRetType");
+}
+
+StringRef Availability::getQueryFnName() const {
+ return def->getValueAsString("queryFnName");
+}
+
+StringRef Availability::getMergeActionCode() const {
+ return def->getValueAsString("mergeAction");
+}
+
+StringRef Availability::getMergeInitializer() const {
+ return def->getValueAsString("initializer");
+}
+
+StringRef Availability::getMergeInstanceType() const {
+ return def->getValueAsString("instanceType");
+}
+
+StringRef Availability::getMergeInstance() const {
+ return def->getValueAsString("instance");
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Definitions AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceDef(const Availability &availability,
+ raw_ostream &os) {
+ StringRef methodName = availability.getQueryFnName();
+ os << availability.getQueryFnRetType() << " "
+ << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
+ << " return getImpl()->" << methodName << "(getOperation());\n"
+ << "}\n";
+}
+
+static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ llvm::emitSourceFileHeader("Availability Interface Definitions", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
+ SmallVector<const Record *, 1> handledClasses;
+ for (const Record *def : defs) {
+ SmallVector<Record *, 1> parent;
+ def->getDirectSuperClasses(parent);
+ if (parent.size() != 1) {
+ PrintFatalError(def->getLoc(),
+ "expected to only have one direct superclass");
+ }
+ if (llvm::is_contained(handledClasses, parent.front()))
+ continue;
+
+ Availability availability(def);
+ emitInterfaceDef(availability, os);
+ handledClasses.push_back(parent.front());
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Declarations AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
+ os << " class Concept {\n"
+ << " public:\n"
+ << " virtual ~Concept() = default;\n"
+ << " virtual " << availability.getQueryFnRetType() << " "
+ << availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n"
+ << " };\n";
+}
+
+static void emitModelDecl(const Availability &availability, raw_ostream &os) {
+ os << " template<typename ConcreteOp>\n";
+ os << " class Model : public Concept {\n"
+ << " public:\n"
+ << " " << availability.getQueryFnRetType() << " "
+ << availability.getQueryFnName()
+ << "(Operation *tblgen_opaque_op) final {\n"
+ << " auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
+ << " (void)op;\n"
+ // Forward to the method on the concrete operation type.
+ << " return op." << availability.getQueryFnName() << "();\n"
+ << " }\n"
+ << " };\n";
+}
+
+static void emitInterfaceDecl(const Availability &availability,
+ raw_ostream &os) {
+ StringRef interfaceName = availability.getInterfaceClassName();
+ std::string interfaceTraitsName = formatv("{0}Traits", interfaceName);
+
+ // Emit the traits struct containing the concept and model declarations.
+ os << "namespace detail {\n"
+ << "struct " << interfaceTraitsName << " {\n";
+ emitConceptDecl(availability, os);
+ os << '\n';
+ emitModelDecl(availability, os);
+ os << "};\n} // end namespace detail\n\n";
+
+ // Emit the main interface class declaration.
+ os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
+ os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
+ "public:\n"
+ " using OpInterface<{1}, detail::{2}>::OpInterface;\n",
+ interfaceName, interfaceName, interfaceTraitsName);
+
+ // Emit query function declaration.
+ os << " " << availability.getQueryFnRetType() << " "
+ << availability.getQueryFnName() << "();\n";
+ os << "};\n\n";
+}
+
+static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ llvm::emitSourceFileHeader("Availability Interface Declarations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
+ SmallVector<const Record *, 4> handledClasses;
+ for (const Record *def : defs) {
+ SmallVector<Record *, 1> parent;
+ def->getDirectSuperClasses(parent);
+ if (parent.size() != 1) {
+ PrintFatalError(def->getLoc(),
+ "expected to only have one direct superclass");
+ }
+ if (llvm::is_contained(handledClasses, parent.front()))
+ continue;
+
+ Availability avail(def);
+ emitInterfaceDecl(avail, os);
+ handledClasses.push_back(parent.front());
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Hook Registration
+//===----------------------------------------------------------------------===//
+
+// Registers the operation interface generator to mlir-tblgen.
+static mlir::GenRegistration
+ genInterfaceDecls("gen-avail-interface-decls",
+ "Generate availability interface declarations",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitInterfaceDecls(records, os);
+ });
+
+// Registers the operation interface generator to mlir-tblgen.
+static mlir::GenRegistration
+ genInterfaceDefs("gen-avail-interface-defs",
+ "Generate op interface definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitInterfaceDefs(records, os);
+ });
+
+//===----------------------------------------------------------------------===//
// Serialization AutoGen
//===----------------------------------------------------------------------===//
@@ -651,6 +879,17 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
}
//===----------------------------------------------------------------------===//
+// Serialization Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genSerialization(
+ "gen-spirv-serialization",
+ "Generate SPIR-V (de)serialization utilities and functions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitSerializationFns(records, os);
+ });
+
+//===----------------------------------------------------------------------===//
// Op Utils AutoGen
//===----------------------------------------------------------------------===//
@@ -707,19 +946,92 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
}
//===----------------------------------------------------------------------===//
-// Hook Registration
+// Op Utils Hook Registration
//===----------------------------------------------------------------------===//
-static mlir::GenRegistration genSerialization(
- "gen-spirv-serialization",
- "Generate SPIR-V (de)serialization utilities and functions",
- [](const RecordKeeper &records, raw_ostream &os) {
- return emitSerializationFns(records, os);
- });
-
static mlir::GenRegistration
genOpUtils("gen-spirv-op-utils",
"Generate SPIR-V operation utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitOpUtils(records, os);
});
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Availability Impl AutoGen
+//===----------------------------------------------------------------------===//
+
+// Returns the availability spec of the given `def`.
+std::vector<Availability> getAvailabilities(const Record &def) {
+ std::vector<Availability> availabilities;
+ if (auto *availListInit = def.getValueAsListInit("availability")) {
+ availabilities.reserve(availListInit->size());
+ for (auto *availInit : *availListInit)
+ availabilities.emplace_back(
+ llvm::cast<llvm::DefInit>(availInit)->getDef());
+ }
+ return availabilities;
+}
+
+static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
+ mlir::tblgen::FmtContext fctx;
+ fctx.addSubst("overall", "overall");
+
+ std::vector<Availability> opAvailabilities =
+ getAvailabilities(srcOp.getDef());
+
+ // First collect all availablity classes this op should implement.
+ // All availablity instances keep information for the generated interface and
+ // the instance's specific requirement. Here we remember a random instance so
+ // we can get the information regarding the generated interface.
+ llvm::StringMap<Availability> availClasses;
+ for (const Availability &avail : opAvailabilities)
+ availClasses.try_emplace(avail.getClass(), avail);
+
+ // Then generate implementation for each availability class.
+ for (const auto &availClass : availClasses) {
+ StringRef availClassName = availClass.getKey();
+ Availability avail = availClass.getValue();
+
+ // Generate the implementation method signature.
+ os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
+ srcOp.getCppClassName(), avail.getQueryFnName());
+
+ // Create the variable for the final requirement and initialize it.
+ os << formatv(" {0} overall = {1};\n", avail.getQueryFnRetType(),
+ avail.getMergeInitializer());
+
+ // Update with the op's specific availability spec.
+ for (const Availability &avail : opAvailabilities)
+ if (avail.getClass() == availClassName) {
+ os << " "
+ << tgfmt(avail.getMergeActionCode(),
+ &fctx.addSubst("instance", avail.getMergeInstance()))
+ << ";\n";
+ }
+ os << " return overall;\n";
+ os << "}\n";
+ }
+}
+
+static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os);
+
+ auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
+ for (const auto *def : defs) {
+ Operator op(def);
+ emitAvailabilityImpl(op, os);
+ }
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Op Availability Implementation Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+ genOpAvailabilityImpl("gen-spirv-avail-impls",
+ "Generate SPIR-V operation utility definitions",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return emitAvailabilityImpl(records, os);
+ });
OpenPOWER on IntegriCloud