diff options
| -rw-r--r-- | llvm/lib/Target/BPF/BPFISelDAGToDAG.cpp | 312 | ||||
| -rw-r--r-- | llvm/test/CodeGen/BPF/remove_truncate_1.ll | 87 | ||||
| -rw-r--r-- | llvm/test/CodeGen/BPF/remove_truncate_2.ll | 65 |
3 files changed, 387 insertions, 77 deletions
diff --git a/llvm/lib/Target/BPF/BPFISelDAGToDAG.cpp b/llvm/lib/Target/BPF/BPFISelDAGToDAG.cpp index c6ddd6bdad5..f48429ee57b 100644 --- a/llvm/lib/Target/BPF/BPFISelDAGToDAG.cpp +++ b/llvm/lib/Target/BPF/BPFISelDAGToDAG.cpp @@ -16,6 +16,7 @@ #include "BPFRegisterInfo.h" #include "BPFSubtarget.h" #include "BPFTargetMachine.h" +#include "llvm/CodeGen/FunctionLoweringInfo.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -57,6 +58,11 @@ private: bool SelectAddr(SDValue Addr, SDValue &Base, SDValue &Offset); bool SelectFIAddr(SDValue Addr, SDValue &Base, SDValue &Offset); + // Node preprocessing cases + void PreprocessLoad(SDNode *Node, SelectionDAG::allnodes_iterator I); + void PreprocessCopyToReg(SDNode *Node); + void PreprocessTrunc(SDNode *Node, SelectionDAG::allnodes_iterator I); + // Find constants from a constant structure typedef std::vector<unsigned char> val_vec_type; bool fillGenericConstant(const DataLayout &DL, const Constant *CV, @@ -69,9 +75,12 @@ private: val_vec_type &Vals, int Offset); bool getConstantFieldValue(const GlobalAddressSDNode *Node, uint64_t Offset, uint64_t Size, unsigned char *ByteSeq); + bool checkLoadDef(unsigned DefReg, unsigned match_load_op); // Mapping from ConstantStruct global value to corresponding byte-list values std::map<const void *, val_vec_type> cs_vals_; + // Mapping from vreg to load memory opcode + std::map<unsigned, unsigned> load_to_vreg_; }; } // namespace @@ -203,89 +212,110 @@ void BPFDAGToDAGISel::Select(SDNode *Node) { SelectCode(Node); } +void BPFDAGToDAGISel::PreprocessLoad(SDNode *Node, + SelectionDAG::allnodes_iterator I) { + union { + uint8_t c[8]; + uint16_t s; + uint32_t i; + uint64_t d; + } new_val; // hold up the constant values replacing loads. + bool to_replace = false; + SDLoc DL(Node); + const LoadSDNode *LD = cast<LoadSDNode>(Node); + uint64_t size = LD->getMemOperand()->getSize(); + + if (!size || size > 8 || (size & (size - 1))) + return; + + SDNode *LDAddrNode = LD->getOperand(1).getNode(); + // Match LDAddr against either global_addr or (global_addr + offset) + unsigned opcode = LDAddrNode->getOpcode(); + if (opcode == ISD::ADD) { + SDValue OP1 = LDAddrNode->getOperand(0); + SDValue OP2 = LDAddrNode->getOperand(1); + + // We want to find the pattern global_addr + offset + SDNode *OP1N = OP1.getNode(); + if (OP1N->getOpcode() <= ISD::BUILTIN_OP_END || OP1N->getNumOperands() == 0) + return; + + DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); + + const GlobalAddressSDNode *GADN = + dyn_cast<GlobalAddressSDNode>(OP1N->getOperand(0).getNode()); + const ConstantSDNode *CDN = dyn_cast<ConstantSDNode>(OP2.getNode()); + if (GADN && CDN) + to_replace = + getConstantFieldValue(GADN, CDN->getZExtValue(), size, new_val.c); + } else if (LDAddrNode->getOpcode() > ISD::BUILTIN_OP_END && + LDAddrNode->getNumOperands() > 0) { + DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); + + SDValue OP1 = LDAddrNode->getOperand(0); + if (const GlobalAddressSDNode *GADN = + dyn_cast<GlobalAddressSDNode>(OP1.getNode())) + to_replace = getConstantFieldValue(GADN, 0, size, new_val.c); + } + + if (!to_replace) + return; + + // replacing the old with a new value + uint64_t val; + if (size == 1) + val = new_val.c[0]; + else if (size == 2) + val = new_val.s; + else if (size == 4) + val = new_val.i; + else { + val = new_val.d; + } + + DEBUG(dbgs() << "Replacing load of size " << size << " with constant " << val + << '\n'); + SDValue NVal = CurDAG->getConstant(val, DL, MVT::i64); + + // After replacement, the current node is dead, we need to + // go backward one step to make iterator still work + I--; + SDValue From[] = {SDValue(Node, 0), SDValue(Node, 1)}; + SDValue To[] = {NVal, NVal}; + CurDAG->ReplaceAllUsesOfValuesWith(From, To, 2); + I++; + // It is safe to delete node now + CurDAG->DeleteNode(Node); +} + void BPFDAGToDAGISel::PreprocessISelDAG() { - // Iterate through all nodes, only interested in loads from ConstantStruct - // ConstantArray should have converted by IR->DAG processing + // Iterate through all nodes, interested in the following cases: + // + // . loads from ConstantStruct or ConstantArray of constructs + // which can be turns into constant itself, with this we can + // avoid reading from read-only section at runtime. + // + // . reg truncating is often the result of 8/16/32bit->64bit or + // 8/16bit->32bit conversion. If the reg value is loaded with + // masked byte width, the AND operation can be removed since + // BPF LOAD already has zero extension. + // + // This also solved a correctness issue. + // In BPF socket-related program, e.g., __sk_buff->{data, data_end} + // are 32-bit registers, but later on, kernel verifier will rewrite + // it with 64-bit value. Therefore, truncating the value after the + // load will result in incorrect code. for (SelectionDAG::allnodes_iterator I = CurDAG->allnodes_begin(), E = CurDAG->allnodes_end(); I != E;) { SDNode *Node = &*I++; unsigned Opcode = Node->getOpcode(); - if (Opcode != ISD::LOAD) - continue; - - union { - uint8_t c[8]; - uint16_t s; - uint32_t i; - uint64_t d; - } new_val; // hold up the constant values replacing loads. - bool to_replace = false; - SDLoc DL(Node); - const LoadSDNode *LD = cast<LoadSDNode>(Node); - uint64_t size = LD->getMemOperand()->getSize(); - if (!size || size > 8 || (size & (size - 1))) - continue; - - SDNode *LDAddrNode = LD->getOperand(1).getNode(); - // Match LDAddr against either global_addr or (global_addr + offset) - unsigned opcode = LDAddrNode->getOpcode(); - if (opcode == ISD::ADD) { - SDValue OP1 = LDAddrNode->getOperand(0); - SDValue OP2 = LDAddrNode->getOperand(1); - - // We want to find the pattern global_addr + offset - SDNode *OP1N = OP1.getNode(); - if (OP1N->getOpcode() <= ISD::BUILTIN_OP_END || - OP1N->getNumOperands() == 0) - continue; - - DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); - - const GlobalAddressSDNode *GADN = - dyn_cast<GlobalAddressSDNode>(OP1N->getOperand(0).getNode()); - const ConstantSDNode *CDN = dyn_cast<ConstantSDNode>(OP2.getNode()); - if (GADN && CDN) - to_replace = - getConstantFieldValue(GADN, CDN->getZExtValue(), size, new_val.c); - } else if (LDAddrNode->getOpcode() > ISD::BUILTIN_OP_END && - LDAddrNode->getNumOperands() > 0) { - DEBUG(dbgs() << "Check candidate load: "; LD->dump(); dbgs() << '\n'); - - SDValue OP1 = LDAddrNode->getOperand(0); - if (const GlobalAddressSDNode *GADN = - dyn_cast<GlobalAddressSDNode>(OP1.getNode())) - to_replace = getConstantFieldValue(GADN, 0, size, new_val.c); - } - - if (!to_replace) - continue; - - // replacing the old with a new value - uint64_t val; - if (size == 1) - val = new_val.c[0]; - else if (size == 2) - val = new_val.s; - else if (size == 4) - val = new_val.i; - else { - val = new_val.d; - } - - DEBUG(dbgs() << "Replacing load of size " << size << " with constant " - << val << '\n'); - SDValue NVal = CurDAG->getConstant(val, DL, MVT::i64); - - // After replacement, the current node is dead, we need to - // go backward one step to make iterator still work - I--; - SDValue From[] = {SDValue(Node, 0), SDValue(Node, 1)}; - SDValue To[] = {NVal, NVal}; - CurDAG->ReplaceAllUsesOfValuesWith(From, To, 2); - I++; - // It is safe to delete node now - CurDAG->DeleteNode(Node); + if (Opcode == ISD::LOAD) + PreprocessLoad(Node, I); + else if (Opcode == ISD::CopyToReg) + PreprocessCopyToReg(Node); + else if (Opcode == ISD::AND) + PreprocessTrunc(Node, I); } } @@ -415,6 +445,134 @@ bool BPFDAGToDAGISel::fillConstantStruct(const DataLayout &DL, return true; } +void BPFDAGToDAGISel::PreprocessCopyToReg(SDNode *Node) { + const RegisterSDNode *RegN = dyn_cast<RegisterSDNode>(Node->getOperand(1)); + if (!RegN || !TargetRegisterInfo::isVirtualRegister(RegN->getReg())) + return; + + const LoadSDNode *LD = dyn_cast<LoadSDNode>(Node->getOperand(2)); + if (!LD) + return; + + // Assign a load value to a virtual register. record its load width + unsigned mem_load_op = 0; + switch (LD->getMemOperand()->getSize()) { + default: + return; + case 4: + mem_load_op = BPF::LDW; + break; + case 2: + mem_load_op = BPF::LDH; + break; + case 1: + mem_load_op = BPF::LDB; + break; + } + + DEBUG(dbgs() << "Find Load Value to VReg " + << TargetRegisterInfo::virtReg2Index(RegN->getReg()) << '\n'); + load_to_vreg_[RegN->getReg()] = mem_load_op; +} + +void BPFDAGToDAGISel::PreprocessTrunc(SDNode *Node, + SelectionDAG::allnodes_iterator I) { + ConstantSDNode *MaskN = dyn_cast<ConstantSDNode>(Node->getOperand(1)); + if (!MaskN) + return; + + unsigned match_load_op = 0; + switch (MaskN->getZExtValue()) { + default: + return; + case 0xFFFFFFFF: + match_load_op = BPF::LDW; + break; + case 0xFFFF: + match_load_op = BPF::LDH; + break; + case 0xFF: + match_load_op = BPF::LDB; + break; + } + + // The Reg operand should be a virtual register, which is defined + // outside the current basic block. DAG combiner has done a pretty + // good job in removing truncating inside a single basic block. + SDValue BaseV = Node->getOperand(0); + if (BaseV.getOpcode() != ISD::CopyFromReg) + return; + + const RegisterSDNode *RegN = + dyn_cast<RegisterSDNode>(BaseV.getNode()->getOperand(1)); + if (!RegN || !TargetRegisterInfo::isVirtualRegister(RegN->getReg())) + return; + unsigned AndOpReg = RegN->getReg(); + DEBUG(dbgs() << "Examine %vreg" << TargetRegisterInfo::virtReg2Index(AndOpReg) + << '\n'); + + // Examine the PHI insns in the MachineBasicBlock to found out the + // definitions of this virtual register. At this stage (DAG2DAG + // transformation), only PHI machine insns are available in the machine basic + // block. + MachineBasicBlock *MBB = FuncInfo->MBB; + MachineInstr *MII = nullptr; + for (auto &MI : *MBB) { + for (unsigned i = 0; i < MI.getNumOperands(); ++i) { + const MachineOperand &MOP = MI.getOperand(i); + if (!MOP.isReg() || !MOP.isDef()) + continue; + unsigned Reg = MOP.getReg(); + if (TargetRegisterInfo::isVirtualRegister(Reg) && Reg == AndOpReg) { + MII = &MI; + break; + } + } + } + + if (MII == nullptr) { + // No phi definition in this block. + if (!checkLoadDef(AndOpReg, match_load_op)) + return; + } else { + // The PHI node looks like: + // %vreg2<def> = PHI %vreg0, <BB#1>, %vreg1, <BB#3> + // Trace each incoming definition, e.g., (%vreg0, BB#1) and (%vreg1, BB#3) + // The AND operation can be removed if both %vreg0 in BB#1 and %vreg1 in + // BB#3 are defined with with a load matching the MaskN. + DEBUG(dbgs() << "Check PHI Insn: "; MII->dump(); dbgs() << '\n'); + unsigned PrevReg = -1; + for (unsigned i = 0; i < MII->getNumOperands(); ++i) { + const MachineOperand &MOP = MII->getOperand(i); + if (MOP.isReg()) { + if (MOP.isDef()) + continue; + PrevReg = MOP.getReg(); + if (!TargetRegisterInfo::isVirtualRegister(PrevReg)) + return; + if (!checkLoadDef(PrevReg, match_load_op)) + return; + } + } + } + + DEBUG(dbgs() << "Remove the redundant AND operation in: "; Node->dump(); + dbgs() << '\n'); + + I--; + CurDAG->ReplaceAllUsesWith(SDValue(Node, 0), BaseV); + I++; + CurDAG->DeleteNode(Node); +} + +bool BPFDAGToDAGISel::checkLoadDef(unsigned DefReg, unsigned match_load_op) { + auto it = load_to_vreg_.find(DefReg); + if (it == load_to_vreg_.end()) + return false; // The definition of register is not exported yet. + + return it->second == match_load_op; +} + FunctionPass *llvm::createBPFISelDag(BPFTargetMachine &TM) { return new BPFDAGToDAGISel(TM); } diff --git a/llvm/test/CodeGen/BPF/remove_truncate_1.ll b/llvm/test/CodeGen/BPF/remove_truncate_1.ll new file mode 100644 index 00000000000..65433853b9d --- /dev/null +++ b/llvm/test/CodeGen/BPF/remove_truncate_1.ll @@ -0,0 +1,87 @@ +; RUN: llc < %s -march=bpf -verify-machineinstrs | FileCheck %s + +; Source code: +; struct xdp_md { +; unsigned data; +; unsigned data_end; +; }; +; +; int gbl; +; int xdp_dummy(struct xdp_md *xdp) +; { +; char tmp; +; long addr; +; +; if (gbl) { +; long addr1 = (long)xdp->data; +; tmp = *(char *)addr1; +; if (tmp == 1) +; return 3; +; } else { +; tmp = *(volatile char *)(long)xdp->data_end; +; if (tmp == 1) +; return 2; +; } +; addr = (long)xdp->data; +; tmp = *(volatile char *)addr; +; if (tmp == 0) +; return 1; +; return 0; +; } + +%struct.xdp_md = type { i32, i32 } + +@gbl = common local_unnamed_addr global i32 0, align 4 + +; Function Attrs: norecurse nounwind +define i32 @xdp_dummy(%struct.xdp_md* nocapture readonly %xdp) local_unnamed_addr #0 { +entry: + %0 = load i32, i32* @gbl, align 4 + %tobool = icmp eq i32 %0, 0 + br i1 %tobool, label %if.else, label %if.then + +if.then: ; preds = %entry + %data = getelementptr inbounds %struct.xdp_md, %struct.xdp_md* %xdp, i64 0, i32 0 + %1 = load i32, i32* %data, align 4 + %conv = zext i32 %1 to i64 + %2 = inttoptr i64 %conv to i8* + %3 = load i8, i8* %2, align 1 + %cmp = icmp eq i8 %3, 1 + br i1 %cmp, label %cleanup20, label %if.end12 +; CHECK: r1 = *(u32 *)(r1 + 0) +; CHECK: r2 = *(u8 *)(r1 + 0) + +if.else: ; preds = %entry + %data_end = getelementptr inbounds %struct.xdp_md, %struct.xdp_md* %xdp, i64 0, i32 1 + %4 = load i32, i32* %data_end, align 4 + %conv6 = zext i32 %4 to i64 +; CHECK: r2 = *(u32 *)(r1 + 4) + %5 = inttoptr i64 %conv6 to i8* + %6 = load volatile i8, i8* %5, align 1 + %cmp8 = icmp eq i8 %6, 1 + br i1 %cmp8, label %cleanup20, label %if.else.if.end12_crit_edge + +if.else.if.end12_crit_edge: ; preds = %if.else + %data13.phi.trans.insert = getelementptr inbounds %struct.xdp_md, %struct.xdp_md* %xdp, i64 0, i32 0 + %.pre = load i32, i32* %data13.phi.trans.insert, align 4 + br label %if.end12 +; CHECK: r1 = *(u32 *)(r1 + 0) + +if.end12: ; preds = %if.else.if.end12_crit_edge, %if.then + %7 = phi i32 [ %.pre, %if.else.if.end12_crit_edge ], [ %1, %if.then ] + %conv14 = zext i32 %7 to i64 +; CHECK-NOT: r1 <<= 32 +; CHECK-NOT: r1 >>= 32 + %8 = inttoptr i64 %conv14 to i8* + %9 = load volatile i8, i8* %8, align 1 +; CHECK: r1 = *(u8 *)(r1 + 0) + %cmp16 = icmp eq i8 %9, 0 + %.28 = zext i1 %cmp16 to i32 + br label %cleanup20 + +cleanup20: ; preds = %if.then, %if.end12, %if.else + %retval.1 = phi i32 [ 3, %if.then ], [ 2, %if.else ], [ %.28, %if.end12 ] + ret i32 %retval.1 +} + +attributes #0 = { norecurse nounwind } diff --git a/llvm/test/CodeGen/BPF/remove_truncate_2.ll b/llvm/test/CodeGen/BPF/remove_truncate_2.ll new file mode 100644 index 00000000000..979d820dd85 --- /dev/null +++ b/llvm/test/CodeGen/BPF/remove_truncate_2.ll @@ -0,0 +1,65 @@ +; RUN: llc < %s -march=bpf -verify-machineinstrs | FileCheck %s + +; Source code: +; struct xdp_md { +; unsigned data; +; unsigned data_end; +; }; +; +; int gbl; +; int xdp_dummy(struct xdp_md *xdp) +; { +; char addr = *(char *)(long)xdp->data; +; if (gbl) { +; if (gbl == 1) +; return 1; +; if (addr == 1) +; return 3; +; } else if (addr == 0) +; return 2; +; return 0; +; } + +%struct.xdp_md = type { i32, i32 } + +@gbl = common local_unnamed_addr global i32 0, align 4 + +; Function Attrs: norecurse nounwind readonly +define i32 @xdp_dummy(%struct.xdp_md* nocapture readonly %xdp) local_unnamed_addr #0 { +entry: + %data = getelementptr inbounds %struct.xdp_md, %struct.xdp_md* %xdp, i64 0, i32 0 + %0 = load i32, i32* %data, align 4 + %conv = zext i32 %0 to i64 + %1 = inttoptr i64 %conv to i8* + %2 = load i8, i8* %1, align 1 +; CHECK: r1 = *(u32 *)(r1 + 0) +; CHECK: r1 = *(u8 *)(r1 + 0) + %3 = load i32, i32* @gbl, align 4 + switch i32 %3, label %if.end [ + i32 0, label %if.else + i32 1, label %cleanup + ] + +if.end: ; preds = %entry + %cmp4 = icmp eq i8 %2, 1 +; CHECK: r0 = 3 +; CHECK-NOT: r1 &= 255 +; CHECK: if r1 == 1 goto + br i1 %cmp4, label %cleanup, label %if.end13 + +if.else: ; preds = %entry + %cmp9 = icmp eq i8 %2, 0 +; CHECK: r0 = 2 +; CHECK-NOT: r1 &= 255 +; CHECK: if r1 == 0 goto + br i1 %cmp9, label %cleanup, label %if.end13 + +if.end13: ; preds = %if.else, %if.end + br label %cleanup + +cleanup: ; preds = %if.else, %if.end, %entry, %if.end13 + %retval.0 = phi i32 [ 0, %if.end13 ], [ 1, %entry ], [ 3, %if.end ], [ 2, %if.else ] + ret i32 %retval.0 +} + +attributes #0 = { norecurse nounwind readonly } |

