diff options
| author | Jason Henline <jhen@google.com> | 2016-09-01 18:35:37 +0000 | 
|---|---|---|
| committer | Jason Henline <jhen@google.com> | 2016-09-01 18:35:37 +0000 | 
| commit | e9a12f117563440ae6a4345d3d23fde96c5260e1 (patch) | |
| tree | de6d1938d0d31927988dae58f53975d2526e81f4 /parallel-libs/streamexecutor | |
| parent | 22d36c167e57ab6301be5938d8c2f4fc53031866 (diff) | |
| download | bcm5719-llvm-e9a12f117563440ae6a4345d3d23fde96c5260e1.tar.gz bcm5719-llvm-e9a12f117563440ae6a4345d3d23fde96c5260e1.zip  | |
[SE] Make Stream movable
Summary:
The example code makes it clear that this is a much better design
decision.
Reviewers: jlebar
Subscribers: jprice, parallel_libs-commits
Differential Revision: https://reviews.llvm.org/D24142
llvm-svn: 280397
Diffstat (limited to 'parallel-libs/streamexecutor')
5 files changed, 17 insertions, 14 deletions
diff --git a/parallel-libs/streamexecutor/examples/Example.cpp b/parallel-libs/streamexecutor/examples/Example.cpp index af1994da415..8f42ffa0a3b 100644 --- a/parallel-libs/streamexecutor/examples/Example.cpp +++ b/parallel-libs/streamexecutor/examples/Example.cpp @@ -121,13 +121,13 @@ int main() {        getOrDie(Device->allocateDeviceMemory<float>(ArraySize));    // Run operations on a stream. -  std::unique_ptr<se::Stream> Stream = getOrDie(Device->createStream()); -  Stream->thenCopyH2D<float>(HostX, X) +  se::Stream Stream = getOrDie(Device->createStream()); +  Stream.thenCopyH2D<float>(HostX, X)        .thenCopyH2D<float>(HostY, Y)        .thenLaunch(ArraySize, 1, *Kernel, A, X, Y)        .thenCopyD2H<float>(X, HostX);    // Wait for the stream to complete. -  se::dieIfError(Stream->blockHostUntilDone()); +  se::dieIfError(Stream.blockHostUntilDone());    // Process output data in HostX.    std::vector<float> ExpectedX = {4, 47, 90, 133}; diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h index c37f9b1affb..48ecf22ae76 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h @@ -50,7 +50,8 @@ public:                                        std::move(*MaybeKernelHandle));    } -  Expected<std::unique_ptr<Stream>> createStream(); +  /// Creates a stream object for this device. +  Expected<Stream> createStream();    /// Allocates an array of ElementCount entries of type T in device memory.    template <typename T> diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h index d1c82f9e5ea..1acb18139d8 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h @@ -61,19 +61,22 @@ class Stream {  public:    explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream); +  Stream(Stream &&Other) = default; +  Stream &operator=(Stream &&Other) = default; +    ~Stream();    /// Returns whether any error has occurred while entraining work on this    /// stream.    bool isOK() const { -    llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex); +    llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex);      return !ErrorMessage;    }    /// Returns the status created by the first error that occurred while    /// entraining work on this stream.    Error getStatus() const { -    llvm::sys::ScopedReader ReaderLock(ErrorMessageMutex); +    llvm::sys::ScopedReader ReaderLock(*ErrorMessageMutex);      if (ErrorMessage)        return make_error(*ErrorMessage);      else @@ -315,7 +318,7 @@ private:    /// Does not overwrite the error if it is already set.    void setError(Error &&E) {      if (E) { -      llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex); +      llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex);        if (!ErrorMessage)          ErrorMessage = consumeAndGetMessage(std::move(E));      } @@ -325,7 +328,7 @@ private:    ///    /// Does not overwrite the error if it is already set.    void setError(llvm::Twine Message) { -    llvm::sys::ScopedWriter WriterLock(ErrorMessageMutex); +    llvm::sys::ScopedWriter WriterLock(*ErrorMessageMutex);      if (!ErrorMessage)        ErrorMessage = Message.str();    } @@ -337,9 +340,7 @@ private:    std::unique_ptr<PlatformStreamHandle> ThePlatformStream;    /// Mutex that guards the error state flags. -  /// -  /// Mutable so that it can be obtained via const reader lock. -  mutable llvm::sys::RWMutex ErrorMessageMutex; +  std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex;    /// First error message for an operation in this stream or empty if there have    /// been no errors. diff --git a/parallel-libs/streamexecutor/lib/Device.cpp b/parallel-libs/streamexecutor/lib/Device.cpp index 4a5ec11997d..54f03849c68 100644 --- a/parallel-libs/streamexecutor/lib/Device.cpp +++ b/parallel-libs/streamexecutor/lib/Device.cpp @@ -27,7 +27,7 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {}  Device::~Device() = default; -Expected<std::unique_ptr<Stream>> Device::createStream() { +Expected<Stream> Device::createStream() {    Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream =        PDevice->createStream();    if (!MaybePlatformStream) { @@ -35,7 +35,7 @@ Expected<std::unique_ptr<Stream>> Device::createStream() {    }    assert((*MaybePlatformStream)->getDevice() == PDevice &&           "an executor created a stream with a different stored executor"); -  return llvm::make_unique<Stream>(std::move(*MaybePlatformStream)); +  return Stream(std::move(*MaybePlatformStream));  }  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp index 20a817c2715..e1fca58cc19 100644 --- a/parallel-libs/streamexecutor/lib/Stream.cpp +++ b/parallel-libs/streamexecutor/lib/Stream.cpp @@ -17,7 +17,8 @@  namespace streamexecutor {  Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream) -    : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)) {} +    : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)), +      ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {}  Stream::~Stream() = default;  | 

