diff options
Diffstat (limited to 'llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp')
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 214 |
1 files changed, 214 insertions, 0 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index cd308806c36..1ea47fc8d59 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -253,6 +253,12 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) { case NVPTXISD::Suld3DV4I32Trap: ResNode = SelectSurfaceIntrinsic(N); break; + case ISD::AND: + case ISD::SRA: + case ISD::SRL: + // Try to select BFE + ResNode = SelectBFE(N); + break; case ISD::ADDRSPACECAST: ResNode = SelectAddrSpaceCast(N); break; @@ -2959,6 +2965,214 @@ SDNode *NVPTXDAGToDAGISel::SelectSurfaceIntrinsic(SDNode *N) { return Ret; } +/// SelectBFE - Look for instruction sequences that can be made more efficient +/// by using the 'bfe' (bit-field extract) PTX instruction +SDNode *NVPTXDAGToDAGISel::SelectBFE(SDNode *N) { + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + SDValue Len; + SDValue Start; + SDValue Val; + bool IsSigned = false; + + if (N->getOpcode() == ISD::AND) { + // Canonicalize the operands + // We want 'and %val, %mask' + if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) { + std::swap(LHS, RHS); + } + + ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS); + if (!Mask) { + // We need a constant mask on the RHS of the AND + return NULL; + } + + // Extract the mask bits + uint64_t MaskVal = Mask->getZExtValue(); + if (!isMask_64(MaskVal)) { + // We *could* handle shifted masks here, but doing so would require an + // 'and' operation to fix up the low-order bits so we would trade + // shr+and for bfe+and, which has the same throughput + return NULL; + } + + // How many bits are in our mask? + uint64_t NumBits = CountTrailingOnes_64(MaskVal); + Len = CurDAG->getTargetConstant(NumBits, MVT::i32); + + if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) { + // We have a 'srl/and' pair, extract the effective start bit and length + Val = LHS.getNode()->getOperand(0); + Start = LHS.getNode()->getOperand(1); + ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start); + if (StartConst) { + uint64_t StartVal = StartConst->getZExtValue(); + // How many "good" bits do we have left? "good" is defined here as bits + // that exist in the original value, not shifted in. + uint64_t GoodBits = Start.getValueType().getSizeInBits() - StartVal; + if (NumBits > GoodBits) { + // Do not handle the case where bits have been shifted in. In theory + // we could handle this, but the cost is likely higher than just + // emitting the srl/and pair. + return NULL; + } + Start = CurDAG->getTargetConstant(StartVal, MVT::i32); + } else { + // Do not handle the case where the shift amount (can be zero if no srl + // was found) is not constant. We could handle this case, but it would + // require run-time logic that would be more expensive than just + // emitting the srl/and pair. + return NULL; + } + } else { + // Do not handle the case where the LHS of the and is not a shift. While + // it would be trivial to handle this case, it would just transform + // 'and' -> 'bfe', but 'and' has higher-throughput. + return NULL; + } + } else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) { + if (LHS->getOpcode() == ISD::AND) { + ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS); + if (!ShiftCnst) { + // Shift amount must be constant + return NULL; + } + + uint64_t ShiftAmt = ShiftCnst->getZExtValue(); + + SDValue AndLHS = LHS->getOperand(0); + SDValue AndRHS = LHS->getOperand(1); + + // Canonicalize the AND to have the mask on the RHS + if (isa<ConstantSDNode>(AndLHS)) { + std::swap(AndLHS, AndRHS); + } + + ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS); + if (!MaskCnst) { + // Mask must be constant + return NULL; + } + + uint64_t MaskVal = MaskCnst->getZExtValue(); + uint64_t NumZeros; + uint64_t NumBits; + if (isMask_64(MaskVal)) { + NumZeros = 0; + // The number of bits in the result bitfield will be the number of + // trailing ones (the AND) minus the number of bits we shift off + NumBits = CountTrailingOnes_64(MaskVal) - ShiftAmt; + } else if (isShiftedMask_64(MaskVal)) { + NumZeros = countTrailingZeros(MaskVal); + unsigned NumOnes = CountTrailingOnes_64(MaskVal >> NumZeros); + // The number of bits in the result bitfield will be the number of + // trailing zeros plus the number of set bits in the mask minus the + // number of bits we shift off + NumBits = NumZeros + NumOnes - ShiftAmt; + } else { + // This is not a mask we can handle + return NULL; + } + + if (ShiftAmt < NumZeros) { + // Handling this case would require extra logic that would make this + // transformation non-profitable + return NULL; + } + + Val = AndLHS; + Start = CurDAG->getTargetConstant(ShiftAmt, MVT::i32); + Len = CurDAG->getTargetConstant(NumBits, MVT::i32); + } else if (LHS->getOpcode() == ISD::SHL) { + // Here, we have a pattern like: + // + // (sra (shl val, NN), MM) + // or + // (srl (shl val, NN), MM) + // + // If MM >= NN, we can efficiently optimize this with bfe + Val = LHS->getOperand(0); + + SDValue ShlRHS = LHS->getOperand(1); + ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS); + if (!ShlCnst) { + // Shift amount must be constant + return NULL; + } + uint64_t InnerShiftAmt = ShlCnst->getZExtValue(); + + SDValue ShrRHS = RHS; + ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS); + if (!ShrCnst) { + // Shift amount must be constant + return NULL; + } + uint64_t OuterShiftAmt = ShrCnst->getZExtValue(); + + // To avoid extra codegen and be profitable, we need Outer >= Inner + if (OuterShiftAmt < InnerShiftAmt) { + return NULL; + } + + // If the outer shift is more than the type size, we have no bitfield to + // extract (since we also check that the inner shift is <= the outer shift + // then this also implies that the inner shift is < the type size) + if (OuterShiftAmt >= Val.getValueType().getSizeInBits()) { + return NULL; + } + + Start = + CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, MVT::i32); + Len = + CurDAG->getTargetConstant(Val.getValueType().getSizeInBits() - + OuterShiftAmt, MVT::i32); + + if (N->getOpcode() == ISD::SRA) { + // If we have a arithmetic right shift, we need to use the signed bfe + // variant + IsSigned = true; + } + } else { + // No can do... + return NULL; + } + } else { + // No can do... + return NULL; + } + + + unsigned Opc; + // For the BFE operations we form here from "and" and "srl", always use the + // unsigned variants. + if (Val.getValueType() == MVT::i32) { + if (IsSigned) { + Opc = NVPTX::BFE_S32rii; + } else { + Opc = NVPTX::BFE_U32rii; + } + } else if (Val.getValueType() == MVT::i64) { + if (IsSigned) { + Opc = NVPTX::BFE_S64rii; + } else { + Opc = NVPTX::BFE_U64rii; + } + } else { + // We cannot handle this type + return NULL; + } + + SDValue Ops[] = { + Val, Start, Len + }; + + SDNode *Ret = + CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops); + + return Ret; +} + // SelectDirectAddr - Match a direct address for DAG. // A direct address could be a globaladdress or externalsymbol. bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) { |