summaryrefslogtreecommitdiffstats
path: root/parallel-libs/streamexecutor/lib
diff options
context:
space:
mode:
Diffstat (limited to 'parallel-libs/streamexecutor/lib')
-rw-r--r--parallel-libs/streamexecutor/lib/Device.cpp7
-rw-r--r--parallel-libs/streamexecutor/lib/Kernel.cpp41
-rw-r--r--parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp2
-rw-r--r--parallel-libs/streamexecutor/lib/Stream.cpp37
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h12
-rw-r--r--parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp12
6 files changed, 84 insertions, 27 deletions
diff --git a/parallel-libs/streamexecutor/lib/Device.cpp b/parallel-libs/streamexecutor/lib/Device.cpp
index 54f03849c68..0d81fb78e2d 100644
--- a/parallel-libs/streamexecutor/lib/Device.cpp
+++ b/parallel-libs/streamexecutor/lib/Device.cpp
@@ -28,14 +28,11 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {}
Device::~Device() = default;
Expected<Stream> Device::createStream() {
- Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream =
- PDevice->createStream();
+ Expected<const void *> MaybePlatformStream = PDevice->createStream();
if (!MaybePlatformStream) {
return MaybePlatformStream.takeError();
}
- assert((*MaybePlatformStream)->getDevice() == PDevice &&
- "an executor created a stream with a different stored executor");
- return Stream(std::move(*MaybePlatformStream));
+ return Stream(PDevice, *MaybePlatformStream);
}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/Kernel.cpp b/parallel-libs/streamexecutor/lib/Kernel.cpp
index 1f4218c4df3..61305372f18 100644
--- a/parallel-libs/streamexecutor/lib/Kernel.cpp
+++ b/parallel-libs/streamexecutor/lib/Kernel.cpp
@@ -12,16 +12,49 @@
///
//===----------------------------------------------------------------------===//
-#include "streamexecutor/Kernel.h"
+#include <cassert>
+
#include "streamexecutor/Device.h"
+#include "streamexecutor/Kernel.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "llvm/DebugInfo/Symbolize/Symbolize.h"
namespace streamexecutor {
-KernelBase::KernelBase(llvm::StringRef Name)
- : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
- Name, nullptr)) {}
+KernelBase::KernelBase(PlatformDevice *D, const void *PlatformKernelHandle,
+ llvm::StringRef Name)
+ : PDevice(D), PlatformKernelHandle(PlatformKernelHandle), Name(Name),
+ DemangledName(
+ llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr)) {
+ assert(D != nullptr &&
+ "cannot construct a kernel object with a null platform device");
+ assert(PlatformKernelHandle != nullptr &&
+ "cannot construct a kernel object with a null platform kernel handle");
+}
+
+KernelBase::KernelBase(KernelBase &&Other)
+ : PDevice(Other.PDevice), PlatformKernelHandle(Other.PlatformKernelHandle),
+ Name(std::move(Other.Name)),
+ DemangledName(std::move(Other.DemangledName)) {
+ Other.PDevice = nullptr;
+ Other.PlatformKernelHandle = nullptr;
+}
+
+KernelBase &KernelBase::operator=(KernelBase &&Other) {
+ PDevice = Other.PDevice;
+ PlatformKernelHandle = Other.PlatformKernelHandle;
+ Name = std::move(Other.Name);
+ DemangledName = std::move(Other.DemangledName);
+ Other.PDevice = nullptr;
+ Other.PlatformKernelHandle = nullptr;
+ return *this;
+}
+
+KernelBase::~KernelBase() {
+ if (PlatformKernelHandle)
+ // TODO(jhen): Handle the error here.
+ consumeError(PDevice->destroyKernel(PlatformKernelHandle));
+}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp
index 770cd170c4f..e9378b519df 100644
--- a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp
+++ b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp
@@ -16,8 +16,6 @@
namespace streamexecutor {
-PlatformStreamHandle::~PlatformStreamHandle() = default;
-
PlatformDevice::~PlatformDevice() = default;
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp
index e1fca58cc19..96aad044c9c 100644
--- a/parallel-libs/streamexecutor/lib/Stream.cpp
+++ b/parallel-libs/streamexecutor/lib/Stream.cpp
@@ -12,14 +12,43 @@
///
//===----------------------------------------------------------------------===//
+#include <cassert>
+
#include "streamexecutor/Stream.h"
namespace streamexecutor {
-Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream)
- : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)),
- ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {}
+Stream::Stream(PlatformDevice *D, const void *PlatformStreamHandle)
+ : PDevice(D), PlatformStreamHandle(PlatformStreamHandle),
+ ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {
+ assert(D != nullptr &&
+ "cannot construct a stream object with a null platform device");
+ assert(PlatformStreamHandle != nullptr &&
+ "cannot construct a stream object with a null platform stream handle");
+}
+
+Stream::Stream(Stream &&Other)
+ : PDevice(Other.PDevice), PlatformStreamHandle(Other.PlatformStreamHandle),
+ ErrorMessageMutex(std::move(Other.ErrorMessageMutex)),
+ ErrorMessage(std::move(Other.ErrorMessage)) {
+ Other.PDevice = nullptr;
+ Other.PlatformStreamHandle = nullptr;
+}
+
+Stream &Stream::operator=(Stream &&Other) {
+ PDevice = Other.PDevice;
+ PlatformStreamHandle = Other.PlatformStreamHandle;
+ ErrorMessageMutex = std::move(Other.ErrorMessageMutex);
+ ErrorMessage = std::move(Other.ErrorMessage);
+ Other.PDevice = nullptr;
+ Other.PlatformStreamHandle = nullptr;
+ return *this;
+}
-Stream::~Stream() = default;
+Stream::~Stream() {
+ if (PlatformStreamHandle)
+ // TODO(jhen): Handle error condition here.
+ consumeError(PDevice->destroyStream(PlatformStreamHandle));
+}
} // namespace streamexecutor
diff --git a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
index 184c2d7f273..b54b31dd457 100644
--- a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
+++ b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h
@@ -34,9 +34,7 @@ class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice {
public:
std::string getName() const override { return "SimpleHostPlatformDevice"; }
- streamexecutor::Expected<
- std::unique_ptr<streamexecutor::PlatformStreamHandle>>
- createStream() override {
+ streamexecutor::Expected<const void *> createStream() override {
return nullptr;
}
@@ -69,7 +67,7 @@ public:
return streamexecutor::Error::success();
}
- streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S,
+ streamexecutor::Error copyD2H(const void *StreamHandle,
const void *DeviceHandleSrc,
size_t SrcByteOffset, void *HostDst,
size_t DstByteOffset,
@@ -80,8 +78,8 @@ public:
return streamexecutor::Error::success();
}
- streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S,
- const void *HostSrc, size_t SrcByteOffset,
+ streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc,
+ size_t SrcByteOffset,
const void *DeviceHandleDst,
size_t DstByteOffset,
size_t ByteCount) override {
@@ -92,7 +90,7 @@ public:
}
streamexecutor::Error
- copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc,
+ copyD2D(const void *StreamHandle, const void *DeviceHandleSrc,
size_t SrcByteOffset, const void *DeviceHandleDst,
size_t DstByteOffset, size_t ByteCount) override {
std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) +
diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
index 4f42bbe8e72..3a0f4e6fdd2 100644
--- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
+++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp
@@ -34,11 +34,11 @@ const auto &getDeviceValue =
class StreamTest : public ::testing::Test {
public:
StreamTest()
- : Device(&PDevice),
- Stream(llvm::make_unique<se::PlatformStreamHandle>(&PDevice)),
- HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9},
- HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23},
- Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35},
+ : DummyPlatformStream(1), Device(&PDevice),
+ Stream(&PDevice, &DummyPlatformStream), HostA5{0, 1, 2, 3, 4},
+ HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16},
+ HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28},
+ Host7{29, 30, 31, 32, 33, 34, 35},
DeviceA5(getOrDie(Device.allocateDeviceMemory<int>(5))),
DeviceB5(getOrDie(Device.allocateDeviceMemory<int>(5))),
DeviceA7(getOrDie(Device.allocateDeviceMemory<int>(7))),
@@ -50,6 +50,8 @@ public:
}
protected:
+ int DummyPlatformStream; // Mimicking a platform where the platform stream
+ // handle is just a stream number.
se::test::SimpleHostPlatformDevice PDevice;
se::Device Device;
se::Stream Stream;
OpenPOWER on IntegriCloud