summaryrefslogtreecommitdiffstats
path: root/llvm/lib/CodeGen
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp276
1 files changed, 276 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 8bfb79ad41a..9a744956be6 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -20,6 +20,7 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
@@ -375,6 +376,7 @@ namespace {
unsigned PosOpcode, unsigned NegOpcode,
const SDLoc &DL);
SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
+ SDValue MatchLoadCombine(SDNode *N);
SDValue ReduceLoadWidth(SDNode *N);
SDValue ReduceLoadOpStoreWidth(SDNode *N);
SDValue splitMergedValStore(StoreSDNode *ST);
@@ -3969,6 +3971,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N)))
return SDValue(Rot, 0);
+ if (SDValue Load = MatchLoadCombine(N))
+ return Load;
+
// Simplify the operands using demanded-bits information.
if (!VT.isVector() &&
SimplifyDemandedBits(SDValue(N, 0)))
@@ -4340,6 +4345,277 @@ struct BaseIndexOffset {
};
} // namespace
+namespace {
+/// Represents the origin of an individual byte in load combine pattern. The
+/// value of the byte is either unknown, zero or comes from memory.
+struct ByteProvider {
+ enum ProviderTy {
+ Unknown,
+ ZeroConstant,
+ Memory
+ };
+
+ ProviderTy Kind;
+ // Load and ByteOffset are set for Memory providers only.
+ // Load represents the node which loads the byte from memory.
+ // ByteOffset is the offset of the byte in the value produced by the load.
+ LoadSDNode *Load;
+ unsigned ByteOffset;
+
+ ByteProvider() : Kind(ProviderTy::Unknown), Load(nullptr), ByteOffset(0) {}
+
+ static ByteProvider getUnknown() {
+ return ByteProvider(ProviderTy::Unknown, nullptr, 0);
+ }
+ static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
+ return ByteProvider(ProviderTy::Memory, Load, ByteOffset);
+ }
+ static ByteProvider getZero() {
+ return ByteProvider(ProviderTy::ZeroConstant, nullptr, 0);
+ }
+
+ bool operator==(const ByteProvider &Other) const {
+ return Other.Kind == Kind && Other.Load == Load &&
+ Other.ByteOffset == ByteOffset;
+ }
+
+private:
+ ByteProvider(ProviderTy Kind, LoadSDNode *Load, unsigned ByteOffset)
+ : Kind(Kind), Load(Load), ByteOffset(ByteOffset) {}
+};
+
+/// Recursively traverses the expression collecting the origin of individual
+/// bytes of the given value. For all the values except the root of the
+/// expression verifies that it doesn't have uses outside of the expression.
+const Optional<SmallVector<ByteProvider, 4> >
+collectByteProviders(SDValue Op, bool CheckNumberOfUses = false) {
+ if (CheckNumberOfUses && !Op.hasOneUse())
+ return None;
+
+ unsigned BitWidth = Op.getScalarValueSizeInBits();
+ if (BitWidth % 8 != 0)
+ return None;
+ unsigned ByteWidth = BitWidth / 8;
+
+ switch (Op.getOpcode()) {
+ case ISD::OR: {
+ auto LHS = collectByteProviders(Op->getOperand(0),
+ /*CheckNumberOfUses=*/true);
+ auto RHS = collectByteProviders(Op->getOperand(1),
+ /*CheckNumberOfUses=*/true);
+ if (!LHS || !RHS)
+ return None;
+
+ auto OR = [](ByteProvider LHS, ByteProvider RHS) {
+ if (LHS == RHS)
+ return LHS;
+ if (LHS.Kind == ByteProvider::Unknown ||
+ RHS.Kind == ByteProvider::Unknown)
+ return ByteProvider::getUnknown();
+ if (LHS.Kind == ByteProvider::Memory && RHS.Kind == ByteProvider::Memory)
+ return ByteProvider::getUnknown();
+
+ if (LHS.Kind == ByteProvider::Memory)
+ return LHS;
+ else
+ return RHS;
+ };
+
+ SmallVector<ByteProvider, 4> Result(ByteWidth);
+ for (unsigned i = 0; i < LHS->size(); i++)
+ Result[i] = OR(LHS.getValue()[i], RHS.getValue()[i]);
+
+ return Result;
+ }
+ case ISD::SHL: {
+ auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
+ if (!ShiftOp)
+ return None;
+
+ uint64_t BitShift = ShiftOp->getZExtValue();
+ if (BitShift % 8 != 0)
+ return None;
+ uint64_t ByteShift = BitShift / 8;
+
+ auto Original = collectByteProviders(Op->getOperand(0),
+ /*CheckNumberOfUses=*/true);
+ if (!Original)
+ return None;
+
+ SmallVector<ByteProvider, 4> Result;
+ Result.insert(Result.begin(), ByteShift, ByteProvider::getZero());
+ Result.insert(Result.end(), Original->begin(),
+ std::prev(Original->end(), ByteShift));
+ assert(Result.size() == ByteWidth && "sanity");
+ return Result;
+ }
+ case ISD::ZERO_EXTEND: {
+ auto Original = collectByteProviders(Op->getOperand(0),
+ /*CheckNumberOfUses=*/true);
+ if (!Original)
+ return None;
+
+ SmallVector<ByteProvider, 4> Result;
+ unsigned NarrowByteWidth = Original->size();
+ Result.insert(Result.begin(), Original->begin(), Original->end());
+ Result.insert(Result.end(), ByteWidth - NarrowByteWidth,
+ ByteProvider::getZero());
+ assert(Result.size() == ByteWidth && "sanity");
+ return Result;
+ }
+ case ISD::LOAD: {
+ auto L = cast<LoadSDNode>(Op.getNode());
+ if (L->isVolatile() || L->isIndexed() ||
+ L->getExtensionType() != ISD::NON_EXTLOAD)
+ return None;
+
+ EVT VT = L->getMemoryVT();
+ assert(BitWidth == VT.getSizeInBits() && "sanity");
+
+ SmallVector<ByteProvider, 4> Result(ByteWidth);
+ for (unsigned i = 0; i < ByteWidth; i++)
+ Result[i] = ByteProvider::getMemory(L, i);
+
+ return Result;
+ }
+ }
+
+ return None;
+}
+} // namespace
+
+/// Match a pattern where a wide type scalar value is loaded by several narrow
+/// loads and combined by shifts and ors. Fold it into a single load or a load
+/// and a BSWAP if the targets supports it.
+///
+/// Assuming little endian target:
+/// i8 *a = ...
+/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
+/// =>
+/// i32 val = *((i32)a)
+///
+/// i8 *a = ...
+/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
+/// =>
+/// i32 val = BSWAP(*((i32)a))
+SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
+ assert(N->getOpcode() == ISD::OR &&
+ "Can only match load combining against OR nodes");
+
+ // Handles simple types only
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+ return SDValue();
+
+ // There is nothing to do here if the target can't load a value of this type
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!TLI.isOperationLegal(ISD::LOAD, VT))
+ return SDValue();
+
+ // Calculate byte providers for the OR we are looking at
+ auto Res = collectByteProviders(SDValue(N, 0));
+ if (!Res)
+ return SDValue();
+ auto &Bytes = Res.getValue();
+ unsigned ByteWidth = Bytes.size();
+ assert(VT.getSizeInBits() == ByteWidth * 8 && "sanity");
+
+ auto LittleEndianByteAt = [](unsigned BW, unsigned i) { return i; };
+ auto BigEndianByteAt = [](unsigned BW, unsigned i) { return BW - i - 1; };
+
+ Optional<BaseIndexOffset> Base;
+ SDValue Chain;
+
+ SmallSet<LoadSDNode *, 8> Loads;
+ LoadSDNode *FirstLoad = nullptr;
+
+ // Check if all the bytes of the OR we are looking at are loaded from the same
+ // base address. Collect bytes offsets from Base address in ByteOffsets.
+ SmallVector<int64_t, 4> ByteOffsets(ByteWidth);
+ for (unsigned i = 0; i < ByteWidth; i++) {
+ // All the bytes must be loaded from memory
+ if (Bytes[i].Kind != ByteProvider::Memory)
+ return SDValue();
+
+ LoadSDNode *L = Bytes[i].Load;
+ assert(L->hasNUsesOfValue(1, 0) && !L->isVolatile() && !L->isIndexed() &&
+ (L->getExtensionType() == ISD::NON_EXTLOAD) &&
+ "Must be enforced by collectByteProviders");
+ assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
+
+ // All loads must share the same chain
+ SDValue LChain = L->getChain();
+ if (!Chain)
+ Chain = LChain;
+ if (Chain != LChain)
+ return SDValue();
+
+ // Loads must share the same base address
+ BaseIndexOffset Ptr = BaseIndexOffset::match(L->getBasePtr(), DAG);
+ if (!Base)
+ Base = Ptr;
+ if (!Base->equalBaseIndex(Ptr))
+ return SDValue();
+
+ // Calculate the offset of the current byte from the base address
+ unsigned LoadByteWidth = L->getMemoryVT().getSizeInBits() / 8;
+ int64_t MemoryByteOffset =
+ DAG.getDataLayout().isBigEndian()
+ ? BigEndianByteAt(LoadByteWidth, Bytes[i].ByteOffset)
+ : LittleEndianByteAt(LoadByteWidth, Bytes[i].ByteOffset);
+ int64_t ByteOffsetFromBase = Ptr.Offset + MemoryByteOffset;
+ ByteOffsets[i] = ByteOffsetFromBase;
+
+ // Remember the first byte load
+ if (ByteOffsetFromBase == 0)
+ FirstLoad = L;
+
+ Loads.insert(L);
+ }
+ assert(Base && "must be set");
+
+ // Check if the bytes of the OR we are looking at match with either big or
+ // little endian value load
+ bool BigEndian = true, LittleEndian = true;
+ for (unsigned i = 0; i < ByteWidth; i++) {
+ LittleEndian &= ByteOffsets[i] == LittleEndianByteAt(ByteWidth, i);
+ BigEndian &= ByteOffsets[i] == BigEndianByteAt(ByteWidth, i);
+ if (!BigEndian && !LittleEndian)
+ return SDValue();
+ }
+ assert((BigEndian != LittleEndian) && "should be either or");
+ assert(FirstLoad && "must be set");
+
+ // The node we are looking at matches with the pattern, check if we can
+ // replace it with a single load and bswap if needed.
+
+ // If the load needs byte swap check if the target supports it
+ bool NeedsBswap = DAG.getDataLayout().isBigEndian() != BigEndian;
+ if (NeedsBswap && !TLI.isOperationLegal(ISD::BSWAP, VT))
+ return SDValue();
+
+ // Check that a load of the wide type is both allowed and fast on the target
+ bool Fast = false;
+ bool Allowed = TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
+ VT, FirstLoad->getAddressSpace(),
+ FirstLoad->getAlignment(), &Fast);
+ if (!Allowed || !Fast)
+ return SDValue();
+
+ SDValue NewLoad =
+ DAG.getLoad(VT, SDLoc(N), Chain, FirstLoad->getBasePtr(),
+ FirstLoad->getPointerInfo(), FirstLoad->getAlignment());
+
+ // Transfer chain users from old loads to the new load.
+ for (LoadSDNode *L : Loads)
+ DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
+
+ if (NeedsBswap)
+ return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad);
+ else
+ return NewLoad;
+}
+
SDValue DAGCombiner::visitXOR(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
OpenPOWER on IntegriCloud