summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h9
-rw-r--r--parallel-libs/streamexecutor/lib/KernelSpec.cpp9
-rw-r--r--parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp9
3 files changed, 15 insertions, 12 deletions
diff --git a/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h b/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h
index caf6f1bdc4f..a6a293001ec 100644
--- a/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h
+++ b/parallel-libs/streamexecutor/include/streamexecutor/KernelSpec.h
@@ -121,12 +121,11 @@ public:
llvm::StringRef KernelName,
const llvm::ArrayRef<CUDAPTXInMemorySpec::PTXSpec> SpecList);
- /// Returns a pointer to the PTX code for the requested compute capability.
+ /// Returns a pointer to the PTX code for the greatest compute capability not
+ /// exceeding the requested compute capability.
///
- /// Returns nullptr on failed lookup (if the requested compute capability is
- /// not available). Matches exactly the specified compute capability. Doesn't
- /// try to do anything smart like finding the next best compute capability if
- /// the specified capability cannot be found.
+ /// Returns nullptr on failed lookup (if the requested version is not
+ /// available and no lower versions are available).
const char *getCode(int ComputeCapabilityMajor,
int ComputeCapabilityMinor) const;
diff --git a/parallel-libs/streamexecutor/lib/KernelSpec.cpp b/parallel-libs/streamexecutor/lib/KernelSpec.cpp
index b5753a489d1..951ea8fc41c 100644
--- a/parallel-libs/streamexecutor/lib/KernelSpec.cpp
+++ b/parallel-libs/streamexecutor/lib/KernelSpec.cpp
@@ -31,12 +31,13 @@ CUDAPTXInMemorySpec::CUDAPTXInMemorySpec(
const char *CUDAPTXInMemorySpec::getCode(int ComputeCapabilityMajor,
int ComputeCapabilityMinor) const {
- auto PTXIter =
- PTXByComputeCapability.find(CUDAPTXInMemorySpec::ComputeCapability{
+ auto Iterator =
+ PTXByComputeCapability.upper_bound(CUDAPTXInMemorySpec::ComputeCapability{
ComputeCapabilityMajor, ComputeCapabilityMinor});
- if (PTXIter == PTXByComputeCapability.end())
+ if (Iterator == PTXByComputeCapability.begin())
return nullptr;
- return PTXIter->second;
+ --Iterator;
+ return Iterator->second;
}
CUDAFatbinInMemorySpec::CUDAFatbinInMemorySpec(llvm::StringRef KernelName,
diff --git a/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp b/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp
index fc9eb549968..486a3504091 100644
--- a/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp
+++ b/parallel-libs/streamexecutor/unittests/CoreTests/KernelSpecTest.cpp
@@ -30,8 +30,9 @@ TEST(CUDAPTXInMemorySpec, SingleComputeCapability) {
const char *PTXCodeString = "Dummy PTX code";
se::CUDAPTXInMemorySpec Spec("KernelName", {{{1, 0}, PTXCodeString}});
EXPECT_EQ("KernelName", Spec.getKernelName());
+ EXPECT_EQ(nullptr, Spec.getCode(0, 5));
EXPECT_EQ(PTXCodeString, Spec.getCode(1, 0));
- EXPECT_EQ(nullptr, Spec.getCode(2, 0));
+ EXPECT_EQ(PTXCodeString, Spec.getCode(2, 0));
}
TEST(CUDAPTXInMemorySpec, TwoComputeCapabilities) {
@@ -40,9 +41,10 @@ TEST(CUDAPTXInMemorySpec, TwoComputeCapabilities) {
se::CUDAPTXInMemorySpec Spec(
"KernelName", {{{1, 0}, PTXCodeString10}, {{3, 0}, PTXCodeString30}});
EXPECT_EQ("KernelName", Spec.getKernelName());
+ EXPECT_EQ(nullptr, Spec.getCode(0, 5));
EXPECT_EQ(PTXCodeString10, Spec.getCode(1, 0));
EXPECT_EQ(PTXCodeString30, Spec.getCode(3, 0));
- EXPECT_EQ(nullptr, Spec.getCode(2, 0));
+ EXPECT_EQ(PTXCodeString10, Spec.getCode(2, 0));
}
TEST(CUDAFatbinInMemorySpec, BasicUsage) {
@@ -89,8 +91,9 @@ TEST(MultiKernelLoaderSpec, Registration) {
EXPECT_TRUE(MultiSpec.hasOpenCLTextInMemory());
EXPECT_EQ(KernelName, MultiSpec.getCUDAPTXInMemory().getKernelName());
+ EXPECT_EQ(nullptr, MultiSpec.getCUDAPTXInMemory().getCode(0, 5));
EXPECT_EQ(PTXCodeString, MultiSpec.getCUDAPTXInMemory().getCode(1, 0));
- EXPECT_EQ(nullptr, MultiSpec.getCUDAPTXInMemory().getCode(2, 0));
+ EXPECT_EQ(PTXCodeString, MultiSpec.getCUDAPTXInMemory().getCode(2, 0));
EXPECT_EQ(KernelName, MultiSpec.getCUDAFatbinInMemory().getKernelName());
EXPECT_EQ(FatbinBytes, MultiSpec.getCUDAFatbinInMemory().getBytes());
OpenPOWER on IntegriCloud