summaryrefslogtreecommitdiffstats
path: root/clang/unittests/AST/ASTPrint.h
diff options
context:
space:
mode:
Diffstat (limited to 'clang/unittests/AST/ASTPrint.h')
-rw-r--r--clang/unittests/AST/ASTPrint.h92
1 files changed, 92 insertions, 0 deletions
diff --git a/clang/unittests/AST/ASTPrint.h b/clang/unittests/AST/ASTPrint.h
new file mode 100644
index 00000000000..c3b6b842316
--- /dev/null
+++ b/clang/unittests/AST/ASTPrint.h
@@ -0,0 +1,92 @@
+//===- unittests/AST/ASTPrint.h ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Helpers to simplify testing of printing of AST constructs provided in the/
+// form of the source code.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/ASTContext.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Tooling/Tooling.h"
+#include "llvm/ADT/SmallString.h"
+#include "gtest/gtest.h"
+
+namespace clang {
+
+using PolicyAdjusterType =
+ Optional<llvm::function_ref<void(PrintingPolicy &Policy)>>;
+
+static void PrintStmt(raw_ostream &Out, const ASTContext *Context,
+ const Stmt *S, PolicyAdjusterType PolicyAdjuster) {
+ assert(S != nullptr && "Expected non-null Stmt");
+ PrintingPolicy Policy = Context->getPrintingPolicy();
+ if (PolicyAdjuster)
+ (*PolicyAdjuster)(Policy);
+ S->printPretty(Out, /*Helper*/ nullptr, Policy);
+}
+
+class PrintMatch : public ast_matchers::MatchFinder::MatchCallback {
+ SmallString<1024> Printed;
+ unsigned NumFoundStmts;
+ PolicyAdjusterType PolicyAdjuster;
+
+public:
+ PrintMatch(PolicyAdjusterType PolicyAdjuster)
+ : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {}
+
+ void run(const ast_matchers::MatchFinder::MatchResult &Result) override {
+ const Stmt *S = Result.Nodes.getNodeAs<Stmt>("id");
+ if (!S)
+ return;
+ NumFoundStmts++;
+ if (NumFoundStmts > 1)
+ return;
+
+ llvm::raw_svector_ostream Out(Printed);
+ PrintStmt(Out, Result.Context, S, PolicyAdjuster);
+ }
+
+ StringRef getPrinted() const { return Printed; }
+
+ unsigned getNumFoundStmts() const { return NumFoundStmts; }
+};
+
+template <typename T>
+::testing::AssertionResult
+PrintedStmtMatches(StringRef Code, const std::vector<std::string> &Args,
+ const T &NodeMatch, StringRef ExpectedPrinted,
+ PolicyAdjusterType PolicyAdjuster = None) {
+
+ PrintMatch Printer(PolicyAdjuster);
+ ast_matchers::MatchFinder Finder;
+ Finder.addMatcher(NodeMatch, &Printer);
+ std::unique_ptr<tooling::FrontendActionFactory> Factory(
+ tooling::newFrontendActionFactory(&Finder));
+
+ if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args))
+ return testing::AssertionFailure()
+ << "Parsing error in \"" << Code.str() << "\"";
+
+ if (Printer.getNumFoundStmts() == 0)
+ return testing::AssertionFailure() << "Matcher didn't find any statements";
+
+ if (Printer.getNumFoundStmts() > 1)
+ return testing::AssertionFailure()
+ << "Matcher should match only one statement (found "
+ << Printer.getNumFoundStmts() << ")";
+
+ if (Printer.getPrinted() != ExpectedPrinted)
+ return ::testing::AssertionFailure()
+ << "Expected \"" << ExpectedPrinted.str() << "\", got \""
+ << Printer.getPrinted().str() << "\"";
+
+ return ::testing::AssertionSuccess();
+}
+
+} // namespace clang
OpenPOWER on IntegriCloud