diff options
6 files changed, 123 insertions, 143 deletions
diff --git a/parallel-libs/streamexecutor/examples/Example.cpp b/parallel-libs/streamexecutor/examples/Example.cpp index 76027a860f6..a96648abac6 100644 --- a/parallel-libs/streamexecutor/examples/Example.cpp +++ b/parallel-libs/streamexecutor/examples/Example.cpp @@ -133,9 +133,5 @@ int main() {    for (size_t I = 0; I < ArraySize; ++I) {      assert(HostX[I] == ExpectedX[I]);    } - -  // Free device memory. -  se::dieIfError(Device->freeDeviceMemory(X)); -  se::dieIfError(Device->freeDeviceMemory(Y));    /// [Example saxpy host main]  } diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h index 26d0636ed66..0171d06f77c 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h @@ -56,13 +56,7 @@ public:          PDevice->allocateDeviceMemory(ElementCount * sizeof(T));      if (!MaybeMemory)        return MaybeMemory.takeError(); -    return GlobalDeviceMemory<T>::makeFromElementCount(*MaybeMemory, -                                                       ElementCount); -  } - -  /// Frees memory previously allocated with allocateDeviceMemory. -  template <typename T> Error freeDeviceMemory(GlobalDeviceMemory<T> Memory) { -    return PDevice->freeDeviceMemory(Memory.getHandle()); +    return GlobalDeviceMemory<T>(this, *MaybeMemory, ElementCount);    }    /// Allocates an array of ElementCount entries of type T in host memory. @@ -304,6 +298,12 @@ public:    ///@} End host-synchronous device memory copying functions  private: +  // Only a GlobalDeviceMemoryBase may free device memory. +  friend GlobalDeviceMemoryBase; +  Error freeDeviceMemory(const GlobalDeviceMemoryBase &Memory) { +    return PDevice->freeDeviceMemory(Memory.getHandle()); +  } +    PlatformDevice *PDevice;  }; diff --git a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h index d841d26745d..b7cd3d1e4fe 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h @@ -14,24 +14,15 @@  /// from the device. Host code cannot have a handle to device shared memory  /// because that memory only exists during the execution of a kernel.  /// -/// GlobalDeviceMemoryBase is similar to a pair consisting of a void* pointer -/// and a byte count to tell how much memory is pointed to by that void*. +/// GlobalDeviceMemory<T> is a handle to an array of elements of type T in +/// global device memory. It is similar to a pair of a std::unique_ptr<T> and an +/// element count to tell how many elements of type T fit in the memory pointed +/// to by that T*.  /// -/// GlobalDeviceMemory<T> is a subclass of GlobalDeviceMemoryBase which keeps -/// track of the type of element to be stored in the device memory. It is -/// similar to a pair of a T* pointer and an element count to tell how many -/// elements of type T fit in the memory pointed to by that T*. -/// -/// SharedDeviceMemoryBase is just the size in bytes of a shared memory buffer. -/// -/// SharedDeviceMemory<T> is a subclass of SharedDeviceMemoryBase which knows -/// how many elements of type T it can hold. -/// -/// These classes are useful for keeping track of which memory space a buffer -/// lives in, and the typed subclasses are useful for type-checking. -/// -/// The typed subclass will be used by user code, and the untyped base classes -/// will be used for type-unsafe operations inside of StreamExecutor. +/// SharedDeviceMemory<T> is just the size in elements of an array of elements +/// of type T in device shared memory. No resources are actually attached to +/// this class, it is just like a memo to the device to allocate space in shared +/// memory.  ///  //===----------------------------------------------------------------------===// @@ -41,56 +32,11 @@  #include <cassert>  #include <cstddef> -namespace streamexecutor { - -/// Wrapper around a generic global device memory allocation. -/// -/// This class represents a buffer of untyped bytes in the global memory space -/// of a device. See GlobalDeviceMemory<T> for the corresponding type that -/// includes type information for the elements in its buffer. -/// -/// This is effectively a pair consisting of an opaque handle and a buffer size -/// in bytes. The opaque handle is a platform-dependent handle to the actual -/// memory that is allocated on the device. -/// -/// In some cases, such as in the CUDA platform, the opaque handle may actually -/// be a pointer in the virtual address space and it may be valid to perform -/// arithmetic on it to obtain other device pointers, but this is not the case -/// in general. -/// -/// For example, in the OpenCL platform, the handle is a pointer to a _cl_mem -/// handle object which really is completely opaque to the user. -/// -/// The only fully platform-generic operations on handles are using them to -/// create new GlobalDeviceMemoryBase objects, and comparing them to each other -/// for equality. -class GlobalDeviceMemoryBase { -public: -  /// Creates a GlobalDeviceMemoryBase from an optional handle and an optional -  /// byte count. -  explicit GlobalDeviceMemoryBase(const void *Handle = nullptr, -                                  size_t ByteCount = 0) -      : Handle(Handle), ByteCount(ByteCount) {} - -  /// Copyable like a pointer. -  GlobalDeviceMemoryBase(const GlobalDeviceMemoryBase &) = default; +#include "streamexecutor/Utils/Error.h" -  /// Copy-assignable like a pointer. -  GlobalDeviceMemoryBase &operator=(const GlobalDeviceMemoryBase &) = default; - -  /// Returns the size, in bytes, for the backing memory. -  size_t getByteCount() const { return ByteCount; } - -  /// Gets the internal handle. -  /// -  /// Warning: note that the pointer returned is not necessarily directly to -  /// device virtual address space, but is platform-dependent. -  const void *getHandle() const { return Handle; } +namespace streamexecutor { -private: -  const void *Handle; // Platform-dependent value representing allocated memory. -  size_t ByteCount;   // Size in bytes of this allocation. -}; +class Device;  template <typename ElemT> class GlobalDeviceMemory; @@ -115,7 +61,7 @@ public:    }    /// Gets the GlobalDeviceMemory backing this slice. -  GlobalDeviceMemory<ElemT> getBaseMemory() const { return BaseMemory; } +  const GlobalDeviceMemory<ElemT> &getBaseMemory() const { return BaseMemory; }    /// Gets the offset of this slice from the base memory.    /// @@ -152,11 +98,68 @@ public:    }  private: -  GlobalDeviceMemory<ElemT> BaseMemory; +  const GlobalDeviceMemory<ElemT> &BaseMemory;    size_t ElementOffset;    size_t ElementCount;  }; +/// Wrapper around a generic global device memory allocation. +/// +/// This class represents a buffer of untyped bytes in the global memory space +/// of a device. See GlobalDeviceMemory<T> for the corresponding type that +/// includes type information for the elements in its buffer. +/// +/// This is effectively a pair consisting of an opaque handle and a buffer size +/// in bytes. The opaque handle is a platform-dependent handle to the actual +/// memory that is allocated on the device. +/// +/// In some cases, such as in the CUDA platform, the opaque handle may actually +/// be a pointer in the virtual address space and it may be valid to perform +/// arithmetic on it to obtain other device pointers, but this is not the case +/// in general. +/// +/// For example, in the OpenCL platform, the handle is a pointer to a _cl_mem +/// handle object which really is completely opaque to the user. +class GlobalDeviceMemoryBase { +public: +  /// Returns an opaque handle to the underlying memory. +  const void *getHandle() const { return Handle; } + +  // Cannot copy because the handle must be owned by a single object. +  GlobalDeviceMemoryBase(const GlobalDeviceMemoryBase &) = delete; +  GlobalDeviceMemoryBase &operator=(const GlobalDeviceMemoryBase &) = delete; + +protected: +  /// Creates a GlobalDeviceMemoryBase from a handle and a byte count. +  GlobalDeviceMemoryBase(Device *D, const void *Handle, size_t ByteCount) +      : TheDevice(D), Handle(Handle), ByteCount(ByteCount) {} + +  /// Transfer ownership of the underlying handle. +  GlobalDeviceMemoryBase(GlobalDeviceMemoryBase &&Other) +      : TheDevice(Other.TheDevice), Handle(Other.Handle), +        ByteCount(Other.ByteCount) { +    Other.TheDevice = nullptr; +    Other.Handle = nullptr; +    Other.ByteCount = 0; +  } + +  GlobalDeviceMemoryBase &operator=(GlobalDeviceMemoryBase &&Other) { +    TheDevice = Other.TheDevice; +    Handle = Other.Handle; +    ByteCount = Other.ByteCount; +    Other.TheDevice = nullptr; +    Other.Handle = nullptr; +    Other.ByteCount = 0; +    return *this; +  } + +  ~GlobalDeviceMemoryBase(); + +  Device *TheDevice;  // Pointer to the device on which this memory lives. +  const void *Handle; // Platform-dependent value representing allocated memory. +  size_t ByteCount;   // Size in bytes of this allocation. +}; +  /// Typed wrapper around the "void *"-like GlobalDeviceMemoryBase class.  ///  /// For example, GlobalDeviceMemory<int> is a simple wrapper around @@ -165,31 +168,12 @@ private:  template <typename ElemT>  class GlobalDeviceMemory : public GlobalDeviceMemoryBase {  public: -  /// Creates a typed area of GlobalDeviceMemory with a given opaque handle and -  /// the given element count. -  static GlobalDeviceMemory<ElemT> makeFromElementCount(const void *Handle, -                                                        size_t ElementCount) { -    return GlobalDeviceMemory<ElemT>(Handle, ElementCount); -  } - -  /// Creates a typed device memory region from an untyped device memory region. -  /// -  /// This effectively amounts to a cast from a void* to an ElemT*, but it also -  /// manages the difference in the size measurements when -  /// GlobalDeviceMemoryBase is measured in bytes and GlobalDeviceMemory is -  /// measured in elements. -  explicit GlobalDeviceMemory(const GlobalDeviceMemoryBase &Other) -      : GlobalDeviceMemoryBase(Other.getHandle(), Other.getByteCount()) {} - -  /// Copyable like a pointer. -  GlobalDeviceMemory(const GlobalDeviceMemory &) = default; - -  /// Copy-assignable like a pointer. -  GlobalDeviceMemory &operator=(const GlobalDeviceMemory &) = default; +  GlobalDeviceMemory(GlobalDeviceMemory &&Other) = default; +  GlobalDeviceMemory &operator=(GlobalDeviceMemory &&Other) = default;    /// Returns the number of elements of type ElemT that constitute this    /// allocation. -  size_t getElementCount() const { return getByteCount() / sizeof(ElemT); } +  size_t getElementCount() const { return ByteCount / sizeof(ElemT); }    /// Converts this memory object into a slice.    GlobalDeviceMemorySlice<ElemT> asSlice() const { @@ -197,23 +181,17 @@ public:    }  private: -  /// Constructs a GlobalDeviceMemory instance from an opaque handle and an -  /// element count. -  /// -  /// This constructor is not public because there is a potential for confusion -  /// between the size of the buffer in bytes and the size of the buffer in -  /// elements. -  /// -  /// The static method makeFromElementCount is provided for users of this class -  /// because its name makes the meaning of the size parameter clear. -  GlobalDeviceMemory(const void *Handle, size_t ElementCount) -      : GlobalDeviceMemoryBase(Handle, ElementCount * sizeof(ElemT)) {} +  GlobalDeviceMemory(const GlobalDeviceMemory &) = delete; +  GlobalDeviceMemory &operator=(const GlobalDeviceMemory &) = delete; + +  // Only a Device can create a GlobalDeviceMemory instance. +  friend Device; +  GlobalDeviceMemory(Device *D, const void *Handle, size_t ElementCount) +      : GlobalDeviceMemoryBase(D, Handle, ElementCount * sizeof(ElemT)) {}  }; -/// A class to represent the size of a dynamic shared memory buffer on a device. -/// -/// This class maintains no information about the types to be stored in the -/// buffer. For the typed version of this class see SharedDeviceMemory<ElemT>. +/// A class to represent the size of a dynamic shared memory buffer of elements +/// of type T on a device.  ///  /// Shared memory buffers exist only on the device and cannot be manipulated  /// from the host, so instances of this class do not have an opaque handle, only @@ -232,31 +210,7 @@ private:  /// multiple SharedDeviceMemory arguments, and simply adding together all the  /// shared memory sizes to get the final shared memory size that is used to  /// launch the kernel. -class SharedDeviceMemoryBase { -public: -  /// Creates an untyped shared memory array from a byte count. -  SharedDeviceMemoryBase(size_t ByteCount) : ByteCount(ByteCount) {} - -  /// Copyable because it is just an array size. -  SharedDeviceMemoryBase(const SharedDeviceMemoryBase &) = default; - -  /// Copy-assignable because it is just an array size. -  SharedDeviceMemoryBase &operator=(const SharedDeviceMemoryBase &) = default; - -  /// Gets the byte count. -  size_t getByteCount() const { return ByteCount; } - -private: -  size_t ByteCount; -}; - -/// Typed wrapper around the untyped SharedDeviceMemoryBase class. -/// -/// For example, SharedDeviceMemory<int> is a wrapper around -/// SharedDeviceMemoryBase that represents a buffer of integers stored in shared -/// device memory. -template <typename ElemT> -class SharedDeviceMemory : public SharedDeviceMemoryBase { +template <typename ElemT> class SharedDeviceMemory {  public:    /// Creates a typed area of shared device memory with a given number of    /// elements. @@ -272,7 +226,7 @@ public:    /// Returns the number of elements of type ElemT that can fit this memory    /// buffer. -  size_t getElementCount() const { return getByteCount() / sizeof(ElemT); } +  size_t getElementCount() const { return ElementCount; }    /// Returns whether this is a single-element memory buffer.    bool isScalar() const { return getElementCount() == 1; } @@ -287,7 +241,9 @@ private:    /// The static method makeFromElementCount is provided for users of this class    /// because its name makes the meaning of the size parameter clear.    explicit SharedDeviceMemory(size_t ElementCount) -      : SharedDeviceMemoryBase(ElementCount * sizeof(ElemT)) {} +      : ElementCount(ElementCount) {} + +  size_t ElementCount;  };  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/CMakeLists.txt b/parallel-libs/streamexecutor/lib/CMakeLists.txt index aa16f508661..79ae5c748b0 100644 --- a/parallel-libs/streamexecutor/lib/CMakeLists.txt +++ b/parallel-libs/streamexecutor/lib/CMakeLists.txt @@ -7,6 +7,7 @@ add_library(      streamexecutor      $<TARGET_OBJECTS:utils>      Device.cpp +    DeviceMemory.cpp      Kernel.cpp      KernelSpec.cpp      PackedKernelArgumentArray.cpp diff --git a/parallel-libs/streamexecutor/lib/DeviceMemory.cpp b/parallel-libs/streamexecutor/lib/DeviceMemory.cpp new file mode 100644 index 00000000000..62b702b8acf --- /dev/null +++ b/parallel-libs/streamexecutor/lib/DeviceMemory.cpp @@ -0,0 +1,28 @@ +//===-- DeviceMemory.cpp - DeviceMemory implementation --------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Implementation of DeviceMemory class internals. +/// +//===----------------------------------------------------------------------===// + +#include "streamexecutor/DeviceMemory.h" + +#include "streamexecutor/Device.h" + +namespace streamexecutor { + +GlobalDeviceMemoryBase::~GlobalDeviceMemoryBase() { +  if (Handle) { +    // TODO(jhen): How to handle errors here. +    consumeError(TheDevice->freeDeviceMemory(*this)); +  } +} + +} // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp b/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp index 6e55aa51207..593f1d1cd37 100644 --- a/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp +++ b/parallel-libs/streamexecutor/lib/unittests/DeviceTest.cpp @@ -78,7 +78,6 @@ TEST_F(DeviceTest, AllocateAndFreeDeviceMemory) {    se::Expected<se::GlobalDeviceMemory<int>> MaybeMemory =        Device.allocateDeviceMemory<int>(10);    EXPECT_TRUE(static_cast<bool>(MaybeMemory)); -  EXPECT_NO_ERROR(Device.freeDeviceMemory(*MaybeMemory));  }  TEST_F(DeviceTest, AllocateAndFreeHostMemory) {  | 

