diff options
Diffstat (limited to 'parallel-libs/streamexecutor/lib')
9 files changed, 248 insertions, 4 deletions
diff --git a/parallel-libs/streamexecutor/lib/CMakeLists.txt b/parallel-libs/streamexecutor/lib/CMakeLists.txt index 6dcb5d87edd..b7f7c278b18 100644 --- a/parallel-libs/streamexecutor/lib/CMakeLists.txt +++ b/parallel-libs/streamexecutor/lib/CMakeLists.txt @@ -7,7 +7,11 @@ add_library(      streamexecutor      $<TARGET_OBJECTS:utils>      Kernel.cpp -    KernelSpec.cpp) +    KernelSpec.cpp +    PackedKernelArgumentArray.cpp +    PlatformInterfaces.cpp +    Stream.cpp +    StreamExecutor.cpp)  target_link_libraries(streamexecutor ${llvm_libs})  if(STREAM_EXECUTOR_UNIT_TESTS) diff --git a/parallel-libs/streamexecutor/lib/Kernel.cpp b/parallel-libs/streamexecutor/lib/Kernel.cpp index af95bbe820d..3c3ec20674f 100644 --- a/parallel-libs/streamexecutor/lib/Kernel.cpp +++ b/parallel-libs/streamexecutor/lib/Kernel.cpp @@ -13,7 +13,7 @@  //===----------------------------------------------------------------------===//  #include "streamexecutor/Kernel.h" -#include "streamexecutor/Interfaces.h" +#include "streamexecutor/PlatformInterfaces.h"  #include "streamexecutor/StreamExecutor.h"  #include "llvm/DebugInfo/Symbolize/Symbolize.h" diff --git a/parallel-libs/streamexecutor/lib/PackedKernelArgumentArray.cpp b/parallel-libs/streamexecutor/lib/PackedKernelArgumentArray.cpp new file mode 100644 index 00000000000..04ac80d74ed --- /dev/null +++ b/parallel-libs/streamexecutor/lib/PackedKernelArgumentArray.cpp @@ -0,0 +1,21 @@ +//===-- PackedKernelArgumentArray.cpp - Packed argument array impl --------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation details for classes from PackedKernelArgumentArray.h. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/PackedKernelArgumentArray.h" + +namespace streamexecutor { + +PackedKernelArgumentArrayBase::~PackedKernelArgumentArrayBase() = default; + +} // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp new file mode 100644 index 00000000000..527c0a934bd --- /dev/null +++ b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp @@ -0,0 +1,23 @@ +//===-- PlatformInterfaces.cpp - Platform interface implementations -------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation file for PlatformInterfaces.h. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/PlatformInterfaces.h" + +namespace streamexecutor { + +PlatformStreamHandle::~PlatformStreamHandle() = default; + +PlatformStreamExecutor::~PlatformStreamExecutor() = default; + +} // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp new file mode 100644 index 00000000000..adfef5fbbe1 --- /dev/null +++ b/parallel-libs/streamexecutor/lib/Stream.cpp @@ -0,0 +1,25 @@ +//===-- Stream.cpp - General stream implementation ------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the implementation details for a general stream object. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/Stream.h" + +namespace streamexecutor { + +Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream) +    : PlatformExecutor(PStream->getExecutor()), +      ThePlatformStream(std::move(PStream)) {} + +Stream::~Stream() = default; + +} // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/StreamExecutor.cpp b/parallel-libs/streamexecutor/lib/StreamExecutor.cpp new file mode 100644 index 00000000000..33e7096f51d --- /dev/null +++ b/parallel-libs/streamexecutor/lib/StreamExecutor.cpp @@ -0,0 +1,42 @@ +//===-- StreamExecutor.cpp - StreamExecutor implementation ----------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation of StreamExecutor class internals. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/StreamExecutor.h" + +#include <cassert> + +#include "streamexecutor/PlatformInterfaces.h" +#include "streamexecutor/Stream.h" + +#include "llvm/ADT/STLExtras.h" + +namespace streamexecutor { + +StreamExecutor::StreamExecutor(PlatformStreamExecutor *PlatformExecutor) +    : PlatformExecutor(PlatformExecutor) {} + +StreamExecutor::~StreamExecutor() = default; + +Expected<std::unique_ptr<Stream>> StreamExecutor::createStream() { +  Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream = +      PlatformExecutor->createStream(); +  if (!MaybePlatformStream) { +    return MaybePlatformStream.takeError(); +  } +  assert((*MaybePlatformStream)->getExecutor() == PlatformExecutor && +         "an executor created a stream with a different stored executor"); +  return llvm::make_unique<Stream>(std::move(*MaybePlatformStream)); +} + +} // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt index efe5f76eb26..f6e6edbebfd 100644 --- a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt +++ b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt @@ -23,7 +23,19 @@ add_executable(      PackedKernelArgumentArrayTest.cpp)  target_link_libraries(      packed_kernel_argument_array_test +    streamexecutor      ${llvm_libs}      ${GTEST_BOTH_LIBRARIES}      ${CMAKE_THREAD_LIBS_INIT})  add_test(PackedKernelArgumentArrayTest packed_kernel_argument_array_test) + +add_executable( +    stream_test +    StreamTest.cpp) +target_link_libraries( +    stream_test +    streamexecutor +    ${llvm_libs} +    ${GTEST_BOTH_LIBRARIES} +    ${CMAKE_THREAD_LIBS_INIT}) +add_test(StreamTest stream_test) diff --git a/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp b/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp index addcb83ec64..9974e994023 100644 --- a/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp +++ b/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp @@ -14,9 +14,9 @@  #include <cassert> -#include "streamexecutor/Interfaces.h"  #include "streamexecutor/Kernel.h"  #include "streamexecutor/KernelSpec.h" +#include "streamexecutor/PlatformInterfaces.h"  #include "streamexecutor/StreamExecutor.h"  #include "llvm/ADT/STLExtras.h" @@ -42,7 +42,8 @@ namespace se = ::streamexecutor;  class MockStreamExecutor : public se::StreamExecutor {  public:    MockStreamExecutor() -      : Unique(llvm::make_unique<se::KernelInterface>()), Raw(Unique.get()) {} +      : se::StreamExecutor(nullptr), +        Unique(llvm::make_unique<se::KernelInterface>()), Raw(Unique.get()) {}    // Moves the unique pointer into the returned se::Expected instance.    // diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp new file mode 100644 index 00000000000..3d7128720cf --- /dev/null +++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp @@ -0,0 +1,116 @@ +//===-- StreamTest.cpp - Tests for Stream ---------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the unit tests for Stream code. +/// +//===----------------------------------------------------------------------===// + +#include <cstring> + +#include "streamexecutor/Kernel.h" +#include "streamexecutor/KernelSpec.h" +#include "streamexecutor/PlatformInterfaces.h" +#include "streamexecutor/Stream.h" +#include "streamexecutor/StreamExecutor.h" + +#include "gtest/gtest.h" + +namespace { + +namespace se = ::streamexecutor; + +/// Mock PlatformStreamExecutor that performs asynchronous memcpy operations by +/// ignoring the stream argument and calling std::memcpy on device memory +/// handles. +class MockPlatformStreamExecutor : public se::PlatformStreamExecutor { +public: +  ~MockPlatformStreamExecutor() override {} + +  std::string getName() const override { return "MockPlatformStreamExecutor"; } + +  se::Expected<std::unique_ptr<se::PlatformStreamHandle>> +  createStream() override { +    return nullptr; +  } + +  se::Error memcpyD2H(se::PlatformStreamHandle *, +                      const se::GlobalDeviceMemoryBase &DeviceSrc, +                      void *HostDst, size_t ByteCount) override { +    std::memcpy(HostDst, DeviceSrc.getHandle(), ByteCount); +    return se::Error::success(); +  } + +  se::Error memcpyH2D(se::PlatformStreamHandle *, const void *HostSrc, +                      se::GlobalDeviceMemoryBase *DeviceDst, +                      size_t ByteCount) override { +    std::memcpy(const_cast<void *>(DeviceDst->getHandle()), HostSrc, ByteCount); +    return se::Error::success(); +  } + +  se::Error memcpyD2D(se::PlatformStreamHandle *, +                      const se::GlobalDeviceMemoryBase &DeviceSrc, +                      se::GlobalDeviceMemoryBase *DeviceDst, +                      size_t ByteCount) override { +    std::memcpy(const_cast<void *>(DeviceDst->getHandle()), +                DeviceSrc.getHandle(), ByteCount); +    return se::Error::success(); +  } +}; + +/// Test fixture to hold objects used by tests. +class StreamTest : public ::testing::Test { +public: +  StreamTest() +      : DeviceA(se::GlobalDeviceMemory<int>::makeFromElementCount(HostA, 10)), +        DeviceB(se::GlobalDeviceMemory<int>::makeFromElementCount(HostB, 10)), +        Stream(llvm::make_unique<se::PlatformStreamHandle>(&PlatformExecutor)) { +  } + +protected: +  // Device memory is backed by host arrays. +  int HostA[10]; +  se::GlobalDeviceMemory<int> DeviceA; +  int HostB[10]; +  se::GlobalDeviceMemory<int> DeviceB; + +  // Host memory to be used as actual host memory. +  int Host[10]; + +  MockPlatformStreamExecutor PlatformExecutor; +  se::Stream Stream; +}; + +TEST_F(StreamTest, MemcpyCorrectSize) { +  Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA); +  EXPECT_TRUE(Stream.isOK()); + +  Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host)); +  EXPECT_TRUE(Stream.isOK()); + +  Stream.thenMemcpyD2D(DeviceA, &DeviceB); +  EXPECT_TRUE(Stream.isOK()); +} + +TEST_F(StreamTest, MemcpyH2DTooManyElements) { +  Stream.thenMemcpyH2D(llvm::ArrayRef<int>(Host), &DeviceA, 20); +  EXPECT_FALSE(Stream.isOK()); +} + +TEST_F(StreamTest, MemcpyD2HTooManyElements) { +  Stream.thenMemcpyD2H(DeviceA, llvm::MutableArrayRef<int>(Host), 20); +  EXPECT_FALSE(Stream.isOK()); +} + +TEST_F(StreamTest, MemcpyD2DTooManyElements) { +  Stream.thenMemcpyD2D(DeviceA, &DeviceB, 20); +  EXPECT_FALSE(Stream.isOK()); +} + +} // namespace  | 

