diff options
author | Lei Zhang <antiagainst@google.com> | 2019-07-30 10:21:25 -0700 |
---|---|---|
committer | A. Unique TensorFlower <gardener@tensorflow.org> | 2019-07-30 11:55:33 -0700 |
commit | 4a55bd5f28e64a0c134adfbbcc20e3ea3af937c6 (patch) | |
tree | f5d91c01b3b45078272f670b3c9842ef7d8e22ef /mlir/unittests/Dialect | |
parent | 4598c04dfe00f3aa493c8a171584b79706c2cdaa (diff) | |
download | bcm5719-llvm-4a55bd5f28e64a0c134adfbbcc20e3ea3af937c6.tar.gz bcm5719-llvm-4a55bd5f28e64a0c134adfbbcc20e3ea3af937c6.zip |
[spirv] Add basic infrastructure for negative deserializer tests
We are relying on serializer to construct positive cases to drive
the test for deserializer. This leaves negative cases untested.
This CL adds a basic test fixture for covering the negative
corner cases to enforce a more robust deserializer.
Refactored common SPIR-V building methods out of serializer to
share it with the deserialization test.
PiperOrigin-RevId: 260742733
Diffstat (limited to 'mlir/unittests/Dialect')
-rw-r--r-- | mlir/unittests/Dialect/CMakeLists.txt | 2 | ||||
-rw-r--r-- | mlir/unittests/Dialect/SPIRV/CMakeLists.txt | 8 | ||||
-rw-r--r-- | mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp | 203 |
3 files changed, 213 insertions, 0 deletions
diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 4c622cbcf7f..87eccaae8a3 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -5,3 +5,5 @@ target_link_libraries(MLIRDialectTests PRIVATE MLIRIR MLIRDialect) + +add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/SPIRV/CMakeLists.txt b/mlir/unittests/Dialect/SPIRV/CMakeLists.txt new file mode 100644 index 00000000000..4e851601f27 --- /dev/null +++ b/mlir/unittests/Dialect/SPIRV/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRSPIRVTests + DeserializationTest.cpp +) +target_link_libraries(MLIRSPIRVTests + PRIVATE + MLIRSPIRV + MLIRSPIRVSerialization) + diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp new file mode 100644 index 00000000000..e4b3ee51d2c --- /dev/null +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -0,0 +1,203 @@ +//===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// The purpose of this file is to provide negative deserialization tests. +// For positive deserialization tests, please use serialization and +// deserialization for roundtripping. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "gmock/gmock.h" + +#include <memory> + +using namespace mlir; + +using ::testing::StrEq; + +//===----------------------------------------------------------------------===// +// Test Fixture +//===----------------------------------------------------------------------===// + +/// A deserialization test fixture providing minimal SPIR-V building and +/// diagnostic checking utilities. +class DeserializationTest : public ::testing::Test { +protected: + DeserializationTest() { + // Register a diagnostic handler to capture the diagnostic so that we can + // check it later. + context.getDiagEngine().setHandler([&](Diagnostic diag) { + diagnostic.reset(new Diagnostic(std::move(diag))); + }); + } + + /// Performs deserialization and returns the constructed spv.module op. + Optional<spirv::ModuleOp> deserialize() { + return spirv::deserialize(binary, &context); + } + + /// Checks there is a diagnostic generated with the given `errorMessage`. + void expectDiagnostic(StringRef errorMessage) { + ASSERT_NE(nullptr, diagnostic.get()); + + // TODO(antiagainst): check error location too. + EXPECT_THAT(diagnostic->str(), StrEq(errorMessage)); + } + + //===--------------------------------------------------------------------===// + // SPIR-V builder methods + //===--------------------------------------------------------------------===// + + /// Adds the SPIR-V module header to `binary`. + void addHeader() { spirv::appendModuleHeader(binary, /*idBound=*/0); } + + /// Adds the SPIR-V instruction into `binary`. + void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) { + uint32_t wordCount = 1 + operands.size(); + assert(((wordCount >> 16) == 0) && "word count out of range!"); + + uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op); + binary.push_back(prefixedOpcode); + binary.append(operands.begin(), operands.end()); + } + + uint32_t addVoidType() { + auto id = nextID++; + addInstruction(spirv::Opcode::OpTypeVoid, {id}); + return id; + } + + uint32_t addIntType(uint32_t bitwidth) { + auto id = nextID++; + addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1}); + return id; + } + + uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) { + auto id = nextID++; + SmallVector<uint32_t, 4> operands; + operands.push_back(id); + operands.push_back(retType); + operands.append(paramTypes.begin(), paramTypes.end()); + addInstruction(spirv::Opcode::OpTypeFunction, operands); + return id; + } + + uint32_t addFunction(uint32_t retType, uint32_t fnType) { + auto id = nextID++; + addInstruction(spirv::Opcode::OpFunction, + {retType, id, + static_cast<uint32_t>(spirv::FunctionControl::None), + fnType}); + return id; + } + + uint32_t addFunctionEnd() { + auto id = nextID++; + addInstruction(spirv::Opcode::OpFunctionEnd, {id}); + return id; + } + +protected: + SmallVector<uint32_t, 5> binary; + uint32_t nextID = 1; + MLIRContext context; + std::unique_ptr<Diagnostic> diagnostic; +}; + +//===----------------------------------------------------------------------===// +// Basics +//===----------------------------------------------------------------------===// + +TEST_F(DeserializationTest, EmptyModuleFailure) { + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("SPIR-V binary module must have a 5-word header"); +} + +TEST_F(DeserializationTest, WrongMagicNumberFailure) { + addHeader(); + binary.front() = 0xdeadbeef; // Change to a wrong magic number + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("incorrect magic number"); +} + +TEST_F(DeserializationTest, OnlyHeaderSuccess) { + addHeader(); + EXPECT_NE(llvm::None, deserialize()); +} + +TEST_F(DeserializationTest, ZeroWordCountFailure) { + addHeader(); + binary.push_back(0); // OpNop with zero word count + + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("word count cannot be zero"); +} + +TEST_F(DeserializationTest, InsufficientWordFailure) { + addHeader(); + binary.push_back((2u << 16) | + static_cast<uint32_t>(spirv::Opcode::OpTypeVoid)); + // Missing word for type <id> + + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("insufficient words for the last instruction"); +} + +//===----------------------------------------------------------------------===// +// Types +//===----------------------------------------------------------------------===// + +TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { + addHeader(); + addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); + + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); +} + +//===----------------------------------------------------------------------===// +// Functions +//===----------------------------------------------------------------------===// + +TEST_F(DeserializationTest, FunctionMissingEndFailure) { + addHeader(); + auto voidType = addVoidType(); + auto fnType = addFunctionType(voidType, {}); + addFunction(voidType, fnType); + // Missing OpFunctionEnd + + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("expected OpFunctionEnd instruction"); +} + +TEST_F(DeserializationTest, FunctionMissingParameterFailure) { + addHeader(); + auto voidType = addVoidType(); + auto i32Type = addIntType(32); + auto fnType = addFunctionType(voidType, {i32Type}); + addFunction(voidType, fnType); + // Missing OpFunctionParameter + + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("expected OpFunctionParameter instruction"); +} |