diff options
| author | David Green <david.green@arm.com> | 2018-09-25 10:08:47 +0000 | 
|---|---|---|
| committer | David Green <david.green@arm.com> | 2018-09-25 10:08:47 +0000 | 
| commit | 9108c2b92155decd107489f2069907b44f234250 (patch) | |
| tree | 7d031061b1680b4927cacc21b07c859887c62c37 | |
| parent | 029aa8ec7f52a6dabd3ac3b15e488ebe0148ae2c (diff) | |
| download | bcm5719-llvm-9108c2b92155decd107489f2069907b44f234250.tar.gz bcm5719-llvm-9108c2b92155decd107489f2069907b44f234250.zip  | |
[LoopUnroll] Add check to Latch's terminator in UnrollRuntimeLoopRemainder
In this patch, I'm adding an extra check to the Latch's terminator in llvm::UnrollRuntimeLoopRemainder,
similar to how it is already done in the llvm::UnrollLoop.
The compiler would crash if this function is called with a malformed loop.
Patch by Rodrigo Caetano Rocha!
Differential Revision: https://reviews.llvm.org/D51486
llvm-svn: 342958
4 files changed, 123 insertions, 5 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index ce078ab8f7b..3361883acd0 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -545,13 +545,27 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count,    BasicBlock *Header = L->getHeader();    BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); + +  if (!LatchBR || LatchBR->isUnconditional()) { +    // The loop-rotate pass can be helpful to avoid this in many cases. +    LLVM_DEBUG( +        dbgs() +        << "Loop latch not terminated by a conditional branch.\n"); +    return false; +  } +    unsigned ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0;    BasicBlock *LatchExit = LatchBR->getSuccessor(ExitIndex); -  // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the -  // targets of the Latch be an exit block out of the loop. This needs -  // to be guaranteed by the callers of UnrollRuntimeLoopRemainder. -  assert(!L->contains(LatchExit) && -         "one of the loop latch successors should be the exit block!"); + +  if (L->contains(LatchExit)) { +    // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the +    // targets of the Latch be an exit block out of the loop. +    LLVM_DEBUG( +        dbgs() +        << "One of the loop latch successors must be the exit block.\n"); +    return false; +  } +    // These are exit blocks other than the target of the latch exiting block.    SmallVector<BasicBlock *, 4> OtherExits;    bool isMultiExitUnrollingEnabled = diff --git a/llvm/test/Transforms/LoopUnroll/runtime-loop-non-exiting-latch.ll b/llvm/test/Transforms/LoopUnroll/runtime-loop-non-exiting-latch.ll new file mode 100644 index 00000000000..6915981375a --- /dev/null +++ b/llvm/test/Transforms/LoopUnroll/runtime-loop-non-exiting-latch.ll @@ -0,0 +1,27 @@ +; REQUIRES: asserts +; RUN: opt < %s -S -loop-unroll -unroll-runtime=true -unroll-allow-remainder=true -unroll-count=4 + +; Make sure that the runtime unroll does not break with a non-exiting latch. +define i32 @test(i32* %a, i32* %b, i32* %c, i64 %n) { +entry: +  br label %while.cond + +while.cond:                                       ; preds = %while.body, %entry +  %i.0 = phi i64 [ 0, %entry ], [ %inc, %while.body ] +  %cmp = icmp slt i64 %i.0, %n +  br i1 %cmp, label %while.body, label %while.end + +while.body:                                       ; preds = %while.cond +  %arrayidx = getelementptr inbounds i32, i32* %b, i64 %i.0 +  %0 = load i32, i32* %arrayidx +  %arrayidx1 = getelementptr inbounds i32, i32* %c, i64 %i.0 +  %1 = load i32, i32* %arrayidx1 +  %mul = mul nsw i32 %0, %1 +  %arrayidx2 = getelementptr inbounds i32, i32* %a, i64 %i.0 +  store i32 %mul, i32* %arrayidx2 +  %inc = add nsw i64 %i.0, 1 +  br label %while.cond + +while.end:                                        ; preds = %while.cond +  ret i32 0 +} diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt index 13c321b38a1..785b79865dc 100644 --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -15,5 +15,6 @@ add_llvm_unittest(UtilsTests    IntegerDivisionTest.cpp    LocalTest.cpp    SSAUpdaterBulkTest.cpp +  UnrollLoopTest.cpp    ValueMapperTest.cpp    ) diff --git a/llvm/unittests/Transforms/Utils/UnrollLoopTest.cpp b/llvm/unittests/Transforms/Utils/UnrollLoopTest.cpp new file mode 100644 index 00000000000..5f7c2d62380 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/UnrollLoopTest.cpp @@ -0,0 +1,76 @@ +//===- UnrollLoopTest.cpp - Unit tests for UnrollLoop ---------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/UnrollLoop.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) { +  SMDiagnostic Err; +  std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C); +  if (!Mod) +    Err.print("UnrollLoopTests", errs()); +  return Mod; +} + +TEST(LoopUnrollRuntime, Latch) { +  LLVMContext C; + +  std::unique_ptr<Module> M = parseIR( +    C, +    R"(define i32 @test(i32* %a, i32* %b, i32* %c, i64 %n) { +entry: +  br label %while.cond + +while.cond:                                       ; preds = %while.body, %entry +  %i.0 = phi i64 [ 0, %entry ], [ %inc, %while.body ] +  %cmp = icmp slt i64 %i.0, %n +  br i1 %cmp, label %while.body, label %while.end + +while.body:                                       ; preds = %while.cond +  %arrayidx = getelementptr inbounds i32, i32* %b, i64 %i.0 +  %0 = load i32, i32* %arrayidx +  %arrayidx1 = getelementptr inbounds i32, i32* %c, i64 %i.0 +  %1 = load i32, i32* %arrayidx1 +  %mul = mul nsw i32 %0, %1 +  %arrayidx2 = getelementptr inbounds i32, i32* %a, i64 %i.0 +  store i32 %mul, i32* %arrayidx2 +  %inc = add nsw i64 %i.0, 1 +  br label %while.cond + +while.end:                                        ; preds = %while.cond +  ret i32 0 +})" +    ); + +  auto *F = M->getFunction("test"); +  DominatorTree DT(*F); +  LoopInfo LI(DT); +  AssumptionCache AC(*F); +  TargetLibraryInfoImpl TLII; +  TargetLibraryInfo TLI(TLII); +  ScalarEvolution SE(*F, TLI, AC, DT, LI); + +  Loop *L = *LI.begin(); + +  bool PreserveLCSSA = L->isRecursivelyLCSSAForm(DT,LI); + +  bool ret = UnrollRuntimeLoopRemainder(L, 4, true, false, false, &LI, &SE, &DT, &AC, PreserveLCSSA); +  EXPECT_FALSE(ret); +}  | 

