summaryrefslogtreecommitdiffstats
path: root/mlir/unittests/Dialect
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-07-30 10:21:25 -0700
committerA. Unique TensorFlower <gardener@tensorflow.org>2019-07-30 11:55:33 -0700
commit4a55bd5f28e64a0c134adfbbcc20e3ea3af937c6 (patch)
treef5d91c01b3b45078272f670b3c9842ef7d8e22ef /mlir/unittests/Dialect
parent4598c04dfe00f3aa493c8a171584b79706c2cdaa (diff)
downloadbcm5719-llvm-4a55bd5f28e64a0c134adfbbcc20e3ea3af937c6.tar.gz
bcm5719-llvm-4a55bd5f28e64a0c134adfbbcc20e3ea3af937c6.zip
[spirv] Add basic infrastructure for negative deserializer tests
We are relying on serializer to construct positive cases to drive the test for deserializer. This leaves negative cases untested. This CL adds a basic test fixture for covering the negative corner cases to enforce a more robust deserializer. Refactored common SPIR-V building methods out of serializer to share it with the deserialization test. PiperOrigin-RevId: 260742733
Diffstat (limited to 'mlir/unittests/Dialect')
-rw-r--r--mlir/unittests/Dialect/CMakeLists.txt2
-rw-r--r--mlir/unittests/Dialect/SPIRV/CMakeLists.txt8
-rw-r--r--mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp203
3 files changed, 213 insertions, 0 deletions
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt
index 4c622cbcf7f..87eccaae8a3 100644
--- a/mlir/unittests/Dialect/CMakeLists.txt
+++ b/mlir/unittests/Dialect/CMakeLists.txt
@@ -5,3 +5,5 @@ target_link_libraries(MLIRDialectTests
PRIVATE
MLIRIR
MLIRDialect)
+
+add_subdirectory(SPIRV)
diff --git a/mlir/unittests/Dialect/SPIRV/CMakeLists.txt b/mlir/unittests/Dialect/SPIRV/CMakeLists.txt
new file mode 100644
index 00000000000..4e851601f27
--- /dev/null
+++ b/mlir/unittests/Dialect/SPIRV/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_unittest(MLIRSPIRVTests
+ DeserializationTest.cpp
+)
+target_link_libraries(MLIRSPIRVTests
+ PRIVATE
+ MLIRSPIRV
+ MLIRSPIRVSerialization)
+
diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
new file mode 100644
index 00000000000..e4b3ee51d2c
--- /dev/null
+++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
@@ -0,0 +1,203 @@
+//===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// The purpose of this file is to provide negative deserialization tests.
+// For positive deserialization tests, please use serialization and
+// deserialization for roundtripping.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Serialization.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/MLIRContext.h"
+#include "gmock/gmock.h"
+
+#include <memory>
+
+using namespace mlir;
+
+using ::testing::StrEq;
+
+//===----------------------------------------------------------------------===//
+// Test Fixture
+//===----------------------------------------------------------------------===//
+
+/// A deserialization test fixture providing minimal SPIR-V building and
+/// diagnostic checking utilities.
+class DeserializationTest : public ::testing::Test {
+protected:
+ DeserializationTest() {
+ // Register a diagnostic handler to capture the diagnostic so that we can
+ // check it later.
+ context.getDiagEngine().setHandler([&](Diagnostic diag) {
+ diagnostic.reset(new Diagnostic(std::move(diag)));
+ });
+ }
+
+ /// Performs deserialization and returns the constructed spv.module op.
+ Optional<spirv::ModuleOp> deserialize() {
+ return spirv::deserialize(binary, &context);
+ }
+
+ /// Checks there is a diagnostic generated with the given `errorMessage`.
+ void expectDiagnostic(StringRef errorMessage) {
+ ASSERT_NE(nullptr, diagnostic.get());
+
+ // TODO(antiagainst): check error location too.
+ EXPECT_THAT(diagnostic->str(), StrEq(errorMessage));
+ }
+
+ //===--------------------------------------------------------------------===//
+ // SPIR-V builder methods
+ //===--------------------------------------------------------------------===//
+
+ /// Adds the SPIR-V module header to `binary`.
+ void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); }
+
+ /// Adds the SPIR-V instruction into `binary`.
+ void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
+ uint32_t wordCount = 1 + operands.size();
+ assert(((wordCount >> 16) == 0) && "word count out of range!");
+
+ uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op);
+ binary.push_back(prefixedOpcode);
+ binary.append(operands.begin(), operands.end());
+ }
+
+ uint32_t addVoidType() {
+ auto id = nextID++;
+ addInstruction(spirv::Opcode::OpTypeVoid, {id});
+ return id;
+ }
+
+ uint32_t addIntType(uint32_t bitwidth) {
+ auto id = nextID++;
+ addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
+ return id;
+ }
+
+ uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
+ auto id = nextID++;
+ SmallVector<uint32_t, 4> operands;
+ operands.push_back(id);
+ operands.push_back(retType);
+ operands.append(paramTypes.begin(), paramTypes.end());
+ addInstruction(spirv::Opcode::OpTypeFunction, operands);
+ return id;
+ }
+
+ uint32_t addFunction(uint32_t retType, uint32_t fnType) {
+ auto id = nextID++;
+ addInstruction(spirv::Opcode::OpFunction,
+ {retType, id,
+ static_cast<uint32_t>(spirv::FunctionControl::None),
+ fnType});
+ return id;
+ }
+
+ uint32_t addFunctionEnd() {
+ auto id = nextID++;
+ addInstruction(spirv::Opcode::OpFunctionEnd, {id});
+ return id;
+ }
+
+protected:
+ SmallVector<uint32_t, 5> binary;
+ uint32_t nextID = 1;
+ MLIRContext context;
+ std::unique_ptr<Diagnostic> diagnostic;
+};
+
+//===----------------------------------------------------------------------===//
+// Basics
+//===----------------------------------------------------------------------===//
+
+TEST_F(DeserializationTest, EmptyModuleFailure) {
+ ASSERT_EQ(llvm::None, deserialize());
+ expectDiagnostic("SPIR-V binary module must have a 5-word header");
+}
+
+TEST_F(DeserializationTest, WrongMagicNumberFailure) {
+ addHeader();
+ binary.front() = 0xdeadbeef; // Change to a wrong magic number
+ ASSERT_EQ(llvm::None, deserialize());
+ expectDiagnostic("incorrect magic number");
+}
+
+TEST_F(DeserializationTest, OnlyHeaderSuccess) {
+ addHeader();
+ EXPECT_NE(llvm::None, deserialize());
+}
+
+TEST_F(DeserializationTest, ZeroWordCountFailure) {
+ addHeader();
+ binary.push_back(0); // OpNop with zero word count
+
+ ASSERT_EQ(llvm::None, deserialize());
+ expectDiagnostic("word count cannot be zero");
+}
+
+TEST_F(DeserializationTest, InsufficientWordFailure) {
+ addHeader();
+ binary.push_back((2u << 16) |
+ static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
+ // Missing word for type <id>
+
+ ASSERT_EQ(llvm::None, deserialize());
+ expectDiagnostic("insufficient words for the last instruction");
+}
+
+//===----------------------------------------------------------------------===//
+// Types
+//===----------------------------------------------------------------------===//
+
+TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
+ addHeader();
+ addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});
+
+ ASSERT_EQ(llvm::None, deserialize());
+ expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
+}
+
+//===----------------------------------------------------------------------===//
+// Functions
+//===----------------------------------------------------------------------===//
+
+TEST_F(DeserializationTest, FunctionMissingEndFailure) {
+ addHeader();
+ auto voidType = addVoidType();
+ auto fnType = addFunctionType(voidType, {});
+ addFunction(voidType, fnType);
+ // Missing OpFunctionEnd
+
+ ASSERT_EQ(llvm::None, deserialize());
+ expectDiagnostic("expected OpFunctionEnd instruction");
+}
+
+TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
+ addHeader();
+ auto voidType = addVoidType();
+ auto i32Type = addIntType(32);
+ auto fnType = addFunctionType(voidType, {i32Type});
+ addFunction(voidType, fnType);
+ // Missing OpFunctionParameter
+
+ ASSERT_EQ(llvm::None, deserialize());
+ expectDiagnostic("expected OpFunctionParameter instruction");
+}
OpenPOWER on IntegriCloud