diff options
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r-- | llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 276 |
1 files changed, 276 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 8bfb79ad41a..9a744956be6 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/CodeGen/MachineFrameInfo.h" @@ -375,6 +376,7 @@ namespace { unsigned PosOpcode, unsigned NegOpcode, const SDLoc &DL); SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); + SDValue MatchLoadCombine(SDNode *N); SDValue ReduceLoadWidth(SDNode *N); SDValue ReduceLoadOpStoreWidth(SDNode *N); SDValue splitMergedValStore(StoreSDNode *ST); @@ -3969,6 +3971,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) { if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) return SDValue(Rot, 0); + if (SDValue Load = MatchLoadCombine(N)) + return Load; + // Simplify the operands using demanded-bits information. if (!VT.isVector() && SimplifyDemandedBits(SDValue(N, 0))) @@ -4340,6 +4345,277 @@ struct BaseIndexOffset { }; } // namespace +namespace { +/// Represents the origin of an individual byte in load combine pattern. The +/// value of the byte is either unknown, zero or comes from memory. +struct ByteProvider { + enum ProviderTy { + Unknown, + ZeroConstant, + Memory + }; + + ProviderTy Kind; + // Load and ByteOffset are set for Memory providers only. + // Load represents the node which loads the byte from memory. + // ByteOffset is the offset of the byte in the value produced by the load. + LoadSDNode *Load; + unsigned ByteOffset; + + ByteProvider() : Kind(ProviderTy::Unknown), Load(nullptr), ByteOffset(0) {} + + static ByteProvider getUnknown() { + return ByteProvider(ProviderTy::Unknown, nullptr, 0); + } + static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) { + return ByteProvider(ProviderTy::Memory, Load, ByteOffset); + } + static ByteProvider getZero() { + return ByteProvider(ProviderTy::ZeroConstant, nullptr, 0); + } + + bool operator==(const ByteProvider &Other) const { + return Other.Kind == Kind && Other.Load == Load && + Other.ByteOffset == ByteOffset; + } + +private: + ByteProvider(ProviderTy Kind, LoadSDNode *Load, unsigned ByteOffset) + : Kind(Kind), Load(Load), ByteOffset(ByteOffset) {} +}; + +/// Recursively traverses the expression collecting the origin of individual +/// bytes of the given value. For all the values except the root of the +/// expression verifies that it doesn't have uses outside of the expression. +const Optional<SmallVector<ByteProvider, 4> > +collectByteProviders(SDValue Op, bool CheckNumberOfUses = false) { + if (CheckNumberOfUses && !Op.hasOneUse()) + return None; + + unsigned BitWidth = Op.getScalarValueSizeInBits(); + if (BitWidth % 8 != 0) + return None; + unsigned ByteWidth = BitWidth / 8; + + switch (Op.getOpcode()) { + case ISD::OR: { + auto LHS = collectByteProviders(Op->getOperand(0), + /*CheckNumberOfUses=*/true); + auto RHS = collectByteProviders(Op->getOperand(1), + /*CheckNumberOfUses=*/true); + if (!LHS || !RHS) + return None; + + auto OR = [](ByteProvider LHS, ByteProvider RHS) { + if (LHS == RHS) + return LHS; + if (LHS.Kind == ByteProvider::Unknown || + RHS.Kind == ByteProvider::Unknown) + return ByteProvider::getUnknown(); + if (LHS.Kind == ByteProvider::Memory && RHS.Kind == ByteProvider::Memory) + return ByteProvider::getUnknown(); + + if (LHS.Kind == ByteProvider::Memory) + return LHS; + else + return RHS; + }; + + SmallVector<ByteProvider, 4> Result(ByteWidth); + for (unsigned i = 0; i < LHS->size(); i++) + Result[i] = OR(LHS.getValue()[i], RHS.getValue()[i]); + + return Result; + } + case ISD::SHL: { + auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1)); + if (!ShiftOp) + return None; + + uint64_t BitShift = ShiftOp->getZExtValue(); + if (BitShift % 8 != 0) + return None; + uint64_t ByteShift = BitShift / 8; + + auto Original = collectByteProviders(Op->getOperand(0), + /*CheckNumberOfUses=*/true); + if (!Original) + return None; + + SmallVector<ByteProvider, 4> Result; + Result.insert(Result.begin(), ByteShift, ByteProvider::getZero()); + Result.insert(Result.end(), Original->begin(), + std::prev(Original->end(), ByteShift)); + assert(Result.size() == ByteWidth && "sanity"); + return Result; + } + case ISD::ZERO_EXTEND: { + auto Original = collectByteProviders(Op->getOperand(0), + /*CheckNumberOfUses=*/true); + if (!Original) + return None; + + SmallVector<ByteProvider, 4> Result; + unsigned NarrowByteWidth = Original->size(); + Result.insert(Result.begin(), Original->begin(), Original->end()); + Result.insert(Result.end(), ByteWidth - NarrowByteWidth, + ByteProvider::getZero()); + assert(Result.size() == ByteWidth && "sanity"); + return Result; + } + case ISD::LOAD: { + auto L = cast<LoadSDNode>(Op.getNode()); + if (L->isVolatile() || L->isIndexed() || + L->getExtensionType() != ISD::NON_EXTLOAD) + return None; + + EVT VT = L->getMemoryVT(); + assert(BitWidth == VT.getSizeInBits() && "sanity"); + + SmallVector<ByteProvider, 4> Result(ByteWidth); + for (unsigned i = 0; i < ByteWidth; i++) + Result[i] = ByteProvider::getMemory(L, i); + + return Result; + } + } + + return None; +} +} // namespace + +/// Match a pattern where a wide type scalar value is loaded by several narrow +/// loads and combined by shifts and ors. Fold it into a single load or a load +/// and a BSWAP if the targets supports it. +/// +/// Assuming little endian target: +/// i8 *a = ... +/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) +/// => +/// i32 val = *((i32)a) +/// +/// i8 *a = ... +/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] +/// => +/// i32 val = BSWAP(*((i32)a)) +SDValue DAGCombiner::MatchLoadCombine(SDNode *N) { + assert(N->getOpcode() == ISD::OR && + "Can only match load combining against OR nodes"); + + // Handles simple types only + EVT VT = N->getValueType(0); + if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64) + return SDValue(); + + // There is nothing to do here if the target can't load a value of this type + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (!TLI.isOperationLegal(ISD::LOAD, VT)) + return SDValue(); + + // Calculate byte providers for the OR we are looking at + auto Res = collectByteProviders(SDValue(N, 0)); + if (!Res) + return SDValue(); + auto &Bytes = Res.getValue(); + unsigned ByteWidth = Bytes.size(); + assert(VT.getSizeInBits() == ByteWidth * 8 && "sanity"); + + auto LittleEndianByteAt = [](unsigned BW, unsigned i) { return i; }; + auto BigEndianByteAt = [](unsigned BW, unsigned i) { return BW - i - 1; }; + + Optional<BaseIndexOffset> Base; + SDValue Chain; + + SmallSet<LoadSDNode *, 8> Loads; + LoadSDNode *FirstLoad = nullptr; + + // Check if all the bytes of the OR we are looking at are loaded from the same + // base address. Collect bytes offsets from Base address in ByteOffsets. + SmallVector<int64_t, 4> ByteOffsets(ByteWidth); + for (unsigned i = 0; i < ByteWidth; i++) { + // All the bytes must be loaded from memory + if (Bytes[i].Kind != ByteProvider::Memory) + return SDValue(); + + LoadSDNode *L = Bytes[i].Load; + assert(L->hasNUsesOfValue(1, 0) && !L->isVolatile() && !L->isIndexed() && + (L->getExtensionType() == ISD::NON_EXTLOAD) && + "Must be enforced by collectByteProviders"); + assert(L->getOffset().isUndef() && "Unindexed load must have undef offset"); + + // All loads must share the same chain + SDValue LChain = L->getChain(); + if (!Chain) + Chain = LChain; + if (Chain != LChain) + return SDValue(); + + // Loads must share the same base address + BaseIndexOffset Ptr = BaseIndexOffset::match(L->getBasePtr(), DAG); + if (!Base) + Base = Ptr; + if (!Base->equalBaseIndex(Ptr)) + return SDValue(); + + // Calculate the offset of the current byte from the base address + unsigned LoadByteWidth = L->getMemoryVT().getSizeInBits() / 8; + int64_t MemoryByteOffset = + DAG.getDataLayout().isBigEndian() + ? BigEndianByteAt(LoadByteWidth, Bytes[i].ByteOffset) + : LittleEndianByteAt(LoadByteWidth, Bytes[i].ByteOffset); + int64_t ByteOffsetFromBase = Ptr.Offset + MemoryByteOffset; + ByteOffsets[i] = ByteOffsetFromBase; + + // Remember the first byte load + if (ByteOffsetFromBase == 0) + FirstLoad = L; + + Loads.insert(L); + } + assert(Base && "must be set"); + + // Check if the bytes of the OR we are looking at match with either big or + // little endian value load + bool BigEndian = true, LittleEndian = true; + for (unsigned i = 0; i < ByteWidth; i++) { + LittleEndian &= ByteOffsets[i] == LittleEndianByteAt(ByteWidth, i); + BigEndian &= ByteOffsets[i] == BigEndianByteAt(ByteWidth, i); + if (!BigEndian && !LittleEndian) + return SDValue(); + } + assert((BigEndian != LittleEndian) && "should be either or"); + assert(FirstLoad && "must be set"); + + // The node we are looking at matches with the pattern, check if we can + // replace it with a single load and bswap if needed. + + // If the load needs byte swap check if the target supports it + bool NeedsBswap = DAG.getDataLayout().isBigEndian() != BigEndian; + if (NeedsBswap && !TLI.isOperationLegal(ISD::BSWAP, VT)) + return SDValue(); + + // Check that a load of the wide type is both allowed and fast on the target + bool Fast = false; + bool Allowed = TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), + VT, FirstLoad->getAddressSpace(), + FirstLoad->getAlignment(), &Fast); + if (!Allowed || !Fast) + return SDValue(); + + SDValue NewLoad = + DAG.getLoad(VT, SDLoc(N), Chain, FirstLoad->getBasePtr(), + FirstLoad->getPointerInfo(), FirstLoad->getAlignment()); + + // Transfer chain users from old loads to the new load. + for (LoadSDNode *L : Loads) + DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1)); + + if (NeedsBswap) + return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad); + else + return NewLoad; +} + SDValue DAGCombiner::visitXOR(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); |