summaryrefslogtreecommitdiffstats
path: root/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
diff options
context:
space:
mode:
authorRichard Trieu <rtrieu@google.com>2018-09-05 04:19:15 +0000
committerRichard Trieu <rtrieu@google.com>2018-09-05 04:19:15 +0000
commit47c2bc58b3d6142ff3400846bd847620d0a58548 (patch)
treef1f13e17002e0c10fa38303dd4b270a92e6c9ef9 /llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
parentc8f348cba7035286ab49d3a2e5d82439ffd814ec (diff)
downloadbcm5719-llvm-47c2bc58b3d6142ff3400846bd847620d0a58548.tar.gz
bcm5719-llvm-47c2bc58b3d6142ff3400846bd847620d0a58548.zip
Prevent unsigned overflow.
The sum of the weights is caculated in an APInt, which has a width smaller than 64. In certain cases, the sum of the widths would overflow when calculations are done inside an APInt, but would not if done with uint64_t. Since the values will be passed as uint64_t in the function call anyways, do all the math in 64 bits. Also added an assert in case the probabilities overflow 64 bits. llvm-svn: 341444
Diffstat (limited to 'llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp')
-rw-r--r--llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp16
1 files changed, 9 insertions, 7 deletions
diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
index 4da8a7113f6..80e16e7157e 100644
--- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
+++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp
@@ -614,13 +614,15 @@ static bool CheckMDProf(MDNode *MD, BranchProbability &TrueProb,
ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2));
if (!TrueWeight || !FalseWeight)
return false;
- APInt TrueWt = TrueWeight->getValue();
- APInt FalseWt = FalseWeight->getValue();
- APInt SumWt = TrueWt + FalseWt;
- TrueProb = BranchProbability::getBranchProbability(TrueWt.getZExtValue(),
- SumWt.getZExtValue());
- FalseProb = BranchProbability::getBranchProbability(FalseWt.getZExtValue(),
- SumWt.getZExtValue());
+ uint64_t TrueWt = TrueWeight->getValue().getZExtValue();
+ uint64_t FalseWt = FalseWeight->getValue().getZExtValue();
+ uint64_t SumWt = TrueWt + FalseWt;
+
+ assert(SumWt >= TrueWt && SumWt >= FalseWt &&
+ "Overflow calculating branch probabilities.");
+
+ TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt);
+ FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt);
return true;
}
OpenPOWER on IntegriCloud