diff options
Diffstat (limited to 'net')
-rw-r--r-- | net/bpf/test_run.c | 15 | ||||
-rw-r--r-- | net/core/filter.c | 176 | ||||
-rw-r--r-- | net/core/sock_map.c | 11 | ||||
-rw-r--r-- | net/ipv4/tcp_bpf.c | 41 | ||||
-rw-r--r-- | net/tls/tls_main.c | 2 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 3 |
6 files changed, 225 insertions, 23 deletions
diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c index 0c423b8cd75c..c89c22c49015 100644 --- a/net/bpf/test_run.c +++ b/net/bpf/test_run.c @@ -10,6 +10,8 @@ #include <linux/etherdevice.h> #include <linux/filter.h> #include <linux/sched/signal.h> +#include <net/sock.h> +#include <net/tcp.h> static __always_inline u32 bpf_test_run_one(struct bpf_prog *prog, void *ctx, struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE]) @@ -115,6 +117,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, u32 retval, duration; int hh_len = ETH_HLEN; struct sk_buff *skb; + struct sock *sk; void *data; int ret; @@ -137,11 +140,21 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, break; } + sk = kzalloc(sizeof(struct sock), GFP_USER); + if (!sk) { + kfree(data); + return -ENOMEM; + } + sock_net_set(sk, current->nsproxy->net_ns); + sock_init_data(NULL, sk); + skb = build_skb(data, 0); if (!skb) { kfree(data); + kfree(sk); return -ENOMEM; } + skb->sk = sk; skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN); __skb_put(skb, size); @@ -159,6 +172,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, if (pskb_expand_head(skb, nhead, 0, GFP_USER)) { kfree_skb(skb); + kfree(sk); return -ENOMEM; } } @@ -171,6 +185,7 @@ int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, size = skb_headlen(skb); ret = bpf_test_finish(kattr, uattr, skb->data, size, retval, duration); kfree_skb(skb); + kfree(sk); return ret; } diff --git a/net/core/filter.c b/net/core/filter.c index 1a3ac6c46873..35c6933c2622 100644 --- a/net/core/filter.c +++ b/net/core/filter.c @@ -2297,6 +2297,137 @@ static const struct bpf_func_proto bpf_msg_pull_data_proto = { .arg4_type = ARG_ANYTHING, }; +BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, + u32, len, u64, flags) +{ + struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge; + u32 new, i = 0, l, space, copy = 0, offset = 0; + u8 *raw, *to, *from; + struct page *page; + + if (unlikely(flags)) + return -EINVAL; + + /* First find the starting scatterlist element */ + i = msg->sg.start; + do { + l = sk_msg_elem(msg, i)->length; + + if (start < offset + l) + break; + offset += l; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); + + if (start >= offset + l) + return -EINVAL; + + space = MAX_MSG_FRAGS - sk_msg_elem_used(msg); + + /* If no space available will fallback to copy, we need at + * least one scatterlist elem available to push data into + * when start aligns to the beginning of an element or two + * when it falls inside an element. We handle the start equals + * offset case because its the common case for inserting a + * header. + */ + if (!space || (space == 1 && start != offset)) + copy = msg->sg.data[i].length; + + page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP, + get_order(copy + len)); + if (unlikely(!page)) + return -ENOMEM; + + if (copy) { + int front, back; + + raw = page_address(page); + + psge = sk_msg_elem(msg, i); + front = start - offset; + back = psge->length - front; + from = sg_virt(psge); + + if (front) + memcpy(raw, from, front); + + if (back) { + from += front; + to = raw + front + len; + + memcpy(to, from, back); + } + + put_page(sg_page(psge)); + } else if (start - offset) { + psge = sk_msg_elem(msg, i); + rsge = sk_msg_elem_cpy(msg, i); + + psge->length = start - offset; + rsge.length -= psge->length; + rsge.offset += start; + + sk_msg_iter_var_next(i); + sg_unmark_end(psge); + sk_msg_iter_next(msg, end); + } + + /* Slot(s) to place newly allocated data */ + new = i; + + /* Shift one or two slots as needed */ + if (!copy) { + sge = sk_msg_elem_cpy(msg, i); + + sk_msg_iter_var_next(i); + sg_unmark_end(&sge); + sk_msg_iter_next(msg, end); + + nsge = sk_msg_elem_cpy(msg, i); + if (rsge.length) { + sk_msg_iter_var_next(i); + nnsge = sk_msg_elem_cpy(msg, i); + } + + while (i != msg->sg.end) { + msg->sg.data[i] = sge; + sge = nsge; + sk_msg_iter_var_next(i); + if (rsge.length) { + nsge = nnsge; + nnsge = sk_msg_elem_cpy(msg, i); + } else { + nsge = sk_msg_elem_cpy(msg, i); + } + } + } + + /* Place newly allocated data buffer */ + sk_mem_charge(msg->sk, len); + msg->sg.size += len; + msg->sg.copy[new] = false; + sg_set_page(&msg->sg.data[new], page, len + copy, 0); + if (rsge.length) { + get_page(sg_page(&rsge)); + sk_msg_iter_var_next(new); + msg->sg.data[new] = rsge; + } + + sk_msg_compute_data_pointers(msg); + return 0; +} + +static const struct bpf_func_proto bpf_msg_push_data_proto = { + .func = bpf_msg_push_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb) { return task_get_classid(skb); @@ -4854,6 +4985,7 @@ bool bpf_helper_changes_pkt_data(void *func) func == bpf_xdp_adjust_head || func == bpf_xdp_adjust_meta || func == bpf_msg_pull_data || + func == bpf_msg_push_data || func == bpf_xdp_adjust_tail || #if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) func == bpf_lwt_seg6_store_bytes || @@ -4876,6 +5008,12 @@ bpf_base_func_proto(enum bpf_func_id func_id) return &bpf_map_update_elem_proto; case BPF_FUNC_map_delete_elem: return &bpf_map_delete_elem_proto; + case BPF_FUNC_map_push_elem: + return &bpf_map_push_elem_proto; + case BPF_FUNC_map_pop_elem: + return &bpf_map_pop_elem_proto; + case BPF_FUNC_map_peek_elem: + return &bpf_map_peek_elem_proto; case BPF_FUNC_get_prandom_u32: return &bpf_get_prandom_u32_proto; case BPF_FUNC_get_smp_processor_id: @@ -5124,6 +5262,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) return &bpf_msg_cork_bytes_proto; case BPF_FUNC_msg_pull_data: return &bpf_msg_pull_data_proto; + case BPF_FUNC_msg_push_data: + return &bpf_msg_push_data_proto; case BPF_FUNC_get_local_storage: return &bpf_get_local_storage_proto; default: @@ -5346,6 +5486,40 @@ static bool sk_filter_is_valid_access(int off, int size, return bpf_skb_is_valid_access(off, size, type, prog, info); } +static bool cg_skb_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, flow_keys): + return false; + } + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range(struct __sk_buff, mark): + case bpf_ctx_range(struct __sk_buff, priority): + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + static bool lwt_is_valid_access(int off, int size, enum bpf_access_type type, const struct bpf_prog *prog, @@ -7038,7 +7212,7 @@ const struct bpf_prog_ops xdp_prog_ops = { const struct bpf_verifier_ops cg_skb_verifier_ops = { .get_func_proto = cg_skb_func_proto, - .is_valid_access = sk_filter_is_valid_access, + .is_valid_access = cg_skb_is_valid_access, .convert_ctx_access = bpf_convert_ctx_access, }; diff --git a/net/core/sock_map.c b/net/core/sock_map.c index 3c0e44cb811a..be6092ac69f8 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -175,12 +175,13 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, } } - psock = sk_psock_get(sk); + psock = sk_psock_get_checked(sk); + if (IS_ERR(psock)) { + ret = PTR_ERR(psock); + goto out_progs; + } + if (psock) { - if (!sk_has_psock(sk)) { - ret = -EBUSY; - goto out_progs; - } if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || (skb_progs && READ_ONCE(psock->progs.skb_parser))) { sk_psock_put(sk, psock); diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index 80debb0daf37..b7918d4caa30 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -39,17 +39,19 @@ static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, } int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, - struct msghdr *msg, int len) + struct msghdr *msg, int len, int flags) { struct iov_iter *iter = &msg->msg_iter; + int peek = flags & MSG_PEEK; int i, ret, copied = 0; + struct sk_msg *msg_rx; + + msg_rx = list_first_entry_or_null(&psock->ingress_msg, + struct sk_msg, list); while (copied != len) { struct scatterlist *sge; - struct sk_msg *msg_rx; - msg_rx = list_first_entry_or_null(&psock->ingress_msg, - struct sk_msg, list); if (unlikely(!msg_rx)) break; @@ -70,21 +72,30 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, } copied += copy; - sge->offset += copy; - sge->length -= copy; - sk_mem_uncharge(sk, copy); - if (!sge->length) { - i++; - if (i == MAX_SKB_FRAGS) - i = 0; - if (!msg_rx->skb) - put_page(page); + if (likely(!peek)) { + sge->offset += copy; + sge->length -= copy; + sk_mem_uncharge(sk, copy); + msg_rx->sg.size -= copy; + + if (!sge->length) { + sk_msg_iter_var_next(i); + if (!msg_rx->skb) + put_page(page); + } + } else { + sk_msg_iter_var_next(i); } if (copied == len) break; } while (i != msg_rx->sg.end); + if (unlikely(peek)) { + msg_rx = list_next_entry(msg_rx, list); + continue; + } + msg_rx->sg.start = i; if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { list_del(&msg_rx->list); @@ -92,6 +103,8 @@ int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, consume_skb(msg_rx->skb); kfree(msg_rx); } + msg_rx = list_first_entry_or_null(&psock->ingress_msg, + struct sk_msg, list); } return copied; @@ -114,7 +127,7 @@ int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); lock_sock(sk); msg_bytes_ready: - copied = __tcp_bpf_recvmsg(sk, psock, msg, len); + copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); if (!copied) { int data, err = 0; long timeo; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index e90b6d537077..311cec8e533d 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -715,8 +715,6 @@ EXPORT_SYMBOL(tls_unregister_device); static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .name = "tls", - .uid = TCP_ULP_TLS, - .user_visible = true, .owner = THIS_MODULE, .init = tls_init, }; diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index a525fc4c2a4b..5cd88ba8acd1 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1478,7 +1478,8 @@ int tls_sw_recvmsg(struct sock *sk, skb = tls_wait_data(sk, psock, flags, timeo, &err); if (!skb) { if (psock) { - int ret = __tcp_bpf_recvmsg(sk, psock, msg, len); + int ret = __tcp_bpf_recvmsg(sk, psock, + msg, len, flags); if (ret > 0) { copied += ret; |