//===-- HostSaxpy.cpp - Example of host saxpy with StreamExecutor API -----===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// /// /// \file /// This file contains example code demonstrating the usage of the /// StreamExecutor API for a host platform. /// //===----------------------------------------------------------------------===// #include #include #include #include #include "streamexecutor/StreamExecutor.h" void Saxpy(float A, float *X, float *Y, size_t N) { for (size_t I = 0; I < N; ++I) X[I] = A * X[I] + Y[I]; } namespace __compilergen { using SaxpyKernel = streamexecutor::Kernel, streamexecutor::GlobalDeviceMemory, size_t>; // Wrapper function converts argument addresses to arguments. void SaxpyWrapper(const void *const *ArgumentAddresses) { Saxpy(*static_cast(ArgumentAddresses[0]), static_cast(const_cast(ArgumentAddresses[1])), static_cast(const_cast(ArgumentAddresses[2])), *static_cast(ArgumentAddresses[3])); } // The wrapper function is what gets registered. static streamexecutor::MultiKernelLoaderSpec SaxpyLoaderSpec = []() { streamexecutor::MultiKernelLoaderSpec Spec; Spec.addHostFunction("Saxpy", SaxpyWrapper); return Spec; }(); } // namespace __compilergen int main() { namespace se = ::streamexecutor; namespace cg = ::__compilergen; // Create some host data. float A = 42.0f; std::vector HostX = {0, 1, 2, 3}; std::vector HostY = {4, 5, 6, 7}; size_t ArraySize = HostX.size(); // Get a device object. se::Platform *Platform = getOrDie(se::PlatformManager::getPlatformByName("host")); if (Platform->getDeviceCount() == 0) { return EXIT_FAILURE; } se::Device *Device = getOrDie(Platform->getDevice(0)); // Load the kernel onto the device. cg::SaxpyKernel Kernel = getOrDie(Device->createKernel(cg::SaxpyLoaderSpec)); se::RegisteredHostMemory RegisteredX = getOrDie(Device->registerHostMemory(HostX)); se::RegisteredHostMemory RegisteredY = getOrDie(Device->registerHostMemory(HostY)); // Allocate memory on the device. se::GlobalDeviceMemory X = getOrDie(Device->allocateDeviceMemory(ArraySize)); se::GlobalDeviceMemory Y = getOrDie(Device->allocateDeviceMemory(ArraySize)); // Run operations on a stream. se::Stream Stream = getOrDie(Device->createStream()); Stream.thenCopyH2D(RegisteredX, X) .thenCopyH2D(RegisteredY, Y) .thenLaunch(1, 1, Kernel, A, X, Y, ArraySize) .thenCopyD2H(X, RegisteredX); // Wait for the stream to complete. se::dieIfError(Stream.blockHostUntilDone()); // Process output data in HostX. std::vector ExpectedX = {4, 47, 90, 133}; assert(std::equal(ExpectedX.begin(), ExpectedX.end(), HostX.begin())); }