diff options
6 files changed, 66 insertions, 22 deletions
diff --git a/clang-tools-extra/include-fixer/IncludeFixer.cpp b/clang-tools-extra/include-fixer/IncludeFixer.cpp index 204d0c4eb01..3023caf92e3 100644 --- a/clang-tools-extra/include-fixer/IncludeFixer.cpp +++ b/clang-tools-extra/include-fixer/IncludeFixer.cpp @@ -365,6 +365,9 @@ IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers, .getLocWithOffset(Range.getOffset()) .print(llvm::dbgs(), CI->getSourceManager())); DEBUG(llvm::dbgs() << " ..."); + llvm::StringRef FileName = CI->getSourceManager().getFilename( + CI->getSourceManager().getLocForStartOfFile( + CI->getSourceManager().getMainFileID())); QuerySymbolInfos.push_back({Query.str(), ScopedQualifiers, Range}); @@ -385,9 +388,10 @@ IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers, // context, it might treat the identifier as a nested class of the scoped // namespace. std::vector<find_all_symbols::SymbolInfo> MatchedSymbols = - SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false); + SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName); if (MatchedSymbols.empty()) - MatchedSymbols = SymbolIndexMgr.search(Query); + MatchedSymbols = + SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName); DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size() << " symbols\n"); // We store a copy of MatchedSymbols in a place where it's globally reachable. diff --git a/clang-tools-extra/include-fixer/SymbolIndexManager.cpp b/clang-tools-extra/include-fixer/SymbolIndexManager.cpp index 45890f4cea6..3bdb5912ddc 100644 --- a/clang-tools-extra/include-fixer/SymbolIndexManager.cpp +++ b/clang-tools-extra/include-fixer/SymbolIndexManager.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" #define DEBUG_TYPE "include-fixer" @@ -20,30 +21,57 @@ namespace include_fixer { using clang::find_all_symbols::SymbolInfo; -/// Sorts SymbolInfos based on the popularity info in SymbolInfo. -static void rankByPopularity(std::vector<SymbolInfo> &Symbols) { - // First collect occurrences per header file. - llvm::DenseMap<llvm::StringRef, unsigned> HeaderPopularity; - for (const SymbolInfo &Symbol : Symbols) { - unsigned &Popularity = HeaderPopularity[Symbol.getFilePath()]; - Popularity = std::max(Popularity, Symbol.getNumOccurrences()); +// Calculate a score based on whether we think the given header is closely +// related to the given source file. +static double similarityScore(llvm::StringRef FileName, + llvm::StringRef Header) { + // Compute the maximum number of common path segements between Header and + // a suffix of FileName. + // We do not do a full longest common substring computation, as Header + // specifies the path we would directly #include, so we assume it is rooted + // relatively to a subproject of the repository. + int MaxSegments = 1; + for (auto FileI = llvm::sys::path::begin(FileName), + FileE = llvm::sys::path::end(FileName); + FileI != FileE; ++FileI) { + int Segments = 0; + for (auto HeaderI = llvm::sys::path::begin(Header), + HeaderE = llvm::sys::path::end(Header), I = FileI; + HeaderI != HeaderE && *I == *HeaderI && I != FileE; ++I, ++HeaderI) { + ++Segments; + } + MaxSegments = std::max(Segments, MaxSegments); } + return MaxSegments; +} - // Sort by the gathered popularities. Use file name as a tie breaker so we can +static void rank(std::vector<SymbolInfo> &Symbols, + llvm::StringRef FileName) { + llvm::DenseMap<llvm::StringRef, double> Score; + for (const SymbolInfo &Symbol : Symbols) { + // Calculate a score from the similarity of the header the symbol is in + // with the current file and the popularity of the symbol. + double NewScore = similarityScore(FileName, Symbol.getFilePath()) * + (1.0 + std::log2(1 + Symbol.getNumOccurrences())); + double &S = Score[Symbol.getFilePath()]; + S = std::max(S, NewScore); + } + // Sort by the gathered scores. Use file name as a tie breaker so we can // deduplicate. std::sort(Symbols.begin(), Symbols.end(), [&](const SymbolInfo &A, const SymbolInfo &B) { - auto APop = HeaderPopularity[A.getFilePath()]; - auto BPop = HeaderPopularity[B.getFilePath()]; - if (APop != BPop) - return APop > BPop; + auto AS = Score[A.getFilePath()]; + auto BS = Score[B.getFilePath()]; + if (AS != BS) + return AS > BS; return A.getFilePath() < B.getFilePath(); }); } std::vector<find_all_symbols::SymbolInfo> SymbolIndexManager::search(llvm::StringRef Identifier, - bool IsNestedSearch) const { + bool IsNestedSearch, + llvm::StringRef FileName) const { // The identifier may be fully qualified, so split it and get all the context // names. llvm::SmallVector<llvm::StringRef, 8> Names; @@ -119,7 +147,7 @@ SymbolIndexManager::search(llvm::StringRef Identifier, TookPrefix = true; } while (MatchedSymbols.empty() && !Names.empty() && IsNestedSearch); - rankByPopularity(MatchedSymbols); + rank(MatchedSymbols, FileName); return MatchedSymbols; } diff --git a/clang-tools-extra/include-fixer/SymbolIndexManager.h b/clang-tools-extra/include-fixer/SymbolIndexManager.h index 6a1d22cc145..e1e52879dd5 100644 --- a/clang-tools-extra/include-fixer/SymbolIndexManager.h +++ b/clang-tools-extra/include-fixer/SymbolIndexManager.h @@ -42,7 +42,8 @@ public: /// /// \returns A list of symbol candidates. std::vector<find_all_symbols::SymbolInfo> - search(llvm::StringRef Identifier, bool IsNestedSearch = true) const; + search(llvm::StringRef Identifier, bool IsNestedSearch = true, + llvm::StringRef FileName = "") const; private: std::vector<std::shared_future<std::unique_ptr<SymbolIndex>>> SymbolIndices; diff --git a/clang-tools-extra/include-fixer/tool/ClangIncludeFixer.cpp b/clang-tools-extra/include-fixer/tool/ClangIncludeFixer.cpp index 351cad5e07f..47942a97afa 100644 --- a/clang-tools-extra/include-fixer/tool/ClangIncludeFixer.cpp +++ b/clang-tools-extra/include-fixer/tool/ClangIncludeFixer.cpp @@ -332,7 +332,8 @@ int includeFixerMain(int argc, const char **argv) { // Query symbol mode. if (!QuerySymbol.empty()) { - auto MatchedSymbols = SymbolIndexMgr->search(QuerySymbol); + auto MatchedSymbols = SymbolIndexMgr->search( + QuerySymbol, /*IsNestedSearch=*/true, SourceFilePath); for (auto &Symbol : MatchedSymbols) { std::string HeaderPath = Symbol.getFilePath().str(); Symbol.SetFilePath(((HeaderPath[0] == '"' || HeaderPath[0] == '<') diff --git a/clang-tools-extra/test/include-fixer/Inputs/fake_yaml_db.yaml b/clang-tools-extra/test/include-fixer/Inputs/fake_yaml_db.yaml index a2f991324f7..2cf30279e4b 100644 --- a/clang-tools-extra/test/include-fixer/Inputs/fake_yaml_db.yaml +++ b/clang-tools-extra/test/include-fixer/Inputs/fake_yaml_db.yaml @@ -9,7 +9,6 @@ FilePath: foo.h LineNumber: 1 Type: Class NumOccurrences: 1 -... --- Name: bar Contexts: @@ -21,7 +20,7 @@ FilePath: ../include/bar.h LineNumber: 1 Type: Class NumOccurrences: 1 -... +--- Name: bar Contexts: - ContextType: Namespace @@ -32,7 +31,7 @@ FilePath: ../include/bar.h LineNumber: 2 Type: Class NumOccurrences: 3 -... +--- Name: bar Contexts: - ContextType: Namespace @@ -50,4 +49,12 @@ FilePath: var.h LineNumber: 1 Type: Variable NumOccurrences: 1 -... +--- +Name: bar +Contexts: + - ContextType: Namespace + ContextName: c +FilePath: test/include-fixer/baz.h +LineNumber: 1 +Type: Class +NumOccurrences: 1 diff --git a/clang-tools-extra/test/include-fixer/ranking.cpp b/clang-tools-extra/test/include-fixer/ranking.cpp index 000f6a58c7e..2dabe16fa63 100644 --- a/clang-tools-extra/test/include-fixer/ranking.cpp +++ b/clang-tools-extra/test/include-fixer/ranking.cpp @@ -1,6 +1,9 @@ // RUN: clang-include-fixer -db=yaml -input=%S/Inputs/fake_yaml_db.yaml -output-headers %s -- | FileCheck %s +// RUN: clang-include-fixer -query-symbol bar -db=yaml -input=%S/Inputs/fake_yaml_db.yaml -output-headers %s -- | FileCheck %s // CHECK: "HeaderInfos": [ +// CHECK-NEXT: {"Header": "\"test/include-fixer/baz.h\"", +// CHECK-NEXT: "QualifiedName": "c::bar"}, // CHECK-NEXT: {"Header": "\"../include/bar.h\"", // CHECK-NEXT: "QualifiedName": "b::a::bar"}, // CHECK-NEXT: {"Header": "\"../include/zbar.h\"", |