diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 107 |
1 files changed, 86 insertions, 21 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 406aed3bf84..36d60dee0c2 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -4866,23 +4866,23 @@ static bool getTargetShuffleMaskIndices(SDValue MaskNode, return true; } -static const Constant *getTargetShuffleMaskConstant(SDValue MaskNode) { - MaskNode = peekThroughBitcasts(MaskNode); +static const Constant *getTargetConstantFromNode(SDValue Op) { + Op = peekThroughBitcasts(Op); - auto *MaskLoad = dyn_cast<LoadSDNode>(MaskNode); - if (!MaskLoad) + auto *Load = dyn_cast<LoadSDNode>(Op); + if (!Load) return nullptr; - SDValue Ptr = MaskLoad->getBasePtr(); + SDValue Ptr = Load->getBasePtr(); if (Ptr->getOpcode() == X86ISD::Wrapper || Ptr->getOpcode() == X86ISD::WrapperRIP) Ptr = Ptr->getOperand(0); - auto *MaskCP = dyn_cast<ConstantPoolSDNode>(Ptr); - if (!MaskCP || MaskCP->isMachineConstantPoolEntry()) + auto *CNode = dyn_cast<ConstantPoolSDNode>(Ptr); + if (!CNode || CNode->isMachineConstantPoolEntry()) return nullptr; - return dyn_cast<Constant>(MaskCP->getConstVal()); + return dyn_cast<Constant>(CNode->getConstVal()); } /// Calculates the shuffle mask corresponding to the target-specific opcode. @@ -4992,7 +4992,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPERMILPMask(VT, RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMILPMask(C, MaskEltSize, Mask); break; } @@ -5006,7 +5006,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodePSHUFBMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodePSHUFBMask(C, Mask); break; } @@ -5055,7 +5055,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPERMIL2PMask(VT, CtrlImm, RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMIL2PMask(C, CtrlImm, MaskEltSize, Mask); break; } @@ -5070,7 +5070,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPPERMMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPPERMMask(C, Mask); break; } @@ -5087,7 +5087,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, DecodeVPERMVMask(RawMask, Mask); break; } - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMVMask(C, VT, Mask); break; } @@ -5099,7 +5099,7 @@ static bool getTargetShuffleMask(SDNode *N, MVT VT, bool AllowSentinelZero, Ops.push_back(N->getOperand(0)); Ops.push_back(N->getOperand(2)); SDValue MaskNode = N->getOperand(1); - if (auto *C = getTargetShuffleMaskConstant(MaskNode)) { + if (auto *C = getTargetConstantFromNode(MaskNode)) { DecodeVPERMV3Mask(C, VT, Mask); break; } @@ -30358,6 +30358,18 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG, case X86ISD::FNMSUB: return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0), Arg.getOperand(1), Arg.getOperand(2)); + case X86ISD::FMADD_RND: + return DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + case X86ISD::FMSUB_RND: + return DAG.getNode(X86ISD::FNMADD_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + case X86ISD::FNMADD_RND: + return DAG.getNode(X86ISD::FMSUB_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); + case X86ISD::FNMSUB_RND: + return DAG.getNode(X86ISD::FMADD_RND, DL, VT, Arg.getOperand(0), + Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3)); } } return SDValue(); @@ -30386,6 +30398,45 @@ static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG, } return SDValue(); } + +/// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000). +bool isFNEG(const SDNode *N) { + if (N->getOpcode() == ISD::FNEG) + return true; + + if (N->getOpcode() == X86ISD::FXOR) { + unsigned EltBits = N->getSimpleValueType(0).getScalarSizeInBits(); + SDValue Op1 = N->getOperand(1); + + auto isSignBitValue = [&](const ConstantFP *C) { + return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits); + }; + + // There is more than one way to represent the same constant on + // the different X86 targets. The type of the node may also depend on size. + // - load scalar value and broadcast + // - BUILD_VECTOR node + // - load from a constant pool. + // We check all variants here. + if (Op1.getOpcode() == X86ISD::VBROADCAST) { + if (auto *C = getTargetConstantFromNode(Op1.getOperand(0))) + return isSignBitValue(cast<ConstantFP>(C)); + + } else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) { + if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode()) + return isSignBitValue(CN->getConstantFPValue()); + + } else if (auto *C = getTargetConstantFromNode(Op1)) { + if (C->getType()->isVectorTy()) { + if (auto *SplatV = C->getSplatValue()) + return isSignBitValue(cast<ConstantFP>(SplatV)); + } else if (auto *FPConst = dyn_cast<ConstantFP>(C)) + return isSignBitValue(FPConst); + } + } + return false; +} + /// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes. static SDValue combineFOr(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { @@ -30401,6 +30452,9 @@ static SDValue combineFOr(SDNode *N, SelectionDAG &DAG, if (C->getValueAPF().isPosZero()) return N->getOperand(0); + if (isFNEG(N)) + if (SDValue NewVal = combineFneg(N, DAG, Subtarget)) + return NewVal; return lowerX86FPLogicOp(N, DAG, Subtarget); } @@ -30810,9 +30864,9 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, SDValue B = N->getOperand(1); SDValue C = N->getOperand(2); - bool NegA = (A.getOpcode() == ISD::FNEG); - bool NegB = (B.getOpcode() == ISD::FNEG); - bool NegC = (C.getOpcode() == ISD::FNEG); + bool NegA = isFNEG(A.getNode()); + bool NegB = isFNEG(B.getNode()); + bool NegC = isFNEG(C.getNode()); // Negative multiplication when NegA xor NegB bool NegMul = (NegA != NegB); @@ -30823,13 +30877,22 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG, if (NegC) C = C.getOperand(0); - unsigned Opcode; + unsigned NewOpcode; if (!NegMul) - Opcode = (!NegC) ? X86ISD::FMADD : X86ISD::FMSUB; + NewOpcode = (!NegC) ? X86ISD::FMADD : X86ISD::FMSUB; else - Opcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; + NewOpcode = (!NegC) ? X86ISD::FNMADD : X86ISD::FNMSUB; - return DAG.getNode(Opcode, dl, VT, A, B, C); + if (N->getOpcode() == X86ISD::FMADD_RND) { + switch (NewOpcode) { + case X86ISD::FMADD: NewOpcode = X86ISD::FMADD_RND; break; + case X86ISD::FMSUB: NewOpcode = X86ISD::FMSUB_RND; break; + case X86ISD::FNMADD: NewOpcode = X86ISD::FNMADD_RND; break; + case X86ISD::FNMSUB: NewOpcode = X86ISD::FNMSUB_RND; break; + } + return DAG.getNode(NewOpcode, dl, VT, A, B, C, N->getOperand(3)); + } + return DAG.getNode(NewOpcode, dl, VT, A, B, C); } static SDValue combineZext(SDNode *N, SelectionDAG &DAG, @@ -31559,6 +31622,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case X86ISD::VPERM2X128: case X86ISD::VZEXT_MOVL: case ISD::VECTOR_SHUFFLE: return combineShuffle(N, DAG, DCI,Subtarget); + case X86ISD::FMADD: + case X86ISD::FMADD_RND: case ISD::FMA: return combineFMA(N, DAG, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG); |