diff options
author | Sebastian Pop <sebpop@gmail.com> | 2019-09-07 20:24:51 +0000 |
---|---|---|
committer | Sebastian Pop <sebpop@gmail.com> | 2019-09-07 20:24:51 +0000 |
commit | eacb2c2c975cf88676a75d0835f85420c72cd46f (patch) | |
tree | e7be437262a14ccac8e1960a543a018c079692ab /llvm/lib | |
parent | c4450437ec91334d81a28084c4cf637cfdd8bbcb (diff) | |
download | bcm5719-llvm-eacb2c2c975cf88676a75d0835f85420c72cd46f.tar.gz bcm5719-llvm-eacb2c2c975cf88676a75d0835f85420c72cd46f.zip |
[aarch64] Add combine patterns for fp16 fmla
This patch enables generation of fused multiply add/sub for instructions operating on fp16.
Tested on aarch64-linux.
Differential Revision: https://reviews.llvm.org/D67297
llvm-svn: 371321
Diffstat (limited to 'llvm/lib')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 342 |
1 files changed, 280 insertions, 62 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index a9f54a1bc9e..3e1e798e43b 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -3466,13 +3466,19 @@ static bool isCombineInstrCandidateFP(const MachineInstr &Inst) { switch (Inst.getOpcode()) { default: break; + case AArch64::FADDHrr: case AArch64::FADDSrr: case AArch64::FADDDrr: + case AArch64::FADDv4f16: + case AArch64::FADDv8f16: case AArch64::FADDv2f32: case AArch64::FADDv2f64: case AArch64::FADDv4f32: + case AArch64::FSUBHrr: case AArch64::FSUBSrr: case AArch64::FSUBDrr: + case AArch64::FSUBv4f16: + case AArch64::FSUBv8f16: case AArch64::FSUBv2f32: case AArch64::FSUBv2f64: case AArch64::FSUBv4f32: @@ -3682,9 +3688,21 @@ static bool getFMAPatterns(MachineInstr &Root, default: assert(false && "Unsupported FP instruction in combiner\n"); break; + case AArch64::FADDHrr: + assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg() && + "FADDHrr does not have register operands"); + if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULHrr)) { + Patterns.push_back(MachineCombinerPattern::FMULADDH_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULHrr)) { + Patterns.push_back(MachineCombinerPattern::FMULADDH_OP2); + Found = true; + } + break; case AArch64::FADDSrr: assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg() && - "FADDWrr does not have register operands"); + "FADDSrr does not have register operands"); if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) { Patterns.push_back(MachineCombinerPattern::FMULADDS_OP1); Found = true; @@ -3720,6 +3738,46 @@ static bool getFMAPatterns(MachineInstr &Root, Found = true; } break; + case AArch64::FADDv4f16: + if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv4i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4i16_indexed_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv4f16)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4f16_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4i16_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4f16)) { + Patterns.push_back(MachineCombinerPattern::FMLAv4f16_OP2); + Found = true; + } + break; + case AArch64::FADDv8f16: + if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv8i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv8i16_indexed_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv8f16)) { + Patterns.push_back(MachineCombinerPattern::FMLAv8f16_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv8i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLAv8i16_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv8f16)) { + Patterns.push_back(MachineCombinerPattern::FMLAv8f16_OP2); + Found = true; + } + break; case AArch64::FADDv2f32: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULv2i32_indexed)) { @@ -3781,6 +3839,20 @@ static bool getFMAPatterns(MachineInstr &Root, } break; + case AArch64::FSUBHrr: + if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULHrr)) { + Patterns.push_back(MachineCombinerPattern::FMULSUBH_OP1); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULHrr)) { + Patterns.push_back(MachineCombinerPattern::FMULSUBH_OP2); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FNMULHrr)) { + Patterns.push_back(MachineCombinerPattern::FNMULSUBH_OP1); + Found = true; + } + break; case AArch64::FSUBSrr: if (canCombineWithFMUL(MBB, Root.getOperand(1), AArch64::FMULSrr)) { Patterns.push_back(MachineCombinerPattern::FMULSUBS_OP1); @@ -3817,6 +3889,46 @@ static bool getFMAPatterns(MachineInstr &Root, Found = true; } break; + case AArch64::FSUBv4f16: + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv4i16_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv4f16)) { + Patterns.push_back(MachineCombinerPattern::FMLSv4f16_OP2); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv4i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv2i32_indexed_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv4f16)) { + Patterns.push_back(MachineCombinerPattern::FMLSv2f32_OP1); + Found = true; + } + break; + case AArch64::FSUBv8f16: + if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv8i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv8i16_indexed_OP2); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(2), + AArch64::FMULv8f16)) { + Patterns.push_back(MachineCombinerPattern::FMLSv8f16_OP2); + Found = true; + } + if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv8i16_indexed)) { + Patterns.push_back(MachineCombinerPattern::FMLSv8i16_indexed_OP1); + Found = true; + } else if (canCombineWithFMUL(MBB, Root.getOperand(1), + AArch64::FMULv8f16)) { + Patterns.push_back(MachineCombinerPattern::FMLSv8f16_OP1); + Found = true; + } + break; case AArch64::FSUBv2f32: if (canCombineWithFMUL(MBB, Root.getOperand(2), AArch64::FMULv2i32_indexed)) { @@ -3889,6 +4001,10 @@ bool AArch64InstrInfo::isThroughputPattern( switch (Pattern) { default: break; + case MachineCombinerPattern::FMULADDH_OP1: + case MachineCombinerPattern::FMULADDH_OP2: + case MachineCombinerPattern::FMULSUBH_OP1: + case MachineCombinerPattern::FMULSUBH_OP2: case MachineCombinerPattern::FMULADDS_OP1: case MachineCombinerPattern::FMULADDS_OP2: case MachineCombinerPattern::FMULSUBS_OP1: @@ -3897,12 +4013,21 @@ bool AArch64InstrInfo::isThroughputPattern( case MachineCombinerPattern::FMULADDD_OP2: case MachineCombinerPattern::FMULSUBD_OP1: case MachineCombinerPattern::FMULSUBD_OP2: + case MachineCombinerPattern::FNMULSUBH_OP1: case MachineCombinerPattern::FNMULSUBS_OP1: case MachineCombinerPattern::FNMULSUBD_OP1: + case MachineCombinerPattern::FMLAv4i16_indexed_OP1: + case MachineCombinerPattern::FMLAv4i16_indexed_OP2: + case MachineCombinerPattern::FMLAv8i16_indexed_OP1: + case MachineCombinerPattern::FMLAv8i16_indexed_OP2: case MachineCombinerPattern::FMLAv1i32_indexed_OP1: case MachineCombinerPattern::FMLAv1i32_indexed_OP2: case MachineCombinerPattern::FMLAv1i64_indexed_OP1: case MachineCombinerPattern::FMLAv1i64_indexed_OP2: + case MachineCombinerPattern::FMLAv4f16_OP2: + case MachineCombinerPattern::FMLAv4f16_OP1: + case MachineCombinerPattern::FMLAv8f16_OP1: + case MachineCombinerPattern::FMLAv8f16_OP2: case MachineCombinerPattern::FMLAv2f32_OP2: case MachineCombinerPattern::FMLAv2f32_OP1: case MachineCombinerPattern::FMLAv2f64_OP1: @@ -3915,10 +4040,16 @@ bool AArch64InstrInfo::isThroughputPattern( case MachineCombinerPattern::FMLAv4f32_OP2: case MachineCombinerPattern::FMLAv4i32_indexed_OP1: case MachineCombinerPattern::FMLAv4i32_indexed_OP2: + case MachineCombinerPattern::FMLSv4i16_indexed_OP2: + case MachineCombinerPattern::FMLSv8i16_indexed_OP1: + case MachineCombinerPattern::FMLSv8i16_indexed_OP2: case MachineCombinerPattern::FMLSv1i32_indexed_OP2: case MachineCombinerPattern::FMLSv1i64_indexed_OP2: case MachineCombinerPattern::FMLSv2i32_indexed_OP2: case MachineCombinerPattern::FMLSv2i64_indexed_OP2: + case MachineCombinerPattern::FMLSv4f16_OP2: + case MachineCombinerPattern::FMLSv8f16_OP1: + case MachineCombinerPattern::FMLSv8f16_OP2: case MachineCombinerPattern::FMLSv2f32_OP2: case MachineCombinerPattern::FMLSv2f64_OP2: case MachineCombinerPattern::FMLSv4i32_indexed_OP2: @@ -4266,34 +4397,35 @@ void AArch64InstrInfo::genAlternativeCodeSequence( break; } // Floating Point Support + case MachineCombinerPattern::FMULADDH_OP1: + Opc = AArch64::FMADDHrrr; + RC = &AArch64::FPR16RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; case MachineCombinerPattern::FMULADDS_OP1: + Opc = AArch64::FMADDSrrr; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; case MachineCombinerPattern::FMULADDD_OP1: - // MUL I=A,B,0 - // ADD R,I,C - // ==> MADD R,A,B,C - // --- Create(MADD); - if (Pattern == MachineCombinerPattern::FMULADDS_OP1) { - Opc = AArch64::FMADDSrrr; - RC = &AArch64::FPR32RegClass; - } else { - Opc = AArch64::FMADDDrrr; - RC = &AArch64::FPR64RegClass; - } + Opc = AArch64::FMADDDrrr; + RC = &AArch64::FPR64RegClass; MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); break; + + case MachineCombinerPattern::FMULADDH_OP2: + Opc = AArch64::FMADDHrrr; + RC = &AArch64::FPR16RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; case MachineCombinerPattern::FMULADDS_OP2: + Opc = AArch64::FMADDSrrr; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; case MachineCombinerPattern::FMULADDD_OP2: - // FMUL I=A,B,0 - // FADD R,C,I - // ==> FMADD R,A,B,C - // --- Create(FMADD); - if (Pattern == MachineCombinerPattern::FMULADDS_OP2) { - Opc = AArch64::FMADDSrrr; - RC = &AArch64::FPR32RegClass; - } else { - Opc = AArch64::FMADDDrrr; - RC = &AArch64::FPR64RegClass; - } + Opc = AArch64::FMADDDrrr; + RC = &AArch64::FPR64RegClass; MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; @@ -4323,6 +4455,31 @@ void AArch64InstrInfo::genAlternativeCodeSequence( FMAInstKind::Indexed); break; + case MachineCombinerPattern::FMLAv4i16_indexed_OP1: + RC = &AArch64::FPR64RegClass; + Opc = AArch64::FMLAv4i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLAv4f16_OP1: + RC = &AArch64::FPR64RegClass; + Opc = AArch64::FMLAv4f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator); + break; + case MachineCombinerPattern::FMLAv4i16_indexed_OP2: + RC = &AArch64::FPR64RegClass; + Opc = AArch64::FMLAv4i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLAv4f16_OP2: + RC = &AArch64::FPR64RegClass; + Opc = AArch64::FMLAv4f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + break; + case MachineCombinerPattern::FMLAv2i32_indexed_OP1: case MachineCombinerPattern::FMLAv2f32_OP1: RC = &AArch64::FPR64RegClass; @@ -4350,6 +4507,31 @@ void AArch64InstrInfo::genAlternativeCodeSequence( } break; + case MachineCombinerPattern::FMLAv8i16_indexed_OP1: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLAv8i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLAv8f16_OP1: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLAv8f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator); + break; + case MachineCombinerPattern::FMLAv8i16_indexed_OP2: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLAv8i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLAv8f16_OP2: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLAv8f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + break; + case MachineCombinerPattern::FMLAv2i64_indexed_OP1: case MachineCombinerPattern::FMLAv2f64_OP1: RC = &AArch64::FPR128RegClass; @@ -4405,56 +4587,53 @@ void AArch64InstrInfo::genAlternativeCodeSequence( } break; + case MachineCombinerPattern::FMULSUBH_OP1: + Opc = AArch64::FNMSUBHrrr; + RC = &AArch64::FPR16RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; case MachineCombinerPattern::FMULSUBS_OP1: - case MachineCombinerPattern::FMULSUBD_OP1: { - // FMUL I=A,B,0 - // FSUB R,I,C - // ==> FNMSUB R,A,B,C // = -C + A*B - // --- Create(FNMSUB); - if (Pattern == MachineCombinerPattern::FMULSUBS_OP1) { - Opc = AArch64::FNMSUBSrrr; - RC = &AArch64::FPR32RegClass; - } else { - Opc = AArch64::FNMSUBDrrr; - RC = &AArch64::FPR64RegClass; - } + Opc = AArch64::FNMSUBSrrr; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::FMULSUBD_OP1: + Opc = AArch64::FNMSUBDrrr; + RC = &AArch64::FPR64RegClass; MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); break; - } + case MachineCombinerPattern::FNMULSUBH_OP1: + Opc = AArch64::FNMADDHrrr; + RC = &AArch64::FPR16RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; case MachineCombinerPattern::FNMULSUBS_OP1: - case MachineCombinerPattern::FNMULSUBD_OP1: { - // FNMUL I=A,B,0 - // FSUB R,I,C - // ==> FNMADD R,A,B,C // = -A*B - C - // --- Create(FNMADD); - if (Pattern == MachineCombinerPattern::FNMULSUBS_OP1) { - Opc = AArch64::FNMADDSrrr; - RC = &AArch64::FPR32RegClass; - } else { - Opc = AArch64::FNMADDDrrr; - RC = &AArch64::FPR64RegClass; - } + Opc = AArch64::FNMADDSrrr; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::FNMULSUBD_OP1: + Opc = AArch64::FNMADDDrrr; + RC = &AArch64::FPR64RegClass; MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); break; - } + case MachineCombinerPattern::FMULSUBH_OP2: + Opc = AArch64::FMSUBHrrr; + RC = &AArch64::FPR16RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; case MachineCombinerPattern::FMULSUBS_OP2: - case MachineCombinerPattern::FMULSUBD_OP2: { - // FMUL I=A,B,0 - // FSUB R,C,I - // ==> FMSUB R,A,B,C (computes C - A*B) - // --- Create(FMSUB); - if (Pattern == MachineCombinerPattern::FMULSUBS_OP2) { - Opc = AArch64::FMSUBSrrr; - RC = &AArch64::FPR32RegClass; - } else { - Opc = AArch64::FMSUBDrrr; - RC = &AArch64::FPR64RegClass; - } + Opc = AArch64::FMSUBSrrr; + RC = &AArch64::FPR32RegClass; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::FMULSUBD_OP2: + Opc = AArch64::FMSUBDrrr; + RC = &AArch64::FPR64RegClass; MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; - } case MachineCombinerPattern::FMLSv1i32_indexed_OP2: Opc = AArch64::FMLSv1i32_indexed; @@ -4470,6 +4649,19 @@ void AArch64InstrInfo::genAlternativeCodeSequence( FMAInstKind::Indexed); break; + case MachineCombinerPattern::FMLSv4f16_OP2: + RC = &AArch64::FPR64RegClass; + Opc = AArch64::FMLSv4f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + break; + case MachineCombinerPattern::FMLSv4i16_indexed_OP2: + RC = &AArch64::FPR64RegClass; + Opc = AArch64::FMLSv4i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLSv2f32_OP2: case MachineCombinerPattern::FMLSv2i32_indexed_OP2: RC = &AArch64::FPR64RegClass; @@ -4484,6 +4676,32 @@ void AArch64InstrInfo::genAlternativeCodeSequence( } break; + case MachineCombinerPattern::FMLSv8f16_OP1: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLSv8f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator); + break; + case MachineCombinerPattern::FMLSv8i16_indexed_OP1: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLSv8i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed); + break; + + case MachineCombinerPattern::FMLSv8f16_OP2: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLSv8f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Accumulator); + break; + case MachineCombinerPattern::FMLSv8i16_indexed_OP2: + RC = &AArch64::FPR128RegClass; + Opc = AArch64::FMLSv8i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC, + FMAInstKind::Indexed); + break; + case MachineCombinerPattern::FMLSv2f64_OP2: case MachineCombinerPattern::FMLSv2i64_indexed_OP2: RC = &AArch64::FPR128RegClass; |