diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/IR/Instructions.cpp | 57 |
1 files changed, 39 insertions, 18 deletions
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 8812df35e26..ad082a9c24f 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -45,6 +45,12 @@ using namespace llvm; +static cl::opt<bool> SwitchInstProfUpdateWrapperStrict( + "switch-inst-prof-update-wrapper-strict", cl::Hidden, + cl::desc("Assert that prof branch_weights metadata is valid when creating " + "an instance of SwitchInstProfUpdateWrapper"), + cl::init(false)); + //===----------------------------------------------------------------------===// // AllocaInst Class //===----------------------------------------------------------------------===// @@ -3880,7 +3886,7 @@ SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) { } MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { - assert(Changed && "called only if metadata has changed"); + assert(State == Changed && "called only if metadata has changed"); if (!Weights) return nullptr; @@ -3897,11 +3903,20 @@ MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights); } -Optional<SmallVector<uint32_t, 8> > -SwitchInstProfUpdateWrapper::getProfBranchWeights() { +void SwitchInstProfUpdateWrapper::init() { MDNode *ProfileData = getProfBranchWeightsMD(SI); - if (!ProfileData) - return None; + if (!ProfileData) { + State = Initialized; + return; + } + + if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) { + State = Invalid; + if (SwitchInstProfUpdateWrapperStrict) + assert(!"number of prof branch_weights metadata operands corresponds to" + " number of succesors"); + return; + } SmallVector<uint32_t, 8> Weights; for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) { @@ -3909,7 +3924,8 @@ SwitchInstProfUpdateWrapper::getProfBranchWeights() { uint32_t CW = C->getValue().getZExtValue(); Weights.push_back(CW); } - return Weights; + State = Initialized; + this->Weights = std::move(Weights); } SwitchInst::CaseIt @@ -3917,7 +3933,7 @@ SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) { if (Weights) { assert(SI.getNumSuccessors() == Weights->size() && "num of prof branch_weights must accord with num of successors"); - Changed = true; + State = Changed; // Copy the last case to the place of the removed one and shrink. // This is tightly coupled with the way SwitchInst::removeCase() removes // the cases in SwitchInst::removeCase(CaseIt). @@ -3932,12 +3948,15 @@ void SwitchInstProfUpdateWrapper::addCase( SwitchInstProfUpdateWrapper::CaseWeightOpt W) { SI.addCase(OnVal, Dest); + if (State == Invalid) + return; + if (!Weights && W && *W) { - Changed = true; + State = Changed; Weights = SmallVector<uint32_t, 8>(SI.getNumSuccessors(), 0); Weights.getValue()[SI.getNumSuccessors() - 1] = *W; } else if (Weights) { - Changed = true; + State = Changed; Weights.getValue().push_back(W ? *W : 0); } if (Weights) @@ -3948,10 +3967,11 @@ void SwitchInstProfUpdateWrapper::addCase( SymbolTableList<Instruction>::iterator SwitchInstProfUpdateWrapper::eraseFromParent() { // Instruction is erased. Mark as unchanged to not touch it in the destructor. - Changed = false; - - if (Weights) - Weights->resize(0); + if (State != Invalid) { + State = Initialized; + if (Weights) + Weights->resize(0); + } return SI.eraseFromParent(); } @@ -3964,7 +3984,7 @@ SwitchInstProfUpdateWrapper::getSuccessorWeight(unsigned idx) { void SwitchInstProfUpdateWrapper::setSuccessorWeight( unsigned idx, SwitchInstProfUpdateWrapper::CaseWeightOpt W) { - if (!W) + if (!W || State == Invalid) return; if (!Weights && *W) @@ -3973,7 +3993,7 @@ void SwitchInstProfUpdateWrapper::setSuccessorWeight( if (Weights) { auto &OldW = Weights.getValue()[idx]; if (*W != OldW) { - Changed = true; + State = Changed; OldW = *W; } } @@ -3983,9 +4003,10 @@ SwitchInstProfUpdateWrapper::CaseWeightOpt SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI, unsigned idx) { if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) - return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1)) - ->getValue() - .getZExtValue(); + if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1) + return mdconst::extract<ConstantInt>(ProfileData->getOperand(idx + 1)) + ->getValue() + .getZExtValue(); return None; } |

