diff options
| author | Jason Henline <jhen@google.com> | 2016-08-30 23:35:24 +0000 | 
|---|---|---|
| committer | Jason Henline <jhen@google.com> | 2016-08-30 23:35:24 +0000 | 
| commit | 90ce6e1e6496b222cf8e3022ed6f80ccc45dfc0e (patch) | |
| tree | 1149505b8ef06568639106498cd7dfc057a57749 /parallel-libs | |
| parent | ddb53dd080e233b7fe58bd69b46eafa3f093ca8c (diff) | |
| download | bcm5719-llvm-90ce6e1e6496b222cf8e3022ed6f80ccc45dfc0e.tar.gz bcm5719-llvm-90ce6e1e6496b222cf8e3022ed6f80ccc45dfc0e.zip  | |
[StreamExecutor] Simplify Kernel classes
Summary:
Make the Kernel class follow the pattern of the other classes. It now
has a type-safe user wrapper and a typeless, platform-specific handle.
Reviewers: jlebar
Subscribers: jprice, parallel_libs-commits
Differential Revision: https://reviews.llvm.org/D24043
llvm-svn: 280176
Diffstat (limited to 'parallel-libs')
7 files changed, 87 insertions, 212 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Device.h b/parallel-libs/streamexecutor/include/streamexecutor/Device.h index 34bba80859d..c37f9b1affb 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Device.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Device.h @@ -15,13 +15,14 @@  #ifndef STREAMEXECUTOR_DEVICE_H  #define STREAMEXECUTOR_DEVICE_H +#include <type_traits> +  #include "streamexecutor/KernelSpec.h"  #include "streamexecutor/PlatformInterfaces.h"  #include "streamexecutor/Utils/Error.h"  namespace streamexecutor { -class KernelInterface;  class Stream;  class Device { @@ -29,11 +30,24 @@ public:    explicit Device(PlatformDevice *PDevice);    virtual ~Device(); -  /// Gets the kernel implementation for the underlying platform. -  virtual Expected<std::unique_ptr<KernelInterface>> -  getKernelImplementation(const MultiKernelLoaderSpec &Spec) { -    // TODO(jhen): Implement this. -    return nullptr; +  /// Creates a kernel object for this device. +  /// +  /// If the return value is not an error, the returned pointer will never be +  /// null. +  /// +  /// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how +  /// this method is used. +  template <typename KernelT> +  Expected<std::unique_ptr<typename std::enable_if< +      std::is_base_of<KernelBase, KernelT>::value, KernelT>::type>> +  createKernel(const MultiKernelLoaderSpec &Spec) { +    Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle = +        PDevice->createKernel(Spec); +    if (!MaybeKernelHandle) { +      return MaybeKernelHandle.takeError(); +    } +    return llvm::make_unique<KernelT>(Spec.getKernelName(), +                                      std::move(*MaybeKernelHandle));    }    Expected<std::unique_ptr<Stream>> createStream(); diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h index 4a2eeb4b915..63d9c711425 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Kernel.h @@ -11,62 +11,64 @@  /// Types to represent device kernels (code compiled to run on GPU or other  /// accelerator).  /// -/// The TypedKernel class is used to provide type safety to the user API's -/// launch functions, and the KernelBase class is used like a void* function -/// pointer to perform type-unsafe operations inside StreamExecutor. -/// -/// With the kernel parameter types recorded in the TypedKernel template -/// parameters, type-safe kernel launch functions can be written with signatures -/// like the following: +/// With the kernel parameter types recorded in the Kernel template parameters, +/// type-safe kernel launch functions can be written with signatures like the +/// following:  /// \code  ///     template <typename... ParameterTs>  ///     void Launch( -///       const TypedKernel<ParameterTs...> &Kernel, ParamterTs... Arguments); +///       const Kernel<ParameterTs...> &Kernel, ParamterTs... Arguments);  /// \endcode  /// and the compiler will check that the user passes in arguments with types  /// matching the corresponding kernel parameters.  /// -/// A problem is that a TypedKernel template specialization with the right -/// parameter types must be passed as the first argument to the Launch function, -/// and it's just as hard to get the types right in that template specialization -/// as it is to get them right for the kernel arguments. +/// A problem is that a Kernel template specialization with the right parameter +/// types must be passed as the first argument to the Launch function, and it's +/// just as hard to get the types right in that template specialization as it is +/// to get them right for the kernel arguments.  ///  /// With this problem in mind, it is not recommended for users to specialize the -/// TypedKernel template class themselves, but instead to let the compiler do it -/// for them. When the compiler encounters a device kernel function, it can -/// create a TypedKernel template specialization in the host code that has the -/// right parameter types for that kernel and which has a type name based on the -/// name of the kernel function. +/// Kernel template class themselves, but instead to let the compiler do it for +/// them. When the compiler encounters a device kernel function, it can create a +/// Kernel template specialization in the host code that has the right parameter +/// types for that kernel and which has a type name based on the name of the +/// kernel function.  /// +/// \anchor CompilerGeneratedKernelExample  /// For example, if a CUDA device kernel function with the following signature  /// has been defined:  /// \code -///     void Saxpy(float *A, float *X, float *Y); +///     void Saxpy(float A, float *X, float *Y);  /// \endcode  /// the compiler can insert the following declaration in the host code:  /// \code  ///     namespace compiler_cuda_namespace { +///     namespace se = streamexecutor;  ///     using SaxpyKernel = -///         streamexecutor::TypedKernel<float *, float *, float *>; +///         se::Kernel< +///             float, +///             se::GlobalDeviceMemory<float>, +///             se::GlobalDeviceMemory<float>>;  ///     } // namespace compiler_cuda_namespace  /// \endcode  /// and then the user can launch the kernel by calling the StreamExecutor launch  /// function as follows:  /// \code  ///     namespace ccn = compiler_cuda_namespace; +///     using KernelPtr = std::unique_ptr<cnn::SaxpyKernel>;  ///     // Assumes Device is a pointer to the Device on which to launch the  ///     // kernel.  ///     //  ///     // See KernelSpec.h for details on how the compiler can create a  ///     // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below. -///     Expected<ccn::SaxpyKernel> MaybeKernel = -///         ccn::SaxpyKernel::create(Device, ccn::SaxpyKernelLoaderSpec); +///     Expected<KernelPtr> MaybeKernel = +///         Device->createKernel<ccn::SaxpyKernel>(ccn::SaxpyKernelLoaderSpec);  ///     if (!MaybeKernel) { /* Handle error */ } -///     ccn::SaxpyKernel SaxpyKernel = *MaybeKernel; -///     Launch(SaxpyKernel, A, X, Y); +///     KernelPtr SaxpyKernel = std::move(*MaybeKernel); +///     Launch(*SaxpyKernel, A, X, Y);  /// \endcode  /// -/// With the compiler's help in specializing TypedKernel for each device kernel +/// With the compiler's help in specializing Kernel for each device kernel  /// function (and generating a MultiKernelLoaderSpec instance for each kernel),  /// the user can safely launch the device kernel from the host and get an error  /// message at compile time if the argument types don't match the kernel @@ -84,73 +86,37 @@  namespace streamexecutor { -class Device; -class KernelInterface; +class PlatformKernelHandle; -/// The base class for device kernel functions. -/// -/// This class has no information about the types of the parameters taken by the -/// kernel, so it is analogous to a void* pointer to a device function. +/// The base class for all kernel types.  /// -/// See the TypedKernel class below for the subclass which does have information -/// about parameter types. +/// Stores the name of the kernel in both mangled and demangled forms.  class KernelBase {  public: -  KernelBase(KernelBase &&) = default; -  KernelBase &operator=(KernelBase &&) = default; -  ~KernelBase(); - -  /// Creates a kernel object from a Device and a MultiKernelLoaderSpec. -  /// -  /// The Device knows which platform it belongs to and the -  /// MultiKernelLoaderSpec knows how to find the kernel code for different -  /// platforms, so the combined information is enough to get the kernel code -  /// for the appropriate platform. -  static Expected<KernelBase> create(Device *Dev, -                                     const MultiKernelLoaderSpec &Spec); +  KernelBase(llvm::StringRef Name);    const std::string &getName() const { return Name; }    const std::string &getDemangledName() const { return DemangledName; } -  /// Gets a pointer to the platform-specific implementation of this kernel. -  KernelInterface *getImplementation() { return Implementation.get(); } -  private: -  KernelBase(Device *Dev, const std::string &Name, -             const std::string &DemangledName, -             std::unique_ptr<KernelInterface> Implementation); - -  Device *TheDevice;    std::string Name;    std::string DemangledName; -  std::unique_ptr<KernelInterface> Implementation; - -  KernelBase(const KernelBase &) = delete; -  KernelBase &operator=(const KernelBase &) = delete;  }; -/// A device kernel function with specified parameter types. -template <typename... ParameterTs> class TypedKernel : public KernelBase { +/// A StreamExecutor kernel. +/// +/// The template parameters are the types of the parameters to the kernel +/// function. +template <typename... ParameterTs> class Kernel : public KernelBase {  public: -  TypedKernel(TypedKernel &&) = default; -  TypedKernel &operator=(TypedKernel &&) = default; +  Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle) +      : KernelBase(Name), PHandle(std::move(PHandle)) {} -  /// Parameters here have the same meaning as in KernelBase::create. -  static Expected<TypedKernel> create(Device *Dev, -                                      const MultiKernelLoaderSpec &Spec) { -    auto MaybeBase = KernelBase::create(Dev, Spec); -    if (!MaybeBase) { -      return MaybeBase.takeError(); -    } -    TypedKernel Instance(std::move(*MaybeBase)); -    return std::move(Instance); -  } +  /// Gets the underlying platform-specific handle for this kernel. +  PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }  private: -  TypedKernel(KernelBase &&Base) : KernelBase(std::move(Base)) {} - -  TypedKernel(const TypedKernel &) = delete; -  TypedKernel &operator=(const TypedKernel &) = delete; +  std::unique_ptr<PlatformKernelHandle> PHandle;  };  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h index b7737e82e7d..8fa31b63ef2 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -33,9 +33,17 @@ namespace streamexecutor {  class PlatformDevice; -/// Methods supported by device kernel function objects on all platforms. -class KernelInterface { -  // TODO(jhen): Add methods. +/// Platform-specific kernel handle. +class PlatformKernelHandle { +public: +  explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} + +  virtual ~PlatformKernelHandle(); + +  PlatformDevice *getDevice() { return PDevice; } + +private: +  PlatformDevice *PDevice;  };  /// Platform-specific stream handle. @@ -64,12 +72,20 @@ public:    virtual std::string getName() const = 0; +  /// Creates a platform-specific kernel. +  virtual Expected<std::unique_ptr<PlatformKernelHandle>> +  createKernel(const MultiKernelLoaderSpec &Spec) { +    return make_error("createKernel not implemented for platform " + getName()); +  } +    /// Creates a platform-specific stream. -  virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() = 0; +  virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() { +    return make_error("createStream not implemented for platform " + getName()); +  }    /// Launches a kernel on the given stream.    virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize, -                       GridDimensions GridSize, const KernelBase &Kernel, +                       GridDimensions GridSize, PlatformKernelHandle *K,                         const PackedKernelArgumentArrayBase &ArgumentArray) {      return make_error("launch not implemented for platform " + getName());    } diff --git a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h index 0e6e898b473..2937c5842e8 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/Stream.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/Stream.h @@ -86,15 +86,15 @@ public:    /// These arguments can be device memory types like GlobalDeviceMemory<T> and    /// SharedDeviceMemory<T>, or they can be primitive types such as int. The    /// allowable argument types are determined by the template parameters to the -  /// TypedKernel argument. +  /// Kernel argument.    template <typename... ParameterTs>    Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize, -                     const TypedKernel<ParameterTs...> &Kernel, +                     const Kernel<ParameterTs...> &K,                       const ParameterTs &... Arguments) {      auto ArgumentArray =          make_kernel_argument_pack<ParameterTs...>(Arguments...);      setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize, -                             Kernel, ArgumentArray)); +                             K.getPlatformHandle(), ArgumentArray));      return *this;    } diff --git a/parallel-libs/streamexecutor/lib/Kernel.cpp b/parallel-libs/streamexecutor/lib/Kernel.cpp index fa0992003a6..1f4218c4df3 100644 --- a/parallel-libs/streamexecutor/lib/Kernel.cpp +++ b/parallel-libs/streamexecutor/lib/Kernel.cpp @@ -20,26 +20,8 @@  namespace streamexecutor { -KernelBase::KernelBase(Device *Dev, const std::string &Name, -                       const std::string &DemangledName, -                       std::unique_ptr<KernelInterface> Implementation) -    : TheDevice(Dev), Name(Name), DemangledName(DemangledName), -      Implementation(std::move(Implementation)) {} - -KernelBase::~KernelBase() = default; - -Expected<KernelBase> KernelBase::create(Device *Dev, -                                        const MultiKernelLoaderSpec &Spec) { -  auto MaybeImplementation = Dev->getKernelImplementation(Spec); -  if (!MaybeImplementation) { -    return MaybeImplementation.takeError(); -  } -  std::string Name = Spec.getKernelName(); -  std::string DemangledName = -      llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr); -  KernelBase Instance(Dev, Name, DemangledName, -                      std::move(*MaybeImplementation)); -  return std::move(Instance); -} +KernelBase::KernelBase(llvm::StringRef Name) +    : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName( +                      Name, nullptr)) {}  } // namespace streamexecutor diff --git a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt index 3b414e342d9..e12b675f2c4 100644 --- a/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt +++ b/parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt @@ -9,16 +9,6 @@ target_link_libraries(  add_test(DeviceTest device_test)  add_executable( -    kernel_test -    KernelTest.cpp) -target_link_libraries( -    kernel_test -    streamexecutor -    ${GTEST_BOTH_LIBRARIES} -    ${CMAKE_THREAD_LIBS_INIT}) -add_test(KernelTest kernel_test) - -add_executable(      kernel_spec_test      KernelSpecTest.cpp)  target_link_libraries( diff --git a/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp b/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp deleted file mode 100644 index a19ebfb96bd..00000000000 --- a/parallel-libs/streamexecutor/lib/unittests/KernelTest.cpp +++ /dev/null @@ -1,93 +0,0 @@ -//===-- KernelTest.cpp - Tests for Kernel objects -------------------------===// -// -//                     The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file contains the unit tests for the code in Kernel. -/// -//===----------------------------------------------------------------------===// - -#include <cassert> - -#include "streamexecutor/Device.h" -#include "streamexecutor/Kernel.h" -#include "streamexecutor/KernelSpec.h" -#include "streamexecutor/PlatformInterfaces.h" - -#include "llvm/ADT/STLExtras.h" - -#include "gtest/gtest.h" - -namespace { - -namespace se = ::streamexecutor; - -// A Device that returns a dummy KernelInterface. -// -// During construction it creates a unique_ptr to a dummy KernelInterface and it -// also stores a separate copy of the raw pointer that is stored by that -// unique_ptr. -// -// The expectation is that the code being tested will call the -// getKernelImplementation method and will thereby take ownership of the -// unique_ptr, but the copy of the raw pointer will stay behind in this mock -// object. The raw pointer copy can then be used to identify the unique_ptr in -// its new location (by comparing the raw pointer with unique_ptr::get), to -// verify that the unique_ptr ended up where it was supposed to be. -class MockDevice : public se::Device { -public: -  MockDevice() -      : se::Device(nullptr), Unique(llvm::make_unique<se::KernelInterface>()), -        Raw(Unique.get()) {} - -  // Moves the unique pointer into the returned se::Expected instance. -  // -  // Asserts that it is not called again after the unique pointer has been moved -  // out. -  se::Expected<std::unique_ptr<se::KernelInterface>> -  getKernelImplementation(const se::MultiKernelLoaderSpec &) override { -    assert(Unique && "MockDevice getKernelImplementation should not be " -                     "called more than once"); -    return std::move(Unique); -  } - -  // Gets the copy of the raw pointer from the original unique pointer. -  const se::KernelInterface *getRaw() const { return Raw; } - -private: -  std::unique_ptr<se::KernelInterface> Unique; -  const se::KernelInterface *Raw; -}; - -// Test fixture class for typed tests for KernelBase.getImplementation. -// -// The only purpose of this class is to provide a name that types can be bound -// to in the gtest infrastructure. -template <typename T> class GetImplementationTest : public ::testing::Test {}; - -// Types used with the GetImplementationTest fixture class. -typedef ::testing::Types<se::KernelBase, se::TypedKernel<>, -                         se::TypedKernel<int>> -    GetImplementationTypes; - -TYPED_TEST_CASE(GetImplementationTest, GetImplementationTypes); - -// Tests that the kernel create functions properly fetch the implementation -// pointers for the kernel objects they construct from the passed-in -// Device objects. -TYPED_TEST(GetImplementationTest, SetImplementationDuringCreate) { -  se::MultiKernelLoaderSpec Spec; -  MockDevice Dev; - -  auto MaybeKernel = TypeParam::create(&Dev, Spec); -  EXPECT_TRUE(static_cast<bool>(MaybeKernel)); -  se::KernelInterface *Implementation = MaybeKernel->getImplementation(); -  EXPECT_EQ(Dev.getRaw(), Implementation); -} - -} // namespace  | 

