diff options
Diffstat (limited to 'parallel-libs/streamexecutor/examples/HostSaxpy.cpp')
| -rw-r--r-- | parallel-libs/streamexecutor/examples/HostSaxpy.cpp | 94 |
1 files changed, 94 insertions, 0 deletions
diff --git a/parallel-libs/streamexecutor/examples/HostSaxpy.cpp b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp new file mode 100644 index 00000000000..525c4453b01 --- /dev/null +++ b/parallel-libs/streamexecutor/examples/HostSaxpy.cpp @@ -0,0 +1,94 @@ +//===-- HostSaxpy.cpp - Example of host saxpy with StreamExecutor API -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains example code demonstrating the usage of the +/// StreamExecutor API for a host platform. +/// +//===----------------------------------------------------------------------===// + +#include <algorithm> +#include <cassert> +#include <cstdio> +#include <vector> + +#include "streamexecutor/StreamExecutor.h" + +void Saxpy(float A, float *X, float *Y, size_t N) { + for (size_t I = 0; I < N; ++I) + X[I] = A * X[I] + Y[I]; +} + +namespace __compilergen { +using SaxpyKernel = + streamexecutor::Kernel<float, streamexecutor::GlobalDeviceMemory<float>, + streamexecutor::GlobalDeviceMemory<float>, size_t>; + +// Wrapper function converts argument addresses to arguments. +void SaxpyWrapper(const void *const *ArgumentAddresses) { + Saxpy(*static_cast<const float *>(ArgumentAddresses[0]), + static_cast<float *>(const_cast<void *>(ArgumentAddresses[1])), + static_cast<float *>(const_cast<void *>(ArgumentAddresses[2])), + *static_cast<const size_t *>(ArgumentAddresses[3])); +} + +// The wrapper function is what gets registered. +static streamexecutor::MultiKernelLoaderSpec SaxpyLoaderSpec = []() { + streamexecutor::MultiKernelLoaderSpec Spec; + Spec.addHostFunction("Saxpy", SaxpyWrapper); + return Spec; +}(); +} // namespace __compilergen + +int main() { + namespace se = ::streamexecutor; + namespace cg = ::__compilergen; + + // Create some host data. + float A = 42.0f; + std::vector<float> HostX = {0, 1, 2, 3}; + std::vector<float> HostY = {4, 5, 6, 7}; + size_t ArraySize = HostX.size(); + + // Get a device object. + se::Platform *Platform = + getOrDie(se::PlatformManager::getPlatformByName("host")); + if (Platform->getDeviceCount() == 0) { + return EXIT_FAILURE; + } + se::Device *Device = getOrDie(Platform->getDevice(0)); + + // Load the kernel onto the device. + cg::SaxpyKernel Kernel = + getOrDie(Device->createKernel<cg::SaxpyKernel>(cg::SaxpyLoaderSpec)); + + se::RegisteredHostMemory<float> RegisteredX = + getOrDie(Device->registerHostMemory<float>(HostX)); + se::RegisteredHostMemory<float> RegisteredY = + getOrDie(Device->registerHostMemory<float>(HostY)); + + // Allocate memory on the device. + se::GlobalDeviceMemory<float> X = + getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); + se::GlobalDeviceMemory<float> Y = + getOrDie(Device->allocateDeviceMemory<float>(ArraySize)); + + // Run operations on a stream. + se::Stream Stream = getOrDie(Device->createStream()); + Stream.thenCopyH2D(RegisteredX, X) + .thenCopyH2D(RegisteredY, Y) + .thenLaunch(1, 1, Kernel, A, X, Y, ArraySize) + .thenCopyD2H(X, RegisteredX); + // Wait for the stream to complete. + se::dieIfError(Stream.blockHostUntilDone()); + + // Process output data in HostX. + std::vector<float> ExpectedX = {4, 47, 90, 133}; + assert(std::equal(ExpectedX.begin(), ExpectedX.end(), HostX.begin())); +} |

