diff options
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp | 63 | ||||
-rw-r--r-- | llvm/test/Transforms/InstCombine/rotate.ll | 95 |
2 files changed, 121 insertions, 37 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 88a72bb8eb5..26d0b522f01 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1546,6 +1546,66 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); } +/// Try to reduce a rotate pattern that includes a compare and select into a +/// sequence of ALU ops only. Example: +/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) +/// --> (a >> (-b & 31)) | (a << (b & 31)) +static Instruction *foldSelectRotate(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + // The false value of the select must be a rotate of the true value. + Value *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *TVal = Sel.getTrueValue(); + Value *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // We have one of these patterns so far: + // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1)) + // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) + // This must be a power-of-2 rotate for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) + return nullptr; + + // Check the shift amounts to see if they are an opposite pair. + Value *ShAmt; + if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0))))) + ShAmt = SA0; + else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1))))) + ShAmt = SA1; + else + return nullptr; + + // Finally, see if the select is filtering out a shift-by-zero. + Value *Cond = Sel.getCondition(); + ICmpInst::Predicate Pred; + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) || + Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // Convert to safely bitmasked shifts. + // TODO: When we can canonicalize to funnel shift intrinsics without risk of + // performance regressions, replace this sequence with that call. + Value *NegShAmt = Builder.CreateNeg(ShAmt); + Value *MaskedShAmt = Builder.CreateAnd(ShAmt, Width - 1); + Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, Width - 1); + Value *NewSA0 = ShAmt == SA0 ? MaskedShAmt : MaskedNegShAmt; + Value *NewSA1 = ShAmt == SA1 ? MaskedShAmt : MaskedNegShAmt; + Value *NewSh0 = Builder.CreateBinOp(ShiftOpcode0, TVal, NewSA0); + Value *NewSh1 = Builder.CreateBinOp(ShiftOpcode1, TVal, NewSA1); + return BinaryOperator::CreateOr(NewSh0, NewSh1); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -2010,5 +2070,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI)) return Select; + if (Instruction *Rot = foldSelectRotate(SI, Builder)) + return Rot; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll index 4401539220a..6150063ab72 100644 --- a/llvm/test/Transforms/InstCombine/rotate.ll +++ b/llvm/test/Transforms/InstCombine/rotate.ll @@ -309,16 +309,16 @@ define i8 @rotateleft_8_neg_mask_wide_amount_commute(i8 %v, i32 %shamt) { ret i8 %ret } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. define i32 @rotr_select(i32 %x, i32 %shamt) { ; CHECK-LABEL: @rotr_select( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i32 32, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[X:%.*]], [[SHAMT]] -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[X]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHR]], [[SHL]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = sub i32 0, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[SHAMT]], 31 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP1]], 31 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[X:%.*]], [[TMP2]] +; CHECK-NEXT: [[TMP5:%.*]] = shl i32 [[X]], [[TMP3]] +; CHECK-NEXT: [[R:%.*]] = or i32 [[TMP4]], [[TMP5]] ; CHECK-NEXT: ret i32 [[R]] ; %cmp = icmp eq i32 %shamt, 0 @@ -330,16 +330,16 @@ define i32 @rotr_select(i32 %x, i32 %shamt) { ret i32 %r } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. define i8 @rotr_select_commute(i8 %x, i8 %shamt) { ; CHECK-LABEL: @rotr_select_commute( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i8 8, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i8 [[X:%.*]], [[SHAMT]] -; CHECK-NEXT: [[SHL:%.*]] = shl i8 [[X]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i8 [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = sub i8 0, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[SHAMT]], 7 +; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP1]], 7 +; CHECK-NEXT: [[TMP4:%.*]] = shl i8 [[X:%.*]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = lshr i8 [[X]], [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = or i8 [[TMP4]], [[TMP5]] ; CHECK-NEXT: ret i8 [[R]] ; %cmp = icmp eq i8 %shamt, 0 @@ -351,16 +351,16 @@ define i8 @rotr_select_commute(i8 %x, i8 %shamt) { ret i8 %r } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. define i16 @rotl_select(i16 %x, i16 %shamt) { ; CHECK-LABEL: @rotl_select( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i16 16, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i16 [[X:%.*]], [[SUB]] -; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[X]], [[SHAMT]] -; CHECK-NEXT: [[OR:%.*]] = or i16 [[SHR]], [[SHL]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i16 [[X]], i16 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = sub i16 0, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i16 [[SHAMT]], 15 +; CHECK-NEXT: [[TMP3:%.*]] = and i16 [[TMP1]], 15 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i16 [[X:%.*]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shl i16 [[X]], [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = or i16 [[TMP4]], [[TMP5]] ; CHECK-NEXT: ret i16 [[R]] ; %cmp = icmp eq i16 %shamt, 0 @@ -372,24 +372,45 @@ define i16 @rotl_select(i16 %x, i16 %shamt) { ret i16 %r } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. -define i64 @rotl_select_commute(i64 %x, i64 %shamt) { +define <2 x i64> @rotl_select_commute(<2 x i64> %x, <2 x i64> %shamt) { ; CHECK-LABEL: @rotl_select_commute( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i64 64, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i64 [[X:%.*]], [[SUB]] -; CHECK-NEXT: [[SHL:%.*]] = shl i64 [[X]], [[SHAMT]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i64 [[X]], i64 [[OR]] -; CHECK-NEXT: ret i64 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = sub <2 x i64> zeroinitializer, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i64> [[SHAMT]], <i64 63, i64 63> +; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i64> [[TMP1]], <i64 63, i64 63> +; CHECK-NEXT: [[TMP4:%.*]] = shl <2 x i64> [[X:%.*]], [[TMP2]] +; CHECK-NEXT: [[TMP5:%.*]] = lshr <2 x i64> [[X]], [[TMP3]] +; CHECK-NEXT: [[R:%.*]] = or <2 x i64> [[TMP4]], [[TMP5]] +; CHECK-NEXT: ret <2 x i64> [[R]] +; + %cmp = icmp eq <2 x i64> %shamt, zeroinitializer + %sub = sub <2 x i64> <i64 64, i64 64>, %shamt + %shr = lshr <2 x i64> %x, %sub + %shl = shl <2 x i64> %x, %shamt + %or = or <2 x i64> %shl, %shr + %r = select <2 x i1> %cmp, <2 x i64> %x, <2 x i64> %or + ret <2 x i64> %r +} + +; Negative test - the transform is only valid with power-of-2 types. + +define i24 @rotl_select_weird_type(i24 %x, i24 %shamt) { +; CHECK-LABEL: @rotl_select_weird_type( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i24 [[SHAMT:%.*]], 0 +; CHECK-NEXT: [[SUB:%.*]] = sub i24 24, [[SHAMT]] +; CHECK-NEXT: [[SHR:%.*]] = lshr i24 [[X:%.*]], [[SUB]] +; CHECK-NEXT: [[SHL:%.*]] = shl i24 [[X]], [[SHAMT]] +; CHECK-NEXT: [[OR:%.*]] = or i24 [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i24 [[X]], i24 [[OR]] +; CHECK-NEXT: ret i24 [[R]] ; - %cmp = icmp eq i64 %shamt, 0 - %sub = sub i64 64, %shamt - %shr = lshr i64 %x, %sub - %shl = shl i64 %x, %shamt - %or = or i64 %shl, %shr - %r = select i1 %cmp, i64 %x, i64 %or - ret i64 %r + %cmp = icmp eq i24 %shamt, 0 + %sub = sub i24 24, %shamt + %shr = lshr i24 %x, %sub + %shl = shl i24 %x, %shamt + %or = or i24 %shl, %shr + %r = select i1 %cmp, i24 %x, i24 %or + ret i24 %r } |