diff options
Diffstat (limited to 'llvm/lib')
| -rw-r--r-- | llvm/lib/Target/ARM/ARMTargetMachine.h | 4 | ||||
| -rw-r--r-- | llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp | 44 |
2 files changed, 46 insertions, 2 deletions
diff --git a/llvm/lib/Target/ARM/ARMTargetMachine.h b/llvm/lib/Target/ARM/ARMTargetMachine.h index be6bec75928..d4caf5ca6e1 100644 --- a/llvm/lib/Target/ARM/ARMTargetMachine.h +++ b/llvm/lib/Target/ARM/ARMTargetMachine.h @@ -46,6 +46,10 @@ public: virtual ARMJITInfo *getJITInfo() { return &JITInfo; } virtual const ARMSubtarget *getSubtargetImpl() const { return &Subtarget; } + virtual const ARMTargetLowering *getTargetLowering() const { + // Implemented by derived classes + llvm_unreachable("getTargetLowering not implemented"); + } virtual const InstrItineraryData *getInstrItineraryData() const { return &InstrItins; } diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp index 404a6fff117..61cb1f6b9a3 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Support/Debug.h" #include "llvm/Target/TargetLowering.h" +#include "llvm/Target/CostTable.h" using namespace llvm; // Declare the pass initialization routine locally as target-specific passes @@ -34,18 +35,20 @@ namespace { class ARMTTI : public ImmutablePass, public TargetTransformInfo { const ARMBaseTargetMachine *TM; const ARMSubtarget *ST; + const ARMTargetLowering *TLI; /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the result needs to be inserted and/or extracted from vectors. unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) const; public: - ARMTTI() : ImmutablePass(ID), TM(0), ST(0) { + ARMTTI() : ImmutablePass(ID), TM(0), ST(0), TLI(0) { llvm_unreachable("This pass cannot be directly constructed"); } ARMTTI(const ARMBaseTargetMachine *TM) - : ImmutablePass(ID), TM(TM), ST(TM->getSubtargetImpl()) { + : ImmutablePass(ID), TM(TM), ST(TM->getSubtargetImpl()), + TLI(TM->getTargetLowering()) { initializeARMTTIPass(*PassRegistry::getPassRegistry()); } @@ -111,6 +114,9 @@ public: return 1; } + unsigned getCastInstrCost(unsigned Opcode, Type *Dst, + Type *Src) const; + /// @} }; @@ -157,3 +163,37 @@ unsigned ARMTTI::getIntImmCost(const APInt &Imm, Type *Ty) const { } return 2; } + +unsigned ARMTTI::getCastInstrCost(unsigned Opcode, Type *Dst, + Type *Src) const { + int ISD = TLI->InstructionOpcodeToISD(Opcode); + assert(ISD && "Invalid opcode"); + + EVT SrcTy = TLI->getValueType(Src); + EVT DstTy = TLI->getValueType(Dst); + + if (!SrcTy.isSimple() || !DstTy.isSimple()) + return TargetTransformInfo::getCastInstrCost(Opcode, Dst, Src); + + // Some arithmetic, load and store operations have specific instructions + // to cast up/down their types automatically at no extra cost + // TODO: Get these tables to know at least what the related operations are + static const TypeConversionCostTblEntry<MVT> NEONConversionTbl[] = { + { ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 0 }, + { ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i16, 0 }, + { ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i32, 1 }, + { ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i32, 1 }, + { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 0 }, + { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1 }, + }; + + if (ST->hasNEON()) { + int Idx = ConvertCostTableLookup<MVT>(NEONConversionTbl, + array_lengthof(NEONConversionTbl), + ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()); + if (Idx != -1) + return NEONConversionTbl[Idx].Cost; + } + + return TargetTransformInfo::getCastInstrCost(Opcode, Dst, Src); +} |

