diff options
Diffstat (limited to 'llvm/unittests/Analysis/SparsePropagation.cpp')
| -rw-r--r-- | llvm/unittests/Analysis/SparsePropagation.cpp | 544 |
1 files changed, 544 insertions, 0 deletions
diff --git a/llvm/unittests/Analysis/SparsePropagation.cpp b/llvm/unittests/Analysis/SparsePropagation.cpp new file mode 100644 index 00000000000..298b1403eb5 --- /dev/null +++ b/llvm/unittests/Analysis/SparsePropagation.cpp @@ -0,0 +1,544 @@ +//===- SparsePropagation.cpp - Unit tests for the generic solver ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SparsePropagation.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/IRBuilder.h" +#include "gtest/gtest.h" +using namespace llvm; + +namespace { +/// To enable interprocedural analysis, we assign LLVM values to the following +/// groups. The register group represents SSA registers, the return group +/// represents the return values of functions, and the memory group represents +/// in-memory values. An LLVM Value can technically be in more than one group. +/// It's necessary to distinguish these groups so we can, for example, track a +/// global variable separately from the value stored at its location. +enum class IPOGrouping { Register, Return, Memory }; + +/// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings. +/// The PointerIntPair header provides a DenseMapInfo specialization, so using +/// these as LatticeKeys is fine. +using TestLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>; +} // namespace + +namespace llvm { +/// A specialization of LatticeKeyInfo for TestLatticeKeys. The generic solver +/// must translate between LatticeKeys and LLVM Values when adding Values to +/// its work list and inspecting the state of control-flow related values. +template <> struct LatticeKeyInfo<TestLatticeKey> { + static inline Value *getValueFromLatticeKey(TestLatticeKey Key) { + return Key.getPointer(); + } + static inline TestLatticeKey getLatticeKeyFromValue(Value *V) { + return TestLatticeKey(V, IPOGrouping::Register); + } +}; +} // namespace llvm + +namespace { +/// This class defines a simple test lattice value that could be used for +/// solving problems similar to constant propagation. The value is maintained +/// as a PointerIntPair. +class TestLatticeVal { +public: + /// The states of the lattices value. Only the ConstantVal state is + /// interesting; the rest are special states used by the generic solver. The + /// UntrackedVal state differs from the other three in that the generic + /// solver uses it to avoid doing unnecessary work. In particular, when a + /// value moves to the UntrackedVal state, it's users are not notified. + enum TestLatticeStateTy { + UndefinedVal, + ConstantVal, + OverdefinedVal, + UntrackedVal + }; + + TestLatticeVal() : LatticeVal(nullptr, UndefinedVal) {} + TestLatticeVal(Constant *C, TestLatticeStateTy State) + : LatticeVal(C, State) {} + + /// Return true if this lattice value is in the Constant state. This is used + /// for checking the solver results. + bool isConstant() const { return LatticeVal.getInt() == ConstantVal; } + + /// Return true if this lattice value is in the Overdefined state. This is + /// used for checking the solver results. + bool isOverdefined() const { return LatticeVal.getInt() == OverdefinedVal; } + + bool operator==(const TestLatticeVal &RHS) const { + return LatticeVal == RHS.LatticeVal; + } + + bool operator!=(const TestLatticeVal &RHS) const { + return LatticeVal != RHS.LatticeVal; + } + +private: + /// A simple lattice value type for problems similar to constant propagation. + /// It holds the constant value and the lattice state. + PointerIntPair<const Constant *, 2, TestLatticeStateTy> LatticeVal; +}; + +/// This class defines a simple test lattice function that could be used for +/// solving problems similar to constant propagation. The test lattice differs +/// from a "real" lattice in a few ways. First, it initializes all return +/// values, values stored in global variables, and arguments in the undefined +/// state. This means that there are no limitations on what we can track +/// interprocedurally. For simplicity, all global values in the tests will be +/// given internal linkage, since this is not something this lattice function +/// tracks. Second, it only handles the few instructions necessary for the +/// tests. +class TestLatticeFunc + : public AbstractLatticeFunction<TestLatticeKey, TestLatticeVal> { +public: + /// Construct a new test lattice function with special values for the + /// Undefined, Overdefined, and Untracked states. + TestLatticeFunc() + : AbstractLatticeFunction( + TestLatticeVal(nullptr, TestLatticeVal::UndefinedVal), + TestLatticeVal(nullptr, TestLatticeVal::OverdefinedVal), + TestLatticeVal(nullptr, TestLatticeVal::UntrackedVal)) {} + + /// Compute and return a TestLatticeVal for the given TestLatticeKey. For the + /// test analysis, a LatticeKey will begin in the undefined state, unless it + /// represents an LLVM Constant in the register grouping. + TestLatticeVal ComputeLatticeVal(TestLatticeKey Key) override { + if (Key.getInt() == IPOGrouping::Register) + if (auto *C = dyn_cast<Constant>(Key.getPointer())) + return TestLatticeVal(C, TestLatticeVal::ConstantVal); + return getUndefVal(); + } + + /// Merge the two given lattice values. This merge should be equivalent to + /// what is done for constant propagation. That is, the resulting lattice + /// value is constant only if the two given lattice values are constant and + /// hold the same value. + TestLatticeVal MergeValues(TestLatticeVal X, TestLatticeVal Y) override { + if (X == getUntrackedVal() || Y == getUntrackedVal()) + return getUntrackedVal(); + if (X == getOverdefinedVal() || Y == getOverdefinedVal()) + return getOverdefinedVal(); + if (X == getUndefVal() && Y == getUndefVal()) + return getUndefVal(); + if (X == getUndefVal()) + return Y; + if (Y == getUndefVal()) + return X; + if (X == Y) + return X; + return getOverdefinedVal(); + } + + /// Compute the lattice values that change as a result of executing the given + /// instruction. We only handle the few instructions needed for the tests. + void ComputeInstructionState( + Instruction &I, DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, + SparseSolver<TestLatticeKey, TestLatticeVal> &SS) override { + switch (I.getOpcode()) { + case Instruction::Call: + return visitCallSite(cast<CallInst>(&I), ChangedValues, SS); + case Instruction::Ret: + return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS); + case Instruction::Store: + return visitStore(*cast<StoreInst>(&I), ChangedValues, SS); + default: + return visitInst(I, ChangedValues, SS); + } + } + +private: + /// Handle call sites. The state of a called function's argument is the merge + /// of the current formal argument state with the call site's corresponding + /// actual argument state. The call site state is the merge of the call site + /// state with the returned value state of the called function. + void visitCallSite(CallSite CS, + DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, + SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { + Function *F = CS.getCalledFunction(); + Instruction *I = CS.getInstruction(); + auto RegI = TestLatticeKey(I, IPOGrouping::Register); + if (!F) { + ChangedValues[RegI] = getOverdefinedVal(); + return; + } + SS.MarkBlockExecutable(&F->front()); + for (Argument &A : F->args()) { + auto RegFormal = TestLatticeKey(&A, IPOGrouping::Register); + auto RegActual = + TestLatticeKey(CS.getArgument(A.getArgNo()), IPOGrouping::Register); + ChangedValues[RegFormal] = + MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual)); + } + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + ChangedValues[RegI] = + MergeValues(SS.getValueState(RegI), SS.getValueState(RetF)); + } + + /// Handle return instructions. The function's return state is the merge of + /// the returned value state and the function's current return state. + void visitReturn(ReturnInst &I, + DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, + SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { + Function *F = I.getParent()->getParent(); + if (F->getReturnType()->isVoidTy()) + return; + auto RegR = TestLatticeKey(I.getReturnValue(), IPOGrouping::Register); + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + ChangedValues[RetF] = + MergeValues(SS.getValueState(RegR), SS.getValueState(RetF)); + } + + /// Handle store instructions. If the pointer operand of the store is a + /// global variable, we attempt to track the value. The global variable state + /// is the merge of the stored value state with the current global variable + /// state. + void visitStore(StoreInst &I, + DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, + SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { + auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand()); + if (!GV) + return; + auto RegVal = TestLatticeKey(I.getValueOperand(), IPOGrouping::Register); + auto MemPtr = TestLatticeKey(GV, IPOGrouping::Memory); + ChangedValues[MemPtr] = + MergeValues(SS.getValueState(RegVal), SS.getValueState(MemPtr)); + } + + /// Handle all other instructions. All other instructions are marked + /// overdefined. + void visitInst(Instruction &I, + DenseMap<TestLatticeKey, TestLatticeVal> &ChangedValues, + SparseSolver<TestLatticeKey, TestLatticeVal> &SS) { + auto RegI = TestLatticeKey(&I, IPOGrouping::Register); + ChangedValues[RegI] = getOverdefinedVal(); + } +}; + +/// This class defines the common data used for all of the tests. The tests +/// should add code to the module and then run the solver. +class SparsePropagationTest : public testing::Test { +protected: + LLVMContext Context; + Module M; + IRBuilder<> Builder; + TestLatticeFunc Lattice; + SparseSolver<TestLatticeKey, TestLatticeVal> Solver; + +public: + SparsePropagationTest() + : M("", Context), Builder(Context), Solver(&Lattice) {} +}; +} // namespace + +/// Test that we mark discovered functions executable. +/// +/// define internal void @f() { +/// call void @g() +/// ret void +/// } +/// +/// define internal void @g() { +/// call void @f() +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the call in "f". The mutually recursive call in "g" also +/// tests that we don't add a block to the basic block work list if it is +/// already executable. Doing so would put the solver into an infinite loop. +TEST_F(SparsePropagationTest, MarkBlockExecutable) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateCall(F); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(GEntry)); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable, and the +/// solver computes the lattice state of the global variable as constant. +TEST_F(SparsePropagationTest, GlobalVariableConstant) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); + EXPECT_TRUE(Solver.getExistingValueState(MemGV).isConstant()); +} + +/// Test that we propagate information through global variables. +/// +/// @gv = internal global i64 +/// +/// define internal void @f() { +/// store i64 0, i64* @gv +/// ret void +/// } +/// +/// define internal void @g() { +/// store i64 1, i64* @gv +/// ret void +/// } +/// +/// For this test, we initially mark both "f" and "g" executable, and the +/// solver computes the lattice state of the global variable as overdefined. +TEST_F(SparsePropagationTest, GlobalVariableOverDefined) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + GlobalVariable *GV = + new GlobalVariable(M, Builder.getInt64Ty(), false, + GlobalValue::InternalLinkage, nullptr, "gv"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateStore(Builder.getInt64(0), GV); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateStore(Builder.getInt64(1), GV); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.MarkBlockExecutable(GEntry); + Solver.Solve(); + + auto MemGV = TestLatticeKey(GV, IPOGrouping::Memory); + EXPECT_TRUE(Solver.getExistingValueState(MemGV).isOverdefined()); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 1 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable, and the solver computes +/// the return value of the function as constant. +TEST_F(SparsePropagationTest, FunctionDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(1)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + EXPECT_TRUE(Solver.getExistingValueState(RetF).isConstant()); +} + +/// Test that we propagate information through function returns. +/// +/// define internal i64 @f(i1* %cond) { +/// if: +/// %0 = load i1, i1* %cond +/// br i1 %0, label %then, label %else +/// +/// then: +/// ret i64 0 +/// +/// else: +/// ret i64 1 +/// } +/// +/// For this test, we initially mark "f" executable, and the solver computes +/// the return value of the function as overdefined. +TEST_F(SparsePropagationTest, FunctionOverDefined) { + Function *F = + Function::Create(FunctionType::get(Builder.getInt64Ty(), + {Type::getInt1PtrTy(Context)}, false), + GlobalValue::InternalLinkage, "f", &M); + BasicBlock *If = BasicBlock::Create(Context, "if", F); + BasicBlock *Then = BasicBlock::Create(Context, "then", F); + BasicBlock *Else = BasicBlock::Create(Context, "else", F); + F->arg_begin()->setName("cond"); + Builder.SetInsertPoint(If); + LoadInst *Cond = Builder.CreateLoad(F->arg_begin()); + Builder.CreateCondBr(Cond, Then, Else); + Builder.SetInsertPoint(Then); + Builder.CreateRet(Builder.getInt64(0)); + Builder.SetInsertPoint(Else); + Builder.CreateRet(Builder.getInt64(1)); + + Solver.MarkBlockExecutable(If); + Solver.Solve(); + + auto RetF = TestLatticeKey(F, IPOGrouping::Return); + EXPECT_TRUE(Solver.getExistingValueState(RetF).isOverdefined()); +} + +/// Test that we propagate information through arguments. +/// +/// define internal void @f() { +/// call void @g(i64 0, i64 1) +/// call void @g(i64 1, i64 1) +/// ret void +/// } +/// +/// define internal void @g(i64 %a, i64 %b) { +/// ret void +/// } +/// +/// For this test, we initially mark "f" executable, and the solver discovers +/// "g" because of the calls in "f". The solver computes the state of argument +/// "a" as overdefined and the state of "b" as constant. +/// +/// In addition, this test demonstrates that ComputeInstructionState can alter +/// the state of multiple lattice values, in addition to the one associated +/// with the instruction definition. Each call instruction in this test updates +/// the state of arguments "a" and "b". +TEST_F(SparsePropagationTest, ComputeInstructionState) { + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Function *G = Function::Create( + FunctionType::get(Builder.getVoidTy(), + {Builder.getInt64Ty(), Builder.getInt64Ty()}, false), + GlobalValue::InternalLinkage, "g", &M); + Argument *A = G->arg_begin(); + Argument *B = std::next(G->arg_begin()); + A->setName("a"); + B->setName("b"); + BasicBlock *FEntry = BasicBlock::Create(Context, "", F); + BasicBlock *GEntry = BasicBlock::Create(Context, "", G); + Builder.SetInsertPoint(FEntry); + Builder.CreateCall(G, {Builder.getInt64(0), Builder.getInt64(1)}); + Builder.CreateCall(G, {Builder.getInt64(1), Builder.getInt64(1)}); + Builder.CreateRetVoid(); + Builder.SetInsertPoint(GEntry); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(FEntry); + Solver.Solve(); + + auto RegA = TestLatticeKey(A, IPOGrouping::Register); + auto RegB = TestLatticeKey(B, IPOGrouping::Register); + EXPECT_TRUE(Solver.getExistingValueState(RegA).isOverdefined()); + EXPECT_TRUE(Solver.getExistingValueState(RegB).isConstant()); +} + +/// Test that we can handle exceptional terminator instructions. +/// +/// declare internal void @p() +/// +/// declare internal void @g() +/// +/// define internal void @f() personality i8* bitcast (void ()* @p to i8*) { +/// entry: +/// invoke void @g() +/// to label %exit unwind label %catch.pad +/// +/// catch.pad: +/// %0 = catchswitch within none [label %catch.body] unwind to caller +/// +/// catch.body: +/// %1 = catchpad within %0 [] +/// catchret from %1 to label %exit +/// +/// exit: +/// ret void +/// } +/// +/// For this test, we initially mark the entry block executable. The solver +/// then discovers the rest of the blocks in the function are executable. +TEST_F(SparsePropagationTest, ExceptionalTerminatorInsts) { + Function *P = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "p", &M); + Function *G = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "g", &M); + Function *F = Function::Create(FunctionType::get(Builder.getVoidTy(), false), + GlobalValue::InternalLinkage, "f", &M); + Constant *C = + ConstantExpr::getCast(Instruction::BitCast, P, Builder.getInt8PtrTy()); + F->setPersonalityFn(C); + BasicBlock *Entry = BasicBlock::Create(Context, "entry", F); + BasicBlock *Pad = BasicBlock::Create(Context, "catch.pad", F); + BasicBlock *Body = BasicBlock::Create(Context, "catch.body", F); + BasicBlock *Exit = BasicBlock::Create(Context, "exit", F); + Builder.SetInsertPoint(Entry); + Builder.CreateInvoke(G, Exit, Pad); + Builder.SetInsertPoint(Pad); + CatchSwitchInst *CatchSwitch = + Builder.CreateCatchSwitch(ConstantTokenNone::get(Context), nullptr, 1); + CatchSwitch->addHandler(Body); + Builder.SetInsertPoint(Body); + CatchPadInst *CatchPad = Builder.CreateCatchPad(CatchSwitch, {}); + Builder.CreateCatchRet(CatchPad, Exit); + Builder.SetInsertPoint(Exit); + Builder.CreateRetVoid(); + + Solver.MarkBlockExecutable(Entry); + Solver.Solve(); + + EXPECT_TRUE(Solver.isBlockExecutable(Pad)); + EXPECT_TRUE(Solver.isBlockExecutable(Body)); + EXPECT_TRUE(Solver.isBlockExecutable(Exit)); +} |

