diff options
3 files changed, 15 insertions, 12 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h b/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h index caf6f1bdc4f..a6a293001ec 100644 --- a/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h +++ b/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h @@ -121,12 +121,11 @@ public: llvm::StringRef KernelName, const llvm::ArrayRef<CUDAPTXInMemorySpec::PTXSpec> SpecList); - /// Returns a pointer to the PTX code for the requested compute capability. + /// Returns a pointer to the PTX code for the greatest compute capability not + /// exceeding the requested compute capability. /// - /// Returns nullptr on failed lookup (if the requested compute capability is - /// not available). Matches exactly the specified compute capability. Doesn't - /// try to do anything smart like finding the next best compute capability if - /// the specified capability cannot be found. + /// Returns nullptr on failed lookup (if the requested version is not + /// available and no lower versions are available). const char *getCode(int ComputeCapabilityMajor, int ComputeCapabilityMinor) const; diff --git a/parallel-libs/streamexecutor/lib/KernelSpec.cpp b/parallel-libs/streamexecutor/lib/KernelSpec.cpp index b5753a489d1..951ea8fc41c 100644 --- a/parallel-libs/streamexecutor/lib/KernelSpec.cpp +++ b/parallel-libs/streamexecutor/lib/KernelSpec.cpp @@ -31,12 +31,13 @@ CUDAPTXInMemorySpec::CUDAPTXInMemorySpec( const char *CUDAPTXInMemorySpec::getCode(int ComputeCapabilityMajor, int ComputeCapabilityMinor) const { - auto PTXIter = - PTXByComputeCapability.find(CUDAPTXInMemorySpec::ComputeCapability{ + auto Iterator = + PTXByComputeCapability.upper_bound(CUDAPTXInMemorySpec::ComputeCapability{ ComputeCapabilityMajor, ComputeCapabilityMinor}); - if (PTXIter == PTXByComputeCapability.end()) + if (Iterator == PTXByComputeCapability.begin()) return nullptr; - return PTXIter->second; + --Iterator; + return Iterator->second; } CUDAFatbinInMemorySpec::CUDAFatbinInMemorySpec(llvm::StringRef KernelName, diff --git a/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp b/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp index fc9eb549968..486a3504091 100644 --- a/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp +++ b/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp @@ -30,8 +30,9 @@ TEST(CUDAPTXInMemorySpec, SingleComputeCapability) { const char *PTXCodeString = "Dummy PTX code"; se::CUDAPTXInMemorySpec Spec("KernelName", {{{1, 0}, PTXCodeString}}); EXPECT_EQ("KernelName", Spec.getKernelName()); + EXPECT_EQ(nullptr, Spec.getCode(0, 5)); EXPECT_EQ(PTXCodeString, Spec.getCode(1, 0)); - EXPECT_EQ(nullptr, Spec.getCode(2, 0)); + EXPECT_EQ(PTXCodeString, Spec.getCode(2, 0)); } TEST(CUDAPTXInMemorySpec, TwoComputeCapabilities) { @@ -40,9 +41,10 @@ TEST(CUDAPTXInMemorySpec, TwoComputeCapabilities) { se::CUDAPTXInMemorySpec Spec( "KernelName", {{{1, 0}, PTXCodeString10}, {{3, 0}, PTXCodeString30}}); EXPECT_EQ("KernelName", Spec.getKernelName()); + EXPECT_EQ(nullptr, Spec.getCode(0, 5)); EXPECT_EQ(PTXCodeString10, Spec.getCode(1, 0)); EXPECT_EQ(PTXCodeString30, Spec.getCode(3, 0)); - EXPECT_EQ(nullptr, Spec.getCode(2, 0)); + EXPECT_EQ(PTXCodeString10, Spec.getCode(2, 0)); } TEST(CUDAFatbinInMemorySpec, BasicUsage) { @@ -89,8 +91,9 @@ TEST(MultiKernelLoaderSpec, Registration) { EXPECT_TRUE(MultiSpec.hasOpenCLTextInMemory()); EXPECT_EQ(KernelName, MultiSpec.getCUDAPTXInMemory().getKernelName()); + EXPECT_EQ(nullptr, MultiSpec.getCUDAPTXInMemory().getCode(0, 5)); EXPECT_EQ(PTXCodeString, MultiSpec.getCUDAPTXInMemory().getCode(1, 0)); - EXPECT_EQ(nullptr, MultiSpec.getCUDAPTXInMemory().getCode(2, 0)); + EXPECT_EQ(PTXCodeString, MultiSpec.getCUDAPTXInMemory().getCode(2, 0)); EXPECT_EQ(KernelName, MultiSpec.getCUDAFatbinInMemory().getKernelName()); EXPECT_EQ(FatbinBytes, MultiSpec.getCUDAFatbinInMemory().getBytes()); |