diff options
| author | Sam Parker <sam.parker@arm.com> | 2019-07-11 07:47:50 +0000 | 
|---|---|---|
| committer | Sam Parker <sam.parker@arm.com> | 2019-07-11 07:47:50 +0000 | 
| commit | 85ad78b1cfa3932eb658365b74f5b08c25dbfb0e (patch) | |
| tree | 357bc2e534cd101cad146306d4230c8c4ef94ade /llvm | |
| parent | 274ad9c3717e43da1dacfec5fc0a61a4b4cd4fff (diff) | |
| download | bcm5719-llvm-85ad78b1cfa3932eb658365b74f5b08c25dbfb0e.tar.gz bcm5719-llvm-85ad78b1cfa3932eb658365b74f5b08c25dbfb0e.zip  | |
[ARM][ParallelDSP] Change the search for smlads
    
Two functional changes have been made here:
- Now search up from any add instruction to find the chains of
  operations that we may turn into a smlad. This allows the
  generation of a smlad which doesn't accumulate into a phi.
- The search function has been corrected to stop it falsely searching
  up through an invalid path.
    
The bulk of the changes have been making the Reduction struct a class
and making it more C++y with getters and setters.
Differential Revision: https://reviews.llvm.org/D61780
llvm-svn: 365740
Diffstat (limited to 'llvm')
| -rw-r--r-- | llvm/lib/Target/ARM/ARMParallelDSP.cpp | 568 | ||||
| -rw-r--r-- | llvm/test/CodeGen/ARM/ParallelDSP/aliasing.ll | 4 | ||||
| -rw-r--r-- | llvm/test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll | 151 | 
3 files changed, 470 insertions, 253 deletions
diff --git a/llvm/lib/Target/ARM/ARMParallelDSP.cpp b/llvm/lib/Target/ARM/ARMParallelDSP.cpp index 3cff9b56851..5389d09bf7d 100644 --- a/llvm/lib/Target/ARM/ARMParallelDSP.cpp +++ b/llvm/lib/Target/ARM/ARMParallelDSP.cpp @@ -48,7 +48,7 @@ DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),  namespace {    struct OpChain;    struct BinOpChain; -  struct Reduction; +  class Reduction;    using OpChainList     = SmallVector<std::unique_ptr<OpChain>, 8>;    using ReductionList   = SmallVector<Reduction, 8>; @@ -79,10 +79,8 @@ namespace {      unsigned size() const { return AllValues.size(); }    }; -  // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures. -  // 'Reduction' contains the phi-node and accumulator statement from where we -  // start pattern matching, and 'BinOpChain' the multiplication -  // instructions that are candidates for parallel execution. +  // 'BinOpChain' holds the multiplication instructions that are candidates +  // for parallel execution.    struct BinOpChain : public OpChain {      ValueList     LHS;      // List of all (narrow) left hand operands.      ValueList     RHS;      // List of all (narrow) right hand operands. @@ -97,15 +95,70 @@ namespace {      bool AreSymmetrical(BinOpChain *Other);    }; -  struct Reduction { -    PHINode         *Phi;             // The Phi-node from where we start -                                      // pattern matching. -    Instruction     *AccIntAdd;       // The accumulating integer add statement, -                                      // i.e, the reduction statement. -    OpChainList     MACCandidates;    // The MAC candidates associated with -                                      // this reduction statement. -    PMACPairList    PMACPairs; -    Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { }; +  /// Represent a sequence of multiply-accumulate operations with the aim to +  /// perform the multiplications in parallel. +  class Reduction { +    Instruction     *Root = nullptr; +    Value           *Acc = nullptr; +    OpChainList     Muls; +    PMACPairList        MulPairs; +    SmallPtrSet<Instruction*, 4> Adds; + +  public: +    Reduction() = delete; + +    Reduction (Instruction *Add) : Root(Add) { } + +    /// Record an Add instruction that is a part of the this reduction. +    void InsertAdd(Instruction *I) { Adds.insert(I); } + +    /// Record a BinOpChain, rooted at a Mul instruction, that is a part of +    /// this reduction. +    void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) { +      Muls.push_back(make_unique<BinOpChain>(I, LHS, RHS)); +    } + +    /// Add the incoming accumulator value, returns true if a value had not +    /// already been added. Returning false signals to the user that this +    /// reduction already has a value to initialise the accumulator. +    bool InsertAcc(Value *V) { +      if (Acc) +        return false; +      Acc = V; +      return true; +    } + +    /// Set two BinOpChains, rooted at muls, that can be executed as a single +    /// parallel operation. +    void AddMulPair(BinOpChain *Mul0, BinOpChain *Mul1) { +      MulPairs.push_back(std::make_pair(Mul0, Mul1)); +    } + +    /// Return true if enough mul operations are found that can be executed in +    /// parallel. +    bool CreateParallelPairs(); + +    /// Return the add instruction which is the root of the reduction. +    Instruction *getRoot() { return Root; } + +    /// Return the incoming value to be accumulated. This maybe null. +    Value *getAccumulator() { return Acc; } + +    /// Return the set of adds that comprise the reduction. +    SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; } + +    /// Return the BinOpChain, rooted at mul instruction, that comprise the +    /// the reduction. +    OpChainList &getMuls() { return Muls; } + +    /// Return the BinOpChain, rooted at mul instructions, that have been +    /// paired for parallel execution. +    PMACPairList &getMulPairs() { return MulPairs; } + +    /// To finalise, replace the uses of the root with the intrinsic call. +    void UpdateRoot(Instruction *SMLAD) { +      Root->replaceAllUsesWith(SMLAD); +    }    };    class WidenedLoad { @@ -133,25 +186,25 @@ namespace {      const DataLayout  *DL;      Module            *M;      std::map<LoadInst*, LoadInst*> LoadPairs; +    SmallPtrSet<LoadInst*, 4> OffsetLoads;      std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads; +    template<unsigned> +    bool IsNarrowSequence(Value *V, ValueList &VL); +      bool RecordMemoryOps(BasicBlock *BB); -    bool InsertParallelMACs(Reduction &Reduction); +    void InsertParallelMACs(Reduction &Reduction);      bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);      LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,                               IntegerType *LoadTy); -    void CreateParallelMACPairs(Reduction &R); -    Instruction *CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0, -                                 SmallVectorImpl<LoadInst*> &VecLd1, -                                 Instruction *Acc, bool Exchange, -                                 Instruction *InsertAfter); +    bool CreateParallelPairs(Reduction &R);      /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate      /// Dual performs two signed 16x16-bit multiplications. It adds the      /// products to a 32-bit accumulate operand. Optionally, the instruction can      /// exchange the halfwords of the second operand before performing the      /// arithmetic. -    bool MatchSMLAD(Function &F); +    bool MatchSMLAD(Loop *L);    public:      static char ID; @@ -201,11 +254,8 @@ namespace {          return false;        } -      // We need a preheader as getIncomingValueForBlock assumes there is one. -      if (!TheLoop->getLoopPreheader()) { -        LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n"); -        return false; -      } +      if (!TheLoop->getLoopPreheader()) +        InsertPreheaderForLoop(L, DT, LI, nullptr, true);        Function &F = *Header->getParent();        M = F.getParent(); @@ -242,7 +292,7 @@ namespace {          return false;        } -      bool Changes = MatchSMLAD(F); +      bool Changes = MatchSMLAD(L);        return Changes;      }    }; @@ -275,6 +325,51 @@ bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,    return true;  } +// MaxBitwidth: the maximum supported bitwidth of the elements in the DSP +// instructions, which is set to 16. So here we should collect all i8 and i16 +// narrow operations. +// TODO: we currently only collect i16, and will support i8 later, so that's +// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth. +template<unsigned MaxBitWidth> +bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) { +  ConstantInt *CInt; + +  if (match(V, m_ConstantInt(CInt))) { +    // TODO: if a constant is used, it needs to fit within the bit width. +    return false; +  } + +  auto *I = dyn_cast<Instruction>(V); +  if (!I) +    return false; + +  Value *Val, *LHS, *RHS; +  if (match(V, m_Trunc(m_Value(Val)))) { +    if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth) +      return IsNarrowSequence<MaxBitWidth>(Val, VL); +  } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) { +    // TODO: we need to implement sadd16/sadd8 for this, which enables to +    // also do the rewrite for smlad8.ll, but it is unsupported for now. +    return false; +  } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) { +    if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) +      return false; + +    if (match(Val, m_Load(m_Value()))) { +      auto *Ld = cast<LoadInst>(Val); + +      // Check that these load could be paired. +      if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld)) +        return false; + +      VL.push_back(Val); +      VL.push_back(I); +      return true; +    } +  } +  return false; +} +  /// Iterate through the block and record base, offset pairs of loads which can  /// be widened into a single load.  bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { @@ -342,6 +437,7 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {        if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&            SafeToPair(Base, Offset)) {          LoadPairs[Base] = Offset; +        OffsetLoads.insert(Offset);          break;        }      } @@ -357,15 +453,150 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {    return LoadPairs.size() > 1;  } -void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) { -  OpChainList &Candidates = R.MACCandidates; -  PMACPairList &PMACPairs = R.PMACPairs; -  const unsigned Elems = Candidates.size(); +// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector +// multiplications. +// To use SMLAD: +// 1) we first need to find integer add then look for this pattern: +// +// acc0 = ... +// ld0 = load i16 +// sext0 = sext i16 %ld0 to i32 +// ld1 = load i16 +// sext1 = sext i16 %ld1 to i32 +// mul0 = mul %sext0, %sext1 +// ld2 = load i16 +// sext2 = sext i16 %ld2 to i32 +// ld3 = load i16 +// sext3 = sext i16 %ld3 to i32 +// mul1 = mul i32 %sext2, %sext3 +// add0 = add i32 %mul0, %acc0 +// acc1 = add i32 %add0, %mul1 +// +// Which can be selected to: +// +// ldr r0 +// ldr r1 +// smlad r2, r0, r1, r2 +// +// If constants are used instead of loads, these will need to be hoisted +// out and into a register. +// +// If loop invariants are used instead of loads, these need to be packed +// before the loop begins. +// +bool ARMParallelDSP::MatchSMLAD(Loop *L) { +  // Search recursively back through the operands to find a tree of values that +  // form a multiply-accumulate chain. The search records the Add and Mul +  // instructions that form the reduction and allows us to find a single value +  // to be used as the initial input to the accumlator. +  std::function<bool(Value*, Reduction&)> Search = [&] +    (Value *V, Reduction &R) -> bool { + +    // If we find a non-instruction, try to use it as the initial accumulator +    // value. This may have already been found during the search in which case +    // this function will return false, signaling a search fail. +    auto *I = dyn_cast<Instruction>(V); +    if (!I) +      return R.InsertAcc(V); + +    switch (I->getOpcode()) { +    default: +      break; +    case Instruction::PHI: +      // Could be the accumulator value. +      return R.InsertAcc(V); +    case Instruction::Add: { +      // Adds should be adding together two muls, or another add and a mul to +      // be within the mac chain. One of the operands may also be the +      // accumulator value at which point we should stop searching. +      bool ValidLHS = Search(I->getOperand(0), R); +      bool ValidRHS = Search(I->getOperand(1), R); +      if (!ValidLHS && !ValidLHS) +        return false; +      else if (ValidLHS && ValidRHS) { +        R.InsertAdd(I); +        return true; +      } else { +        R.InsertAdd(I); +        return R.InsertAcc(I); +      } +    } +    case Instruction::Mul: { +      Value *MulOp0 = I->getOperand(0); +      Value *MulOp1 = I->getOperand(1); +      if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) { +        ValueList LHS; +        ValueList RHS; +        if (IsNarrowSequence<16>(MulOp0, LHS) && +            IsNarrowSequence<16>(MulOp1, RHS)) { +          R.InsertMul(I, LHS, RHS); +          return true; +        } +      } +      return false; +    } +    case Instruction::SExt: +      return Search(I->getOperand(0), R); +    } +    return false; +  }; + +  bool Changed = false; +  SmallPtrSet<Instruction*, 4> AllAdds; +  BasicBlock *Latch = L->getLoopLatch(); + +  for (Instruction &I : reverse(*Latch)) { +    if (I.getOpcode() != Instruction::Add) +      continue; + +    if (AllAdds.count(&I)) +      continue; + +    const auto *Ty = I.getType(); +    if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) +      continue; + +    Reduction R(&I); +    if (!Search(&I, R)) +      continue; + +    if (!CreateParallelPairs(R)) +      continue; + +    InsertParallelMACs(R); +    Changed = true; +    AllAdds.insert(R.getAdds().begin(), R.getAdds().end()); +  } + +  return Changed; +} + +bool ARMParallelDSP::CreateParallelPairs(Reduction &R) { + +  // Not enough mul operations to make a pair. +  if (R.getMuls().size() < 2) +    return false; + +  // Check that the muls operate directly upon sign extended loads. +  for (auto &MulChain : R.getMuls()) { +    // A mul has 2 operands, and a narrow op consist of sext and a load; thus +    // we expect at least 4 items in this operand value list. +    if (MulChain->size() < 4) { +      LLVM_DEBUG(dbgs() << "Operand list too short.\n"); +      return false; +    } +    MulChain->PopulateLoads(); +    ValueList &LHS = static_cast<BinOpChain*>(MulChain.get())->LHS; +    ValueList &RHS = static_cast<BinOpChain*>(MulChain.get())->RHS; -  if (Elems < 2) -    return; +    // Use +=2 to skip over the expected extend instructions. +    for (unsigned i = 0, e = LHS.size(); i < e; i += 2) { +      if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i])) +        return false; +    } +  } -  auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) { +  auto CanPair = [&](Reduction &R, BinOpChain *PMul0, BinOpChain *PMul1) {      if (!PMul0->AreSymmetrical(PMul1))        return false; @@ -391,13 +622,13 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {        if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {          if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {            LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n"); -          PMACPairs.push_back(std::make_pair(PMul0, PMul1)); +          R.AddMulPair(PMul0, PMul1);            return true;          } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {            LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");            LLVM_DEBUG(dbgs() << "    exchanging Ld2 and Ld3\n");            PMul1->Exchange = true; -          PMACPairs.push_back(std::make_pair(PMul0, PMul1)); +          R.AddMulPair(PMul0, PMul1);            return true;          }        } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) && @@ -407,16 +638,18 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {          LLVM_DEBUG(dbgs() << "    and swapping muls\n");          PMul0->Exchange = true;          // Only the second operand can be exchanged, so swap the muls. -        PMACPairs.push_back(std::make_pair(PMul1, PMul0)); +        R.AddMulPair(PMul1, PMul0);          return true;        }      }      return false;    }; +  OpChainList &Muls = R.getMuls(); +  const unsigned Elems = Muls.size();    SmallPtrSet<const Instruction*, 4> Paired;    for (unsigned i = 0; i < Elems; ++i) { -    BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get()); +    BinOpChain *PMul0 = static_cast<BinOpChain*>(Muls[i].get());      if (Paired.count(PMul0->Root))        continue; @@ -424,7 +657,7 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {        if (i == j)          continue; -      BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get()); +      BinOpChain *PMul1 = static_cast<BinOpChain*>(Muls[j].get());        if (Paired.count(PMul1->Root))          continue; @@ -435,199 +668,67 @@ void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {        assert(PMul0 != PMul1 && "expected different chains"); -      if (CanPair(PMul0, PMul1)) { +      if (CanPair(R, PMul0, PMul1)) {          Paired.insert(Mul0);          Paired.insert(Mul1);          break;        }      }    } +  return !R.getMulPairs().empty();  } -bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) { -  Instruction *Acc = Reduction.Phi; -  Instruction *InsertAfter = Reduction.AccIntAdd; -  for (auto &Pair : Reduction.PMACPairs) { +void ARMParallelDSP::InsertParallelMACs(Reduction &R) { + +  auto CreateSMLADCall = [&](SmallVectorImpl<LoadInst*> &VecLd0, +                             SmallVectorImpl<LoadInst*> &VecLd1, +                             Value *Acc, bool Exchange, +                             Instruction *InsertAfter) { +    // Replace the reduction chain with an intrinsic call +    IntegerType *Ty = IntegerType::get(M->getContext(), 32); +    LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ? +      WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty); +    LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ? +      WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty); + +    Value* Args[] = { WideLd0, WideLd1, Acc }; +    Function *SMLAD = nullptr; +    if (Exchange) +      SMLAD = Acc->getType()->isIntegerTy(32) ? +        Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) : +        Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx); +    else +      SMLAD = Acc->getType()->isIntegerTy(32) ? +        Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) : +        Intrinsic::getDeclaration(M, Intrinsic::arm_smlald); + +    IRBuilder<NoFolder> Builder(InsertAfter->getParent(), +                                ++BasicBlock::iterator(InsertAfter)); +    Instruction *Call = Builder.CreateCall(SMLAD, Args); +    NumSMLAD++; +    return Call; +  }; + +  Instruction *InsertAfter = R.getRoot(); +  Value *Acc = R.getAccumulator(); +  if (!Acc) +    Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0); + +  LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n" +             << "Acc: " << *Acc << "\n"); +  for (auto &Pair : R.getMulPairs()) {      BinOpChain *PMul0 = Pair.first;      BinOpChain *PMul1 = Pair.second; -    LLVM_DEBUG(dbgs() << "Found parallel MACs:\n" +    LLVM_DEBUG(dbgs() << "Muls:\n"                 << "- " << *PMul0->Root << "\n"                 << "- " << *PMul1->Root << "\n");      Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,                            InsertAfter); -    InsertAfter = Acc; -  } - -  if (Acc != Reduction.Phi) { -    LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump()); -    Reduction.AccIntAdd->replaceAllUsesWith(Acc); -    return true; +    InsertAfter = cast<Instruction>(Acc);    } -  return false; -} - -template<typename InstType, unsigned BitWidth> -bool IsExtendingLoad(Value *V) { -  auto *I = dyn_cast<InstType>(V); -  if (!I) -    return false; - -  if (I->getSrcTy()->getIntegerBitWidth() != BitWidth) -    return false; - -  return isa<LoadInst>(I->getOperand(0)); -} - -static void MatchParallelMACSequences(Reduction &R, -                                      OpChainList &Candidates) { -  Instruction *Acc = R.AccIntAdd; -  LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc << "\n"); - -  // Returns false to signal the search should be stopped. -  std::function<bool(Value*)> Match = -    [&Candidates, &Match](Value *V) -> bool { - -    auto *I = dyn_cast<Instruction>(V); -    if (!I) -      return false; - -    switch (I->getOpcode()) { -    case Instruction::Add: -      if (Match(I->getOperand(0)) || (Match(I->getOperand(1)))) -        return true; -      break; -    case Instruction::Mul: { -      Value *Op0 = I->getOperand(0); -      Value *Op1 = I->getOperand(1); -      if (IsExtendingLoad<SExtInst, 16>(Op0) && -          IsExtendingLoad<SExtInst, 16>(Op1)) { -        ValueList LHS = { cast<SExtInst>(Op0)->getOperand(0), Op0 }; -        ValueList RHS = { cast<SExtInst>(Op1)->getOperand(0), Op1 }; -        Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS)); -      } -      return false; -    } -    case Instruction::SExt: -      return Match(I->getOperand(0)); -    } -    return false; -  }; - -  while (Match (Acc)); -  LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found " -             << Candidates.size() << " candidates.\n"); -} - -static bool CheckMACMemory(OpChainList &Candidates) { -  for (auto &C : Candidates) { -    // A mul has 2 operands, and a narrow op consist of sext and a load; thus -    // we expect at least 4 items in this operand value list. -    if (C->size() < 4) { -      LLVM_DEBUG(dbgs() << "Operand list too short.\n"); -      return false; -    } -    C->PopulateLoads(); -    ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS; -    ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS; - -    // Use +=2 to skip over the expected extend instructions. -    for (unsigned i = 0, e = LHS.size(); i < e; i += 2) { -      if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i])) -        return false; -    } -  } -  return true; -} - -// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector -// multiplications. -// To use SMLAD: -// 1) we first need to find integer add reduction PHIs, -// 2) then from the PHI, look for this pattern: -// -// acc0 = phi i32 [0, %entry], [%acc1, %loop.body] -// ld0 = load i16 -// sext0 = sext i16 %ld0 to i32 -// ld1 = load i16 -// sext1 = sext i16 %ld1 to i32 -// mul0 = mul %sext0, %sext1 -// ld2 = load i16 -// sext2 = sext i16 %ld2 to i32 -// ld3 = load i16 -// sext3 = sext i16 %ld3 to i32 -// mul1 = mul i32 %sext2, %sext3 -// add0 = add i32 %mul0, %acc0 -// acc1 = add i32 %add0, %mul1 -// -// Which can be selected to: -// -// ldr.h r0 -// ldr.h r1 -// smlad r2, r0, r1, r2 -// -// If constants are used instead of loads, these will need to be hoisted -// out and into a register. -// -// If loop invariants are used instead of loads, these need to be packed -// before the loop begins. -// -bool ARMParallelDSP::MatchSMLAD(Function &F) { - -  auto FindReductions = [&](ReductionList &Reductions) { -    RecurrenceDescriptor RecDesc; -    const bool HasFnNoNaNAttr = -      F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; -    BasicBlock *Latch = L->getLoopLatch(); - -    for (PHINode &Phi : Latch->phis()) { -      const auto *Ty = Phi.getType(); -      if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) -        continue; - -      const bool IsReduction = RecurrenceDescriptor::AddReductionVar( -        &Phi, RecurrenceDescriptor::RK_IntegerAdd, L, HasFnNoNaNAttr, RecDesc); - -      if (!IsReduction) -        continue; - -      Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch)); -      if (!Acc) -        continue; - -      Reductions.push_back(Reduction(&Phi, Acc)); -    } -    return !Reductions.empty(); -  }; - -  ReductionList Reductions; -  if (!FindReductions(Reductions)) -    return false; - -  for (auto &R : Reductions) { -    OpChainList MACCandidates; -    MatchParallelMACSequences(R, MACCandidates); -    if (!CheckMACMemory(MACCandidates)) -      continue; - -    R.MACCandidates = std::move(MACCandidates); - -    LLVM_DEBUG(dbgs() << "MAC candidates:\n"; -      for (auto &M : R.MACCandidates) -        M->Root->dump(); -      dbgs() << "\n";); -  } - -  bool Changed = false; -  // Check whether statements in the basic block that write to memory alias -  // with the memory locations accessed by the MAC-chains. -  for (auto &R : Reductions) { -    CreateParallelMACPairs(R); -    Changed |= InsertParallelMACs(R); -  } - -  return Changed; +  R.UpdateRoot(cast<Instruction>(Acc));  }  LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads, @@ -696,43 +797,6 @@ LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,    return WideLoad;  } -Instruction *ARMParallelDSP::CreateSMLADCall(SmallVectorImpl<LoadInst*> &VecLd0, -                                             SmallVectorImpl<LoadInst*> &VecLd1, -                                             Instruction *Acc, bool Exchange, -                                             Instruction *InsertAfter) { -  LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n" -             << "- " << *VecLd0[0] << "\n" -             << "- " << *VecLd0[1] << "\n" -             << "- " << *VecLd1[0] << "\n" -             << "- " << *VecLd1[1] << "\n" -             << "- " << *Acc << "\n" -             << "- Exchange: " << Exchange << "\n"); - -  // Replace the reduction chain with an intrinsic call -  IntegerType *Ty = IntegerType::get(M->getContext(), 32); -  LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ? -    WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty); -  LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ? -    WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty); - -  Value* Args[] = { WideLd0, WideLd1, Acc }; -  Function *SMLAD = nullptr; -  if (Exchange) -    SMLAD = Acc->getType()->isIntegerTy(32) ? -      Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) : -      Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx); -  else -    SMLAD = Acc->getType()->isIntegerTy(32) ? -      Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) : -      Intrinsic::getDeclaration(M, Intrinsic::arm_smlald); - -  IRBuilder<NoFolder> Builder(InsertAfter->getParent(), -                              ++BasicBlock::iterator(InsertAfter)); -  CallInst *Call = Builder.CreateCall(SMLAD, Args); -  NumSMLAD++; -  return Call; -} -  // Compare the value lists in Other to this chain.  bool BinOpChain::AreSymmetrical(BinOpChain *Other) {    // Element-by-element comparison of Value lists returning true if they are diff --git a/llvm/test/CodeGen/ARM/ParallelDSP/aliasing.ll b/llvm/test/CodeGen/ARM/ParallelDSP/aliasing.ll index 47047c7f44b..4edf5bfbbef 100644 --- a/llvm/test/CodeGen/ARM/ParallelDSP/aliasing.ll +++ b/llvm/test/CodeGen/ARM/ParallelDSP/aliasing.ll @@ -451,8 +451,10 @@ for.body:    br i1 %exitcond, label %for.body, label %for.cond.cleanup  } +; TODO: I think we should be able to generate one smlad here. The search fails +; when it finds the alias.  ; CHECK-LABEL: one_pair_alias -; FIXME: This tests shows we have a bug with smlad insertion +; CHECK-NOT: call i32 @llvm.arm.smlad  define i32 @one_pair_alias(i16* noalias nocapture readonly %b, i16* noalias nocapture readonly %c) {  entry:    br label %for.body diff --git a/llvm/test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll b/llvm/test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll new file mode 100644 index 00000000000..052fb51a8dd --- /dev/null +++ b/llvm/test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll @@ -0,0 +1,151 @@ +; RUN: opt -mtriple=thumbv7em -arm-parallel-dsp -dce -S %s -o - | FileCheck %s + +; CHECK-LABEL: full_unroll +; CHECK: [[IV:%[^ ]+]] = phi i32 +; CHECK: [[AI:%[^ ]+]] = getelementptr inbounds i32, i32* %a, i32 [[IV]] +; CHECK: [[BI:%[^ ]+]] = getelementptr inbounds i16*, i16** %b, i32 [[IV]] +; CHECK: [[BIJ:%[^ ]+]] = load i16*, i16** %arrayidx5, align 4 +; CHECK: [[CI:%[^ ]+]] = getelementptr inbounds i16*, i16** %c, i32 [[IV]] +; CHECK: [[CIJ:%[^ ]+]] = load i16*, i16** [[CI]], align 4 +; CHECK: [[BIJ_CAST:%[^ ]+]] = bitcast i16* [[BIJ]] to i32* +; CHECK: [[BIJ_LD:%[^ ]+]] = load i32, i32* [[BIJ_CAST]], align 2 +; CHECK: [[CIJ_CAST:%[^ ]+]] = bitcast i16* [[CIJ]] to i32* +; CHECK: [[CIJ_LD:%[^ ]+]] = load i32, i32* [[CIJ_CAST]], align 2 +; CHECK: [[BIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[BIJ]], i32 2 +; CHECK: [[BIJ_2_CAST:%[^ ]+]] = bitcast i16* [[BIJ_2]] to i32* +; CHECK: [[BIJ_2_LD:%[^ ]+]] = load i32, i32* [[BIJ_2_CAST]], align 2 +; CHECK: [[CIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[CIJ]], i32 2 +; CHECK: [[CIJ_2_CAST:%[^ ]+]] = bitcast i16* [[CIJ_2]] to i32* +; CHECK: [[CIJ_2_LD:%[^ ]+]] = load i32, i32* [[CIJ_2_CAST]], align 2 +; CHECK: [[SMLAD0:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_2_LD]], i32 [[BIJ_2_LD]], i32 0) +; CHECK: [[SMLAD1:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_LD]], i32 [[BIJ_LD]], i32 [[SMLAD0]]) +; CHECK: store i32 [[SMLAD1]], i32* %arrayidx, align 4 + +define void @full_unroll(i32* noalias nocapture %a, i16** noalias nocapture readonly %b, i16** noalias nocapture readonly %c, i32 %N) { +entry: +  %cmp29 = icmp eq i32 %N, 0 +  br i1 %cmp29, label %for.cond.cleanup, label %for.body + +for.cond.cleanup:                                 ; preds = %for.body, %entry +  ret void + +for.body:                                         ; preds = %entry, %for.body +  %i.030 = phi i32 [ %inc12, %for.body ], [ 0, %entry ] +  %arrayidx = getelementptr inbounds i32, i32* %a, i32 %i.030 +  %arrayidx5 = getelementptr inbounds i16*, i16** %b, i32 %i.030 +  %0 = load i16*, i16** %arrayidx5, align 4 +  %arrayidx7 = getelementptr inbounds i16*, i16** %c, i32 %i.030 +  %1 = load i16*, i16** %arrayidx7, align 4 +  %2 = load i16, i16* %0, align 2 +  %conv = sext i16 %2 to i32 +  %3 = load i16, i16* %1, align 2 +  %conv9 = sext i16 %3 to i32 +  %mul = mul nsw i32 %conv9, %conv +  %arrayidx6.1 = getelementptr inbounds i16, i16* %0, i32 1 +  %4 = load i16, i16* %arrayidx6.1, align 2 +  %conv.1 = sext i16 %4 to i32 +  %arrayidx8.1 = getelementptr inbounds i16, i16* %1, i32 1 +  %5 = load i16, i16* %arrayidx8.1, align 2 +  %conv9.1 = sext i16 %5 to i32 +  %mul.1 = mul nsw i32 %conv9.1, %conv.1 +  %add.1 = add nsw i32 %mul.1, %mul +  %arrayidx6.2 = getelementptr inbounds i16, i16* %0, i32 2 +  %6 = load i16, i16* %arrayidx6.2, align 2 +  %conv.2 = sext i16 %6 to i32 +  %arrayidx8.2 = getelementptr inbounds i16, i16* %1, i32 2 +  %7 = load i16, i16* %arrayidx8.2, align 2 +  %conv9.2 = sext i16 %7 to i32 +  %mul.2 = mul nsw i32 %conv9.2, %conv.2 +  %add.2 = add nsw i32 %mul.2, %add.1 +  %arrayidx6.3 = getelementptr inbounds i16, i16* %0, i32 3 +  %8 = load i16, i16* %arrayidx6.3, align 2 +  %conv.3 = sext i16 %8 to i32 +  %arrayidx8.3 = getelementptr inbounds i16, i16* %1, i32 3 +  %9 = load i16, i16* %arrayidx8.3, align 2 +  %conv9.3 = sext i16 %9 to i32 +  %mul.3 = mul nsw i32 %conv9.3, %conv.3 +  %add.3 = add nsw i32 %mul.3, %add.2 +  store i32 %add.3, i32* %arrayidx, align 4 +  %inc12 = add nuw i32 %i.030, 1 +  %exitcond = icmp eq i32 %inc12, %N +  br i1 %exitcond, label %for.cond.cleanup, label %for.body +} + +; CHECK-LABEL: full_unroll_sub +; CHEC: [[IV:%[^ ]+]] = phi i32 +; CHECK: [[AI:%[^ ]+]] = getelementptr inbounds i32, i32* %a, i32 [[IV]] +; CHECK: [[BI:%[^ ]+]] = getelementptr inbounds i16*, i16** %b, i32 [[IV]] +; CHECK: [[BIJ:%[^ ]+]] = load i16*, i16** [[BI]], align 4 +; CHECK: [[CI:%[^ ]+]] = getelementptr inbounds i16*, i16** %c, i32 [[IV]] +; CHECK: [[CIJ:%[^ ]+]] = load i16*, i16** [[CI]], align 4 +; CHECK: [[BIJ_LD:%[^ ]+]] = load i16, i16* [[BIJ]], align 2 +; CHECK: [[BIJ_LD_SXT:%[^ ]+]] = sext i16 [[BIJ_LD]] to i32 +; CHECK: [[CIJ_LD:%[^ ]+]] = load i16, i16* [[CIJ]], align 2 +; CHECK: [[CIJ_LD_SXT:%[^ ]+]] = sext i16 [[CIJ_LD]] to i32 +; CHECK: [[SUB:%[^ ]+]] = sub nsw i32 [[CIJ_LD_SXT]], [[BIJ_LD_SXT]] +; CHECK: [[BIJ_1:%[^ ]+]] = getelementptr inbounds i16, i16* [[BIJ]], i32 1 +; CHECK: [[BIJ_1_LD:%[^ ]+]] = load i16, i16* [[BIJ_1]], align 2 +; CHECK: [[BIJ_1_LD_SXT:%[^ ]+]] = sext i16 [[BIJ_1_LD]] to i32 +; CHECK: [[CIJ_1:%[^ ]+]] = getelementptr inbounds i16, i16* [[CIJ]], i32 1 +; CHECK: [[CIJ_1_LD:%[^ ]+]] = load i16, i16* [[CIJ_1]], align 2 +; CHECK: [[CIJ_1_LD_SXT:%[^ ]+]] = sext i16 [[CIJ_1_LD]] to i32 +; CHECK: [[MUL:%[^ ]+]] = mul nsw i32 [[CIJ_1_LD_SXT]], [[BIJ_1_LD_SXT]] +; CHECK: [[ACC:%[^ ]+]] = add nsw i32 [[MUL]], [[SUB]] +; CHECK: [[BIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[BIJ]], i32 2 +; CHECK: [[BIJ_2_CAST:%[^ ]+]] = bitcast i16* [[BIJ_2]] to i32* +; CHECK: [[BIJ_2_LD:%[^ ]+]] = load i32, i32* [[BIJ_2_CAST]], align 2 +; CHECK: [[CIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[CIJ]], i32 2 +; CHECK: [[CIJ_2_CAST:%[^ ]+]] = bitcast i16* [[CIJ_2]] to i32* +; CHECK: [[CIJ_2_LD:%[^ ]+]] = load i32, i32* [[CIJ_2_CAST]], align 2 +; CHECK: [[SMLAD0:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_2_LD]], i32 [[BIJ_2_LD]], i32 [[ACC]]) +; CHECK: store i32 [[SMLAD0]], i32* %arrayidx, align 4 + +define void @full_unroll_sub(i32* noalias nocapture %a, i16** noalias nocapture readonly %b, i16** noalias nocapture readonly %c, i32 %N) { +entry: +  %cmp29 = icmp eq i32 %N, 0 +  br i1 %cmp29, label %for.cond.cleanup, label %for.body + +for.cond.cleanup:                                 ; preds = %for.body, %entry +  ret void + +for.body:                                         ; preds = %entry, %for.body +  %i.030 = phi i32 [ %inc12, %for.body ], [ 0, %entry ] +  %arrayidx = getelementptr inbounds i32, i32* %a, i32 %i.030 +  %arrayidx5 = getelementptr inbounds i16*, i16** %b, i32 %i.030 +  %0 = load i16*, i16** %arrayidx5, align 4 +  %arrayidx7 = getelementptr inbounds i16*, i16** %c, i32 %i.030 +  %1 = load i16*, i16** %arrayidx7, align 4 +  %2 = load i16, i16* %0, align 2 +  %conv = sext i16 %2 to i32 +  %3 = load i16, i16* %1, align 2 +  %conv9 = sext i16 %3 to i32 +  %sub = sub nsw i32 %conv9, %conv +  %arrayidx6.1 = getelementptr inbounds i16, i16* %0, i32 1 +  %4 = load i16, i16* %arrayidx6.1, align 2 +  %conv.1 = sext i16 %4 to i32 +  %arrayidx8.1 = getelementptr inbounds i16, i16* %1, i32 1 +  %5 = load i16, i16* %arrayidx8.1, align 2 +  %conv9.1 = sext i16 %5 to i32 +  %mul.1 = mul nsw i32 %conv9.1, %conv.1 +  %add.1 = add nsw i32 %mul.1, %sub +  %arrayidx6.2 = getelementptr inbounds i16, i16* %0, i32 2 +  %6 = load i16, i16* %arrayidx6.2, align 2 +  %conv.2 = sext i16 %6 to i32 +  %arrayidx8.2 = getelementptr inbounds i16, i16* %1, i32 2 +  %7 = load i16, i16* %arrayidx8.2, align 2 +  %conv9.2 = sext i16 %7 to i32 +  %mul.2 = mul nsw i32 %conv9.2, %conv.2 +  %add.2 = add nsw i32 %mul.2, %add.1 +  %arrayidx6.3 = getelementptr inbounds i16, i16* %0, i32 3 +  %8 = load i16, i16* %arrayidx6.3, align 2 +  %conv.3 = sext i16 %8 to i32 +  %arrayidx8.3 = getelementptr inbounds i16, i16* %1, i32 3 +  %9 = load i16, i16* %arrayidx8.3, align 2 +  %conv9.3 = sext i16 %9 to i32 +  %mul.3 = mul nsw i32 %conv9.3, %conv.3 +  %add.3 = add nsw i32 %mul.3, %add.2 +  store i32 %add.3, i32* %arrayidx, align 4 +  %inc12 = add nuw i32 %i.030, 1 +  %exitcond = icmp eq i32 %inc12, %N +  br i1 %exitcond, label %for.cond.cleanup, label %for.body +}  | 

