diff options
4 files changed, 14 insertions, 32 deletions
diff --git a/parallel-libs/streamexecutor/examples/HostSaxpy.cpp b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp index 5bcbcc898ce..cf81b0ba915 100644 --- a/parallel-libs/streamexecutor/examples/HostSaxpy.cpp +++ b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp @@ -33,8 +33,8 @@ using SaxpyKernel =  // Wrapper function converts argument addresses to arguments.  void SaxpyWrapper(const void *const *ArgumentAddresses) {    Saxpy(*static_cast<const float *>(ArgumentAddresses[0]), -        static_cast<float *>(const_cast<void *>(ArgumentAddresses[1])), -        static_cast<float *>(const_cast<void *>(ArgumentAddresses[2])), +        *static_cast<float **>(const_cast<void *>(ArgumentAddresses[1])), +        *static_cast<float **>(const_cast<void *>(ArgumentAddresses[2])),          *static_cast<const size_t *>(ArgumentAddresses[3]));  } diff --git a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h index d8f7cefc398..62f6e579933 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/DeviceMemory.h @@ -133,6 +133,9 @@ public:    /// Returns an opaque handle to the underlying memory.    const void *getHandle() const { return Handle; } +  /// Returns the address of the opaque handle as stored by this object. +  const void *const *getHandleAddress() const { return &Handle; } +    // Cannot copy because the handle must be owned by a single object.    GlobalDeviceMemoryBase(const GlobalDeviceMemoryBase &) = delete;    GlobalDeviceMemoryBase &operator=(const GlobalDeviceMemoryBase &) = delete; diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h b/parallel-libs/streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h index ba53ea4669c..f34ec67089f 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/PackedKernelArgumentArray.h @@ -164,31 +164,10 @@ private:      Types[Index] = KernelArgumentType::VALUE;    } -  // Pack a GlobalDeviceMemoryBase argument. -  void PackOneArgument(size_t Index, const GlobalDeviceMemoryBase &Argument) { -    Addresses[Index] = Argument.getHandle(); -    Sizes[Index] = sizeof(void *); -    Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; -  } - -  // Pack a GlobalDeviceMemoryBase pointer argument. -  void PackOneArgument(size_t Index, GlobalDeviceMemoryBase *Argument) { -    Addresses[Index] = Argument->getHandle(); -    Sizes[Index] = sizeof(void *); -    Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; -  } - -  // Pack a const GlobalDeviceMemoryBase pointer argument. -  void PackOneArgument(size_t Index, const GlobalDeviceMemoryBase *Argument) { -    Addresses[Index] = Argument->getHandle(); -    Sizes[Index] = sizeof(void *); -    Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY; -  } -    // Pack a GlobalDeviceMemory<T> argument.    template <typename T>    void PackOneArgument(size_t Index, const GlobalDeviceMemory<T> &Argument) { -    Addresses[Index] = Argument.getHandle(); +    Addresses[Index] = Argument.getHandleAddress();      Sizes[Index] = sizeof(void *);      Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY;    } @@ -196,7 +175,7 @@ private:    // Pack a GlobalDeviceMemory<T> pointer argument.    template <typename T>    void PackOneArgument(size_t Index, GlobalDeviceMemory<T> *Argument) { -    Addresses[Index] = Argument->getHandle(); +    Addresses[Index] = Argument->getHandleAddress();      Sizes[Index] = sizeof(void *);      Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY;    } @@ -204,7 +183,7 @@ private:    // Pack a const GlobalDeviceMemory<T> pointer argument.    template <typename T>    void PackOneArgument(size_t Index, const GlobalDeviceMemory<T> *Argument) { -    Addresses[Index] = Argument->getHandle(); +    Addresses[Index] = Argument->getHandleAddress();      Sizes[Index] = sizeof(void *);      Types[Index] = KernelArgumentType::GLOBAL_DEVICE_MEMORY;    } diff --git a/parallel-libs/streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp b/parallel-libs/streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp index 443cdde371a..860f21c323a 100644 --- a/parallel-libs/streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp +++ b/parallel-libs/streamexecutor/unittests/CoreTests/PackedKernelArgumentArrayTest.cpp @@ -76,7 +76,7 @@ TEST_F(DeviceMemoryPackingTest, SingleValue) {  TEST_F(DeviceMemoryPackingTest, SingleTypedGlobal) {    auto Array = se::make_kernel_argument_pack(TypedGlobal); -  ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), +  ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *),                Type::GLOBAL_DEVICE_MEMORY, Array, 0);    EXPECT_EQ(1u, Array.getArgumentCount());    EXPECT_EQ(0u, Array.getSharedCount()); @@ -84,7 +84,7 @@ TEST_F(DeviceMemoryPackingTest, SingleTypedGlobal) {  TEST_F(DeviceMemoryPackingTest, SingleTypedGlobalPointer) {    auto Array = se::make_kernel_argument_pack(&TypedGlobal); -  ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), +  ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *),                Type::GLOBAL_DEVICE_MEMORY, Array, 0);    EXPECT_EQ(1u, Array.getArgumentCount());    EXPECT_EQ(0u, Array.getSharedCount()); @@ -93,7 +93,7 @@ TEST_F(DeviceMemoryPackingTest, SingleTypedGlobalPointer) {  TEST_F(DeviceMemoryPackingTest, SingleConstTypedGlobalPointer) {    const se::GlobalDeviceMemory<int> *ArgumentPointer = &TypedGlobal;    auto Array = se::make_kernel_argument_pack(ArgumentPointer); -  ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), +  ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *),                Type::GLOBAL_DEVICE_MEMORY, Array, 0);    EXPECT_EQ(1u, Array.getArgumentCount());    EXPECT_EQ(0u, Array.getSharedCount()); @@ -131,11 +131,11 @@ TEST_F(DeviceMemoryPackingTest, PackSeveralArguments) {                                               TypedGlobalPointer, TypedShared,                                               &TypedShared, TypedSharedPointer);    ExpectEqual(&Value, sizeof(Value), Type::VALUE, Array, 0); -  ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), +  ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *),                Type::GLOBAL_DEVICE_MEMORY, Array, 1); -  ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), +  ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *),                Type::GLOBAL_DEVICE_MEMORY, Array, 2); -  ExpectEqual(TypedGlobal.getHandle(), sizeof(void *), +  ExpectEqual(TypedGlobal.getHandleAddress(), sizeof(void *),                Type::GLOBAL_DEVICE_MEMORY, Array, 3);    ExpectEqual(nullptr, TypedShared.getByteCount(), Type::SHARED_DEVICE_MEMORY,                Array, 4);  | 

