diff options
Diffstat (limited to 'parallel-libs/streamexecutor/lib')
6 files changed, 84 insertions, 27 deletions
diff --git a/parallel-libs/streamexecutor/lib/Device.cpp b/parallel-libs/streamexecutor/lib/Device.cpp index 54f03849c68..0d81fb78e2d 100644 --- a/parallel-libs/streamexecutor/lib/Device.cpp +++ b/parallel-libs/streamexecutor/lib/Device.cpp @@ -28,14 +28,11 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {} Device::~Device() = default; Expected<Stream> Device::createStream() { - Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream = - PDevice->createStream(); + Expected<const void *> MaybePlatformStream = PDevice->createStream(); if (!MaybePlatformStream) { return MaybePlatformStream.takeError(); } - assert((*MaybePlatformStream)->getDevice() == PDevice && - "an executor created a stream with a different stored executor"); - return Stream(std::move(*MaybePlatformStream)); + return Stream(PDevice, *MaybePlatformStream); } } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Kernel.cpp b/parallel-libs/streamexecutor/lib/Kernel.cpp index 1f4218c4df3..61305372f18 100644 --- a/parallel-libs/streamexecutor/lib/Kernel.cpp +++ b/parallel-libs/streamexecutor/lib/Kernel.cpp @@ -12,16 +12,49 @@ /// //===----------------------------------------------------------------------===// -#include "streamexecutor/Kernel.h" +#include <cassert> + #include "streamexecutor/Device.h" +#include "streamexecutor/Kernel.h" #include "streamexecutor/PlatformInterfaces.h" #include "llvm/DebugInfo/Symbolize/Symbolize.h" namespace streamexecutor { -KernelBase::KernelBase(llvm::StringRef Name) - : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName( - Name, nullptr)) {} +KernelBase::KernelBase(PlatformDevice *D, const void *PlatformKernelHandle, + llvm::StringRef Name) + : PDevice(D), PlatformKernelHandle(PlatformKernelHandle), Name(Name), + DemangledName( + llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr)) { + assert(D != nullptr && + "cannot construct a kernel object with a null platform device"); + assert(PlatformKernelHandle != nullptr && + "cannot construct a kernel object with a null platform kernel handle"); +} + +KernelBase::KernelBase(KernelBase &&Other) + : PDevice(Other.PDevice), PlatformKernelHandle(Other.PlatformKernelHandle), + Name(std::move(Other.Name)), + DemangledName(std::move(Other.DemangledName)) { + Other.PDevice = nullptr; + Other.PlatformKernelHandle = nullptr; +} + +KernelBase &KernelBase::operator=(KernelBase &&Other) { + PDevice = Other.PDevice; + PlatformKernelHandle = Other.PlatformKernelHandle; + Name = std::move(Other.Name); + DemangledName = std::move(Other.DemangledName); + Other.PDevice = nullptr; + Other.PlatformKernelHandle = nullptr; + return *this; +} + +KernelBase::~KernelBase() { + if (PlatformKernelHandle) + // TODO(jhen): Handle the error here. + consumeError(PDevice->destroyKernel(PlatformKernelHandle)); +} } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp index 770cd170c4f..e9378b519df 100644 --- a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp +++ b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp @@ -16,8 +16,6 @@ namespace streamexecutor { -PlatformStreamHandle::~PlatformStreamHandle() = default; - PlatformDevice::~PlatformDevice() = default; } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp index e1fca58cc19..96aad044c9c 100644 --- a/parallel-libs/streamexecutor/lib/Stream.cpp +++ b/parallel-libs/streamexecutor/lib/Stream.cpp @@ -12,14 +12,43 @@ /// //===----------------------------------------------------------------------===// +#include <cassert> + #include "streamexecutor/Stream.h" namespace streamexecutor { -Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream) - : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)), - ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {} +Stream::Stream(PlatformDevice *D, const void *PlatformStreamHandle) + : PDevice(D), PlatformStreamHandle(PlatformStreamHandle), + ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) { + assert(D != nullptr && + "cannot construct a stream object with a null platform device"); + assert(PlatformStreamHandle != nullptr && + "cannot construct a stream object with a null platform stream handle"); +} + +Stream::Stream(Stream &&Other) + : PDevice(Other.PDevice), PlatformStreamHandle(Other.PlatformStreamHandle), + ErrorMessageMutex(std::move(Other.ErrorMessageMutex)), + ErrorMessage(std::move(Other.ErrorMessage)) { + Other.PDevice = nullptr; + Other.PlatformStreamHandle = nullptr; +} + +Stream &Stream::operator=(Stream &&Other) { + PDevice = Other.PDevice; + PlatformStreamHandle = Other.PlatformStreamHandle; + ErrorMessageMutex = std::move(Other.ErrorMessageMutex); + ErrorMessage = std::move(Other.ErrorMessage); + Other.PDevice = nullptr; + Other.PlatformStreamHandle = nullptr; + return *this; +} -Stream::~Stream() = default; +Stream::~Stream() { + if (PlatformStreamHandle) + // TODO(jhen): Handle error condition here. + consumeError(PDevice->destroyStream(PlatformStreamHandle)); +} } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h index 184c2d7f273..b54b31dd457 100644 --- a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h +++ b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h @@ -34,9 +34,7 @@ class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice { public: std::string getName() const override { return "SimpleHostPlatformDevice"; } - streamexecutor::Expected< - std::unique_ptr<streamexecutor::PlatformStreamHandle>> - createStream() override { + streamexecutor::Expected<const void *> createStream() override { return nullptr; } @@ -69,7 +67,7 @@ public: return streamexecutor::Error::success(); } - streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S, + streamexecutor::Error copyD2H(const void *StreamHandle, const void *DeviceHandleSrc, size_t SrcByteOffset, void *HostDst, size_t DstByteOffset, @@ -80,8 +78,8 @@ public: return streamexecutor::Error::success(); } - streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S, - const void *HostSrc, size_t SrcByteOffset, + streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc, + size_t SrcByteOffset, const void *DeviceHandleDst, size_t DstByteOffset, size_t ByteCount) override { @@ -92,7 +90,7 @@ public: } streamexecutor::Error - copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc, + copyD2D(const void *StreamHandle, const void *DeviceHandleSrc, size_t SrcByteOffset, const void *DeviceHandleDst, size_t DstByteOffset, size_t ByteCount) override { std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) + diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp index 4f42bbe8e72..3a0f4e6fdd2 100644 --- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp +++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp @@ -34,11 +34,11 @@ const auto &getDeviceValue = class StreamTest : public ::testing::Test { public: StreamTest() - : Device(&PDevice), - Stream(llvm::make_unique<se::PlatformStreamHandle>(&PDevice)), - HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9}, - HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, - Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, + : DummyPlatformStream(1), Device(&PDevice), + Stream(&PDevice, &DummyPlatformStream), HostA5{0, 1, 2, 3, 4}, + HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, + HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28}, + Host7{29, 30, 31, 32, 33, 34, 35}, DeviceA5(getOrDie(Device.allocateDeviceMemory<int>(5))), DeviceB5(getOrDie(Device.allocateDeviceMemory<int>(5))), DeviceA7(getOrDie(Device.allocateDeviceMemory<int>(7))), @@ -50,6 +50,8 @@ public: } protected: + int DummyPlatformStream; // Mimicking a platform where the platform stream + // handle is just a stream number. se::test::SimpleHostPlatformDevice PDevice; se::Device Device; se::Stream Stream; |

