summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp104
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.h4
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td9
-rw-r--r--llvm/test/CodeGen/NVPTX/f16-instructions.ll8
-rw-r--r--llvm/test/CodeGen/NVPTX/f16x2-instructions.ll13
-rw-r--r--llvm/test/CodeGen/NVPTX/math-intrins.ll12
6 files changed, 128 insertions, 22 deletions
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 2e955f89049..cae94d49759 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -546,13 +546,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// These map to conversion instructions for scalar FP types.
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
- ISD::FROUND, ISD::FTRUNC}) {
+ ISD::FTRUNC}) {
setOperationAction(Op, MVT::f16, Legal);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
}
+ setOperationAction(ISD::FROUND, MVT::f16, Promote);
+ setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
+ setOperationAction(ISD::FROUND, MVT::f32, Custom);
+ setOperationAction(ISD::FROUND, MVT::f64, Custom);
+
+
// 'Expand' implements FCOPYSIGN without calling an external library.
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
@@ -2068,6 +2074,100 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
}
}
+SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+
+ if (VT == MVT::f32)
+ return LowerFROUND32(Op, DAG);
+
+ if (VT == MVT::f64)
+ return LowerFROUND64(Op, DAG);
+
+ llvm_unreachable("unhandled type");
+}
+
+// This is the the rounding method used in CUDA libdevice in C like code:
+// float roundf(float A)
+// {
+// float RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f));
+// RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
+// return abs(A) < 0.5 ? (float)(int)A : RoundedA;
+// }
+SDValue NVPTXTargetLowering::LowerFROUND32(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc SL(Op);
+ SDValue A = Op.getOperand(0);
+ EVT VT = Op.getValueType();
+
+ SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
+
+ // RoundedA = (float) (int) ( A > 0 ? (A + 0.5f) : (A - 0.5f))
+ SDValue Bitcast = DAG.getNode(ISD::BITCAST, SL, MVT::i32, A);
+ const int SignBitMask = 0x80000000;
+ SDValue Sign = DAG.getNode(ISD::AND, SL, MVT::i32, Bitcast,
+ DAG.getConstant(SignBitMask, SL, MVT::i32));
+ const int PointFiveInBits = 0x3F000000;
+ SDValue PointFiveWithSignRaw =
+ DAG.getNode(ISD::OR, SL, MVT::i32, Sign,
+ DAG.getConstant(PointFiveInBits, SL, MVT::i32));
+ SDValue PointFiveWithSign =
+ DAG.getNode(ISD::BITCAST, SL, VT, PointFiveWithSignRaw);
+ SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, A, PointFiveWithSign);
+ SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
+
+ // RoundedA = abs(A) > 0x1.0p23 ? A : RoundedA;
+ EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+ SDValue IsLarge =
+ DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 23.0), SL, VT),
+ ISD::SETOGT);
+ RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
+
+ // return abs(A) < 0.5 ? (float)(int)A : RoundedA;
+ SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
+ DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
+ SDValue RoundedAForSmallA = DAG.getNode(ISD::FTRUNC, SL, VT, A);
+ return DAG.getNode(ISD::SELECT, SL, VT, IsSmall, RoundedAForSmallA, RoundedA);
+}
+
+// The implementation of round(double) is similar to that of round(float) in
+// that they both separate the value range into three regions and use a method
+// specific to the region to round the values. However, round(double) first
+// calculates the round of the absolute value and then adds the sign back while
+// round(float) directly rounds the value with sign.
+SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc SL(Op);
+ SDValue A = Op.getOperand(0);
+ EVT VT = Op.getValueType();
+
+ SDValue AbsA = DAG.getNode(ISD::FABS, SL, VT, A);
+
+ // double RoundedA = (double) (int) (abs(A) + 0.5f);
+ SDValue AdjustedA = DAG.getNode(ISD::FADD, SL, VT, AbsA,
+ DAG.getConstantFP(0.5, SL, VT));
+ SDValue RoundedA = DAG.getNode(ISD::FTRUNC, SL, VT, AdjustedA);
+
+ // RoundedA = abs(A) < 0.5 ? (double)0 : RoundedA;
+ EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
+ SDValue IsSmall =DAG.getSetCC(SL, SetCCVT, AbsA,
+ DAG.getConstantFP(0.5, SL, VT), ISD::SETOLT);
+ RoundedA = DAG.getNode(ISD::SELECT, SL, VT, IsSmall,
+ DAG.getConstantFP(0, SL, VT),
+ RoundedA);
+
+ // Add sign to rounded_A
+ RoundedA = DAG.getNode(ISD::FCOPYSIGN, SL, VT, RoundedA, A);
+ DAG.getNode(ISD::FTRUNC, SL, VT, A);
+
+ // RoundedA = abs(A) > 0x1.0p52 ? A : RoundedA;
+ SDValue IsLarge =
+ DAG.getSetCC(SL, SetCCVT, AbsA, DAG.getConstantFP(pow(2.0, 52.0), SL, VT),
+ ISD::SETOGT);
+ return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
+}
+
+
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -2098,6 +2198,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerShiftRightParts(Op, DAG);
case ISD::SELECT:
return LowerSelect(Op, DAG);
+ case ISD::FROUND:
+ return LowerFROUND(Op, DAG);
default:
llvm_unreachable("Custom lowering not defined for operation");
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index bbcc35f49d9..ef645fc1e54 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -556,6 +556,10 @@ private:
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 1bf556c9287..2ee90abb411 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3002,15 +3002,6 @@ def : Pat<(ffloor Float32Regs:$a),
def : Pat<(ffloor Float64Regs:$a),
(CVT_f64_f64 Float64Regs:$a, CvtRMI)>;
-def : Pat<(f16 (fround Float16Regs:$a)),
- (CVT_f16_f16 Float16Regs:$a, CvtRNI)>;
-def : Pat<(fround Float32Regs:$a),
- (CVT_f32_f32 Float32Regs:$a, CvtRNI_FTZ)>, Requires<[doF32FTZ]>;
-def : Pat<(f32 (fround Float32Regs:$a)),
- (CVT_f32_f32 Float32Regs:$a, CvtRNI)>, Requires<[doNoF32FTZ]>;
-def : Pat<(f64 (fround Float64Regs:$a)),
- (CVT_f64_f64 Float64Regs:$a, CvtRNI)>;
-
def : Pat<(ftrunc Float16Regs:$a),
(CVT_f16_f16 Float16Regs:$a, CvtRZI)>;
def : Pat<(ftrunc Float32Regs:$a),
diff --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
index 7788adc8698..9aa81dac126 100644
--- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
@@ -1107,9 +1107,11 @@ define half @test_nearbyint(half %a) #0 {
}
; CHECK-LABEL: test_round(
-; CHECK: ld.param.b16 [[A:%h[0-9]+]], [test_round_param_0];
-; CHECK: cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]];
-; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ld.param.b16 {{.*}}, [test_round_param_0];
+; check the use of sign mask and 0.5 to implement round
+; CHECK: and.b32 [[R:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R]], 1056964608;
+; CHECK: st.param.b16 [func_retval0+0], {{.*}};
; CHECK: ret;
define half @test_round(half %a) #0 {
%r = call half @llvm.round.f16(half %a)
diff --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
index a8996815af4..44dda09a902 100644
--- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
@@ -1378,12 +1378,13 @@ define <2 x half> @test_nearbyint(<2 x half> %a) #0 {
}
; CHECK-LABEL: test_round(
-; CHECK: ld.param.b32 [[A:%hh[0-9]+]], [test_round_param_0];
-; CHECK-DAG: mov.b32 {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]];
-; CHECK-DAG: cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]];
-; CHECK-DAG: cvt.rni.f16.f16 [[R0:%h[0-9]+]], [[A0]];
-; CHECK: mov.b32 [[R:%hh[0-9]+]], {[[R0]], [[R1]]}
-; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ld.param.b32 {{.*}}, [test_round_param_0];
+; check the use of sign mask and 0.5 to implement round
+; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
+; CHECK: and.b32 [[R2:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R2]], 1056964608;
+; CHECK: st.param.b32 [func_retval0+0], {{.*}};
; CHECK: ret;
define <2 x half> @test_round(<2 x half> %a) #0 {
%r = call <2 x half> @llvm.round.f16(<2 x half> %a)
diff --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll
index 828a8807dcf..412b25c7a3b 100644
--- a/llvm/test/CodeGen/NVPTX/math-intrins.ll
+++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll
@@ -74,21 +74,27 @@ define double @floor_double(double %a) {
; CHECK-LABEL: round_float
define float @round_float(float %a) {
- ; CHECK: cvt.rni.f32.f32
+; check the use of sign mask and 0.5 to implement round
+; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
%b = call float @llvm.round.f32(float %a)
ret float %b
}
; CHECK-LABEL: round_float_ftz
define float @round_float_ftz(float %a) #1 {
- ; CHECK: cvt.rni.ftz.f32.f32
+; check the use of sign mask and 0.5 to implement round
+; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
%b = call float @llvm.round.f32(float %a)
ret float %b
}
; CHECK-LABEL: round_double
define double @round_double(double %a) {
- ; CHECK: cvt.rni.f64.f64
+; check the use of 0.5 to implement round
+; CHECK: setp.lt.f64 {{.*}}, [[R:%fd[0-9]+]], 0d3FE0000000000000;
+; CHECK: add.rn.f64 {{.*}}, [[R]], 0d3FE0000000000000;
%b = call double @llvm.round.f64(double %a)
ret double %b
}
OpenPOWER on IntegriCloud