diff options
Diffstat (limited to 'parallel-libs/streamexecutor')
10 files changed, 142 insertions, 95 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h index 95d9b5c62fb..0ee2b2fbc0b 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h @@ -35,12 +35,11 @@ public:    Expected<typename std::enable_if<std::is_base_of<KernelBase, KernelT>::value,                                     KernelT>::type>    createKernel(const MultiKernelLoaderSpec &Spec) { -    Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle = -        PDevice->createKernel(Spec); +    Expected<const void *> MaybeKernelHandle = PDevice->createKernel(Spec);      if (!MaybeKernelHandle) {        return MaybeKernelHandle.takeError();      } -    return KernelT(Spec.getKernelName(), std::move(*MaybeKernelHandle)); +    return KernelT(PDevice, *MaybeKernelHandle, Spec.getKernelName());    }    /// Creates a stream object for this device. diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h index c9b4180afee..6ea7c361803 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h @@ -28,19 +28,32 @@  namespace streamexecutor { -class PlatformKernelHandle; +class PlatformDevice;  /// The base class for all kernel types.  ///  /// Stores the name of the kernel in both mangled and demangled forms.  class KernelBase {  public: -  KernelBase(llvm::StringRef Name); +  KernelBase(PlatformDevice *D, const void *PlatformKernelHandle, +             llvm::StringRef Name); +  KernelBase(const KernelBase &Other) = delete; +  KernelBase &operator=(const KernelBase &Other) = delete; + +  KernelBase(KernelBase &&Other); +  KernelBase &operator=(KernelBase &&Other); + +  ~KernelBase(); + +  const void *getPlatformHandle() const { return PlatformKernelHandle; }    const std::string &getName() const { return Name; }    const std::string &getDemangledName() const { return DemangledName; }  private: +  PlatformDevice *PDevice; +  const void *PlatformKernelHandle; +    std::string Name;    std::string DemangledName;  }; @@ -51,17 +64,12 @@ private:  /// function.  template <typename... ParameterTs> class Kernel : public KernelBase {  public: -  Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle) -      : KernelBase(Name), PHandle(std::move(PHandle)) {} +  Kernel(PlatformDevice *D, const void *PlatformKernelHandle, +         llvm::StringRef Name) +      : KernelBase(D, PlatformKernelHandle, Name) {}    Kernel(Kernel &&Other) = default;    Kernel &operator=(Kernel &&Other) = default; - -  /// Gets the underlying platform-specific handle for this kernel. -  PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); } - -private: -  std::unique_ptr<PlatformKernelHandle> PHandle;  };  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h index b3deff31f50..946f8f96a94 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -31,34 +31,6 @@  namespace streamexecutor { -class PlatformDevice; - -/// Platform-specific kernel handle. -class PlatformKernelHandle { -public: -  explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} - -  virtual ~PlatformKernelHandle(); - -  PlatformDevice *getDevice() { return PDevice; } - -private: -  PlatformDevice *PDevice; -}; - -/// Platform-specific stream handle. -class PlatformStreamHandle { -public: -  explicit PlatformStreamHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} - -  virtual ~PlatformStreamHandle(); - -  PlatformDevice *getDevice() { return PDevice; } - -private: -  PlatformDevice *PDevice; -}; -  /// Raw executor methods that must be implemented by each platform.  ///  /// This class defines the platform interface that supports executing work on a @@ -73,19 +45,30 @@ public:    virtual std::string getName() const = 0;    /// Creates a platform-specific kernel. -  virtual Expected<std::unique_ptr<PlatformKernelHandle>> +  virtual Expected<const void *>    createKernel(const MultiKernelLoaderSpec &Spec) {      return make_error("createKernel not implemented for platform " + getName());    } +  virtual Error destroyKernel(const void *Handle) { +    return make_error("destroyKernel not implemented for platform " + +                      getName()); +  } +    /// Creates a platform-specific stream. -  virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() { +  virtual Expected<const void *> createStream() {      return make_error("createStream not implemented for platform " + getName());    } +  virtual Error destroyStream(const void *Handle) { +    return make_error("destroyStream not implemented for platform " + +                      getName()); +  } +    /// Launches a kernel on the given stream. -  virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize, -                       GridDimensions GridSize, PlatformKernelHandle *K, +  virtual Error launch(const void *PlatformStreamHandle, +                       BlockDimensions BlockSize, GridDimensions GridSize, +                       const void *PKernelHandle,                         const PackedKernelArgumentArrayBase &ArgumentArray) {      return make_error("launch not implemented for platform " + getName());    } @@ -94,9 +77,9 @@ public:    ///    /// HostDst should have been allocated by allocateHostMemory or registered    /// with registerHostMemory. -  virtual Error copyD2H(PlatformStreamHandle *S, const void *DeviceSrcHandle, -                        size_t SrcByteOffset, void *HostDst, -                        size_t DstByteOffset, size_t ByteCount) { +  virtual Error copyD2H(const void *PlatformStreamHandle, +                        const void *DeviceSrcHandle, size_t SrcByteOffset, +                        void *HostDst, size_t DstByteOffset, size_t ByteCount) {      return make_error("copyD2H not implemented for platform " + getName());    } @@ -104,22 +87,23 @@ public:    ///    /// HostSrc should have been allocated by allocateHostMemory or registered    /// with registerHostMemory. -  virtual Error copyH2D(PlatformStreamHandle *S, const void *HostSrc, +  virtual Error copyH2D(const void *PlatformStreamHandle, const void *HostSrc,                          size_t SrcByteOffset, const void *DeviceDstHandle,                          size_t DstByteOffset, size_t ByteCount) {      return make_error("copyH2D not implemented for platform " + getName());    }    /// Copies data from one device location to another. -  virtual Error copyD2D(PlatformStreamHandle *S, const void *DeviceSrcHandle, -                        size_t SrcByteOffset, const void *DeviceDstHandle, -                        size_t DstByteOffset, size_t ByteCount) { +  virtual Error copyD2D(const void *PlatformStreamHandle, +                        const void *DeviceSrcHandle, size_t SrcByteOffset, +                        const void *DeviceDstHandle, size_t DstByteOffset, +                        size_t ByteCount) {      return make_error("copyD2D not implemented for platform " + getName());    }    /// Blocks the host until the given stream completes all the work enqueued up    /// to the point this function is called. -  virtual Error blockHostUntilDone(PlatformStreamHandle *S) { +  virtual Error blockHostUntilDone(const void *PlatformStreamHandle) {      return make_error("blockHostUntilDone not implemented for platform " +                        getName());    } diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h index 81f9ada7792..48dcf32368a 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h @@ -59,10 +59,13 @@ namespace streamexecutor {  /// of a stream once it is in an error state.  class Stream {  public: -  explicit Stream(std::unique_ptr<PlatformStreamHandle> PStream); +  Stream(PlatformDevice *D, const void *PlatformStreamHandle); -  Stream(Stream &&Other) = default; -  Stream &operator=(Stream &&Other) = default; +  Stream(const Stream &Other) = delete; +  Stream &operator=(const Stream &Other) = delete; + +  Stream(Stream &&Other); +  Stream &operator=(Stream &&Other);    ~Stream(); @@ -88,7 +91,7 @@ public:    //    // Returns the result of getStatus() after the Stream work completes.    Error blockHostUntilDone() { -    setError(PDevice->blockHostUntilDone(ThePlatformStream.get())); +    setError(PDevice->blockHostUntilDone(PlatformStreamHandle));      return getStatus();    } @@ -105,7 +108,7 @@ public:                       const ParameterTs &... Arguments) {      auto ArgumentArray =          make_kernel_argument_pack<ParameterTs...>(Arguments...); -    setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize, +    setError(PDevice->launch(PlatformStreamHandle, BlockSize, GridSize,                               K.getPlatformHandle(), ArgumentArray));      return *this;    } @@ -136,7 +139,7 @@ public:        setError("copying too many elements, " + llvm::Twine(ElementCount) +                 ", to a host array of element count " + llvm::Twine(Dst.size()));      else -      setError(PDevice->copyD2H(ThePlatformStream.get(), +      setError(PDevice->copyD2H(PlatformStreamHandle,                                  Src.getBaseMemory().getHandle(),                                  Src.getElementOffset() * sizeof(T), Dst.data(),                                  0, ElementCount * sizeof(T))); @@ -196,10 +199,9 @@ public:                 ", to a device array of element count " +                 llvm::Twine(Dst.getElementCount()));      else -      setError(PDevice->copyH2D(ThePlatformStream.get(), Src.data(), 0, -                                Dst.getBaseMemory().getHandle(), -                                Dst.getElementOffset() * sizeof(T), -                                ElementCount * sizeof(T))); +      setError(PDevice->copyH2D( +          PlatformStreamHandle, Src.data(), 0, Dst.getBaseMemory().getHandle(), +          Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));      return *this;    } @@ -254,7 +256,7 @@ public:                 llvm::Twine(Dst.getElementCount()));      else        setError(PDevice->copyD2D( -          ThePlatformStream.get(), Src.getBaseMemory().getHandle(), +          PlatformStreamHandle, Src.getBaseMemory().getHandle(),            Src.getElementOffset() * sizeof(T), Dst.getBaseMemory().getHandle(),            Dst.getElementOffset() * sizeof(T), ElementCount * sizeof(T)));      return *this; @@ -342,7 +344,7 @@ private:    PlatformDevice *PDevice;    /// The platform-specific stream handle for this instance. -  std::unique_ptr<PlatformStreamHandle> ThePlatformStream; +  const void *PlatformStreamHandle;    /// Mutex that guards the error state flags.    std::unique_ptr<llvm::sys::RWMutex> ErrorMessageMutex; @@ -350,9 +352,6 @@ private:    /// First error message for an operation in this stream or empty if there have    /// been no errors.    llvm::Optional<std::string> ErrorMessage; - -  Stream(const Stream &) = delete; -  void operator=(const Stream &) = delete;  };  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Device.cpp b/parallel-libs/streamexecutor/lib/Device.cpp index 54f03849c68..0d81fb78e2d 100644 --- a/parallel-libs/streamexecutor/lib/Device.cpp +++ b/parallel-libs/streamexecutor/lib/Device.cpp @@ -28,14 +28,11 @@ Device::Device(PlatformDevice *PDevice) : PDevice(PDevice) {}  Device::~Device() = default;  Expected<Stream> Device::createStream() { -  Expected<std::unique_ptr<PlatformStreamHandle>> MaybePlatformStream = -      PDevice->createStream(); +  Expected<const void *> MaybePlatformStream = PDevice->createStream();    if (!MaybePlatformStream) {      return MaybePlatformStream.takeError();    } -  assert((*MaybePlatformStream)->getDevice() == PDevice && -         "an executor created a stream with a different stored executor"); -  return Stream(std::move(*MaybePlatformStream)); +  return Stream(PDevice, *MaybePlatformStream);  }  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Kernel.cpp b/parallel-libs/streamexecutor/lib/Kernel.cpp index 1f4218c4df3..61305372f18 100644 --- a/parallel-libs/streamexecutor/lib/Kernel.cpp +++ b/parallel-libs/streamexecutor/lib/Kernel.cpp @@ -12,16 +12,49 @@  ///  //===----------------------------------------------------------------------===// -#include "streamexecutor/Kernel.h" +#include <cassert> +  #include "streamexecutor/Device.h" +#include "streamexecutor/Kernel.h"  #include "streamexecutor/PlatformInterfaces.h"  #include "llvm/DebugInfo/Symbolize/Symbolize.h"  namespace streamexecutor { -KernelBase::KernelBase(llvm::StringRef Name) -    : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName( -                      Name, nullptr)) {} +KernelBase::KernelBase(PlatformDevice *D, const void *PlatformKernelHandle, +                       llvm::StringRef Name) +    : PDevice(D), PlatformKernelHandle(PlatformKernelHandle), Name(Name), +      DemangledName( +          llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr)) { +  assert(D != nullptr && +         "cannot construct a kernel object with a null platform device"); +  assert(PlatformKernelHandle != nullptr && +         "cannot construct a kernel object with a null platform kernel handle"); +} + +KernelBase::KernelBase(KernelBase &&Other) +    : PDevice(Other.PDevice), PlatformKernelHandle(Other.PlatformKernelHandle), +      Name(std::move(Other.Name)), +      DemangledName(std::move(Other.DemangledName)) { +  Other.PDevice = nullptr; +  Other.PlatformKernelHandle = nullptr; +} + +KernelBase &KernelBase::operator=(KernelBase &&Other) { +  PDevice = Other.PDevice; +  PlatformKernelHandle = Other.PlatformKernelHandle; +  Name = std::move(Other.Name); +  DemangledName = std::move(Other.DemangledName); +  Other.PDevice = nullptr; +  Other.PlatformKernelHandle = nullptr; +  return *this; +} + +KernelBase::~KernelBase() { +  if (PlatformKernelHandle) +    // TODO(jhen): Handle the error here. +    consumeError(PDevice->destroyKernel(PlatformKernelHandle)); +}  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp index 770cd170c4f..e9378b519df 100644 --- a/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp +++ b/parallel-libs/streamexecutor/lib/PlatformInterfaces.cpp @@ -16,8 +16,6 @@  namespace streamexecutor { -PlatformStreamHandle::~PlatformStreamHandle() = default; -  PlatformDevice::~PlatformDevice() = default;  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/Stream.cpp b/parallel-libs/streamexecutor/lib/Stream.cpp index e1fca58cc19..96aad044c9c 100644 --- a/parallel-libs/streamexecutor/lib/Stream.cpp +++ b/parallel-libs/streamexecutor/lib/Stream.cpp @@ -12,14 +12,43 @@  ///  //===----------------------------------------------------------------------===// +#include <cassert> +  #include "streamexecutor/Stream.h"  namespace streamexecutor { -Stream::Stream(std::unique_ptr<PlatformStreamHandle> PStream) -    : PDevice(PStream->getDevice()), ThePlatformStream(std::move(PStream)), -      ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) {} +Stream::Stream(PlatformDevice *D, const void *PlatformStreamHandle) +    : PDevice(D), PlatformStreamHandle(PlatformStreamHandle), +      ErrorMessageMutex(llvm::make_unique<llvm::sys::RWMutex>()) { +  assert(D != nullptr && +         "cannot construct a stream object with a null platform device"); +  assert(PlatformStreamHandle != nullptr && +         "cannot construct a stream object with a null platform stream handle"); +} + +Stream::Stream(Stream &&Other) +    : PDevice(Other.PDevice), PlatformStreamHandle(Other.PlatformStreamHandle), +      ErrorMessageMutex(std::move(Other.ErrorMessageMutex)), +      ErrorMessage(std::move(Other.ErrorMessage)) { +  Other.PDevice = nullptr; +  Other.PlatformStreamHandle = nullptr; +} + +Stream &Stream::operator=(Stream &&Other) { +  PDevice = Other.PDevice; +  PlatformStreamHandle = Other.PlatformStreamHandle; +  ErrorMessageMutex = std::move(Other.ErrorMessageMutex); +  ErrorMessage = std::move(Other.ErrorMessage); +  Other.PDevice = nullptr; +  Other.PlatformStreamHandle = nullptr; +  return *this; +} -Stream::~Stream() = default; +Stream::~Stream() { +  if (PlatformStreamHandle) +    // TODO(jhen): Handle error condition here. +    consumeError(PDevice->destroyStream(PlatformStreamHandle)); +}  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h index 184c2d7f273..b54b31dd457 100644 --- a/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h +++ b/parallel-libs/streamexecutor/lib/unittests/SimpleHostPlatformDevice.h @@ -34,9 +34,7 @@ class SimpleHostPlatformDevice : public streamexecutor::PlatformDevice {  public:    std::string getName() const override { return "SimpleHostPlatformDevice"; } -  streamexecutor::Expected< -      std::unique_ptr<streamexecutor::PlatformStreamHandle>> -  createStream() override { +  streamexecutor::Expected<const void *> createStream() override {      return nullptr;    } @@ -69,7 +67,7 @@ public:      return streamexecutor::Error::success();    } -  streamexecutor::Error copyD2H(streamexecutor::PlatformStreamHandle *S, +  streamexecutor::Error copyD2H(const void *StreamHandle,                                  const void *DeviceHandleSrc,                                  size_t SrcByteOffset, void *HostDst,                                  size_t DstByteOffset, @@ -80,8 +78,8 @@ public:      return streamexecutor::Error::success();    } -  streamexecutor::Error copyH2D(streamexecutor::PlatformStreamHandle *S, -                                const void *HostSrc, size_t SrcByteOffset, +  streamexecutor::Error copyH2D(const void *StreamHandle, const void *HostSrc, +                                size_t SrcByteOffset,                                  const void *DeviceHandleDst,                                  size_t DstByteOffset,                                  size_t ByteCount) override { @@ -92,7 +90,7 @@ public:    }    streamexecutor::Error -  copyD2D(streamexecutor::PlatformStreamHandle *S, const void *DeviceHandleSrc, +  copyD2D(const void *StreamHandle, const void *DeviceHandleSrc,            size_t SrcByteOffset, const void *DeviceHandleDst,            size_t DstByteOffset, size_t ByteCount) override {      std::memcpy(static_cast<char *>(const_cast<void *>(DeviceHandleDst)) + diff --git a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp index 4f42bbe8e72..3a0f4e6fdd2 100644 --- a/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp +++ b/parallel-libs/streamexecutor/lib/unittests/StreamTest.cpp @@ -34,11 +34,11 @@ const auto &getDeviceValue =  class StreamTest : public ::testing::Test {  public:    StreamTest() -      : Device(&PDevice), -        Stream(llvm::make_unique<se::PlatformStreamHandle>(&PDevice)), -        HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9}, -        HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, -        Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, +      : DummyPlatformStream(1), Device(&PDevice), +        Stream(&PDevice, &DummyPlatformStream), HostA5{0, 1, 2, 3, 4}, +        HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, +        HostB7{17, 18, 19, 20, 21, 22, 23}, Host5{24, 25, 26, 27, 28}, +        Host7{29, 30, 31, 32, 33, 34, 35},          DeviceA5(getOrDie(Device.allocateDeviceMemory<int>(5))),          DeviceB5(getOrDie(Device.allocateDeviceMemory<int>(5))),          DeviceA7(getOrDie(Device.allocateDeviceMemory<int>(7))), @@ -50,6 +50,8 @@ public:    }  protected: +  int DummyPlatformStream; // Mimicking a platform where the platform stream +                           // handle is just a stream number.    se::test::SimpleHostPlatformDevice PDevice;    se::Device Device;    se::Stream Stream;  | 

