summaryrefslogtreecommitdiffstats
path: root/net/core
diff options
context:
space:
mode:
Diffstat (limited to 'net/core')
-rw-r--r--net/core/Makefile2
-rw-r--r--net/core/dev.c14
-rw-r--r--net/core/filter.c290
-rw-r--r--net/core/skmsg.c802
-rw-r--r--net/core/sock.c61
-rw-r--r--net/core/sock_map.c1002
6 files changed, 1910 insertions, 261 deletions
diff --git a/net/core/Makefile b/net/core/Makefile
index 80175e6a2eb8..fccd31e0e7f7 100644
--- a/net/core/Makefile
+++ b/net/core/Makefile
@@ -16,6 +16,7 @@ obj-y += dev.o ethtool.o dev_addr_lists.o dst.o netevent.o \
obj-y += net-sysfs.o
obj-$(CONFIG_PAGE_POOL) += page_pool.o
obj-$(CONFIG_PROC_FS) += net-procfs.o
+obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o
obj-$(CONFIG_NET_PKTGEN) += pktgen.o
obj-$(CONFIG_NETPOLL) += netpoll.o
obj-$(CONFIG_FIB_RULES) += fib_rules.o
@@ -27,6 +28,7 @@ obj-$(CONFIG_CGROUP_NET_PRIO) += netprio_cgroup.o
obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o
obj-$(CONFIG_LWTUNNEL) += lwtunnel.o
obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o
+obj-$(CONFIG_BPF_STREAM_PARSER) += sock_map.o
obj-$(CONFIG_DST_CACHE) += dst_cache.o
obj-$(CONFIG_HWBM) += hwbm.o
obj-$(CONFIG_NET_DEVLINK) += devlink.o
diff --git a/net/core/dev.c b/net/core/dev.c
index 8497feea8fb5..022ad73d6253 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -4291,6 +4291,9 @@ static u32 netif_receive_generic_xdp(struct sk_buff *skb,
struct netdev_rx_queue *rxqueue;
void *orig_data, *orig_data_end;
u32 metalen, act = XDP_DROP;
+ __be16 orig_eth_type;
+ struct ethhdr *eth;
+ bool orig_bcast;
int hlen, off;
u32 mac_len;
@@ -4331,6 +4334,9 @@ static u32 netif_receive_generic_xdp(struct sk_buff *skb,
xdp->data_hard_start = skb->data - skb_headroom(skb);
orig_data_end = xdp->data_end;
orig_data = xdp->data;
+ eth = (struct ethhdr *)xdp->data;
+ orig_bcast = is_multicast_ether_addr_64bits(eth->h_dest);
+ orig_eth_type = eth->h_proto;
rxqueue = netif_get_rxqueue(skb);
xdp->rxq = &rxqueue->xdp_rxq;
@@ -4354,6 +4360,14 @@ static u32 netif_receive_generic_xdp(struct sk_buff *skb,
}
+ /* check if XDP changed eth hdr such SKB needs update */
+ eth = (struct ethhdr *)xdp->data;
+ if ((orig_eth_type != eth->h_proto) ||
+ (orig_bcast != is_multicast_ether_addr_64bits(eth->h_dest))) {
+ __skb_push(skb, ETH_HLEN);
+ skb->protocol = eth_type_trans(skb, skb->dev);
+ }
+
switch (act) {
case XDP_REDIRECT:
case XDP_TX:
diff --git a/net/core/filter.c b/net/core/filter.c
index 80da21b097b8..1a3ac6c46873 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -38,6 +38,7 @@
#include <net/protocol.h>
#include <net/netlink.h>
#include <linux/skbuff.h>
+#include <linux/skmsg.h>
#include <net/sock.h>
#include <net/flow_dissector.h>
#include <linux/errno.h>
@@ -2142,123 +2143,7 @@ static const struct bpf_func_proto bpf_redirect_proto = {
.arg2_type = ARG_ANYTHING,
};
-BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
- struct bpf_map *, map, void *, key, u64, flags)
-{
- struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
-
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- tcb->bpf.flags = flags;
- tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
- if (!tcb->bpf.sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-static const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
- .func = bpf_sk_redirect_hash,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_PTR_TO_MAP_KEY,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
- struct bpf_map *, map, u32, key, u64, flags)
-{
- struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
-
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- tcb->bpf.flags = flags;
- tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
- if (!tcb->bpf.sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-struct sock *do_sk_redirect_map(struct sk_buff *skb)
-{
- struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
-
- return tcb->bpf.sk_redir;
-}
-
-static const struct bpf_func_proto bpf_sk_redirect_map_proto = {
- .func = bpf_sk_redirect_map,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_ANYTHING,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg_buff *, msg,
- struct bpf_map *, map, void *, key, u64, flags)
-{
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- msg->flags = flags;
- msg->sk_redir = __sock_hash_lookup_elem(map, key);
- if (!msg->sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-static const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
- .func = bpf_msg_redirect_hash,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_PTR_TO_MAP_KEY,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
- struct bpf_map *, map, u32, key, u64, flags)
-{
- /* If user passes invalid input drop the packet. */
- if (unlikely(flags & ~(BPF_F_INGRESS)))
- return SK_DROP;
-
- msg->flags = flags;
- msg->sk_redir = __sock_map_lookup_elem(map, key);
- if (!msg->sk_redir)
- return SK_DROP;
-
- return SK_PASS;
-}
-
-struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
-{
- return msg->sk_redir;
-}
-
-static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
- .func = bpf_msg_redirect_map,
- .gpl_only = false,
- .ret_type = RET_INTEGER,
- .arg1_type = ARG_PTR_TO_CTX,
- .arg2_type = ARG_CONST_MAP_PTR,
- .arg3_type = ARG_ANYTHING,
- .arg4_type = ARG_ANYTHING,
-};
-
-BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg_buff *, msg, u32, bytes)
+BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg *, msg, u32, bytes)
{
msg->apply_bytes = bytes;
return 0;
@@ -2272,7 +2157,7 @@ static const struct bpf_func_proto bpf_msg_apply_bytes_proto = {
.arg2_type = ARG_ANYTHING,
};
-BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u32, bytes)
+BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg *, msg, u32, bytes)
{
msg->cork_bytes = bytes;
return 0;
@@ -2286,45 +2171,37 @@ static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
.arg2_type = ARG_ANYTHING,
};
-#define sk_msg_iter_var(var) \
- do { \
- var++; \
- if (var == MAX_SKB_FRAGS) \
- var = 0; \
- } while (0)
-
-BPF_CALL_4(bpf_msg_pull_data,
- struct sk_msg_buff *, msg, u32, start, u32, end, u64, flags)
+BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
+ u32, end, u64, flags)
{
- unsigned int len = 0, offset = 0, copy = 0, poffset = 0;
- int bytes = end - start, bytes_sg_total;
- struct scatterlist *sg = msg->sg_data;
- int first_sg, last_sg, i, shift;
- unsigned char *p, *to, *from;
+ u32 len = 0, offset = 0, copy = 0, poffset = 0, bytes = end - start;
+ u32 first_sge, last_sge, i, shift, bytes_sg_total;
+ struct scatterlist *sge;
+ u8 *raw, *to, *from;
struct page *page;
if (unlikely(flags || end <= start))
return -EINVAL;
/* First find the starting scatterlist element */
- i = msg->sg_start;
+ i = msg->sg.start;
do {
- len = sg[i].length;
+ len = sk_msg_elem(msg, i)->length;
if (start < offset + len)
break;
offset += len;
- sk_msg_iter_var(i);
- } while (i != msg->sg_end);
+ sk_msg_iter_var_next(i);
+ } while (i != msg->sg.end);
if (unlikely(start >= offset + len))
return -EINVAL;
- first_sg = i;
+ first_sge = i;
/* The start may point into the sg element so we need to also
* account for the headroom.
*/
bytes_sg_total = start - offset + bytes;
- if (!msg->sg_copy[i] && bytes_sg_total <= len)
+ if (!msg->sg.copy[i] && bytes_sg_total <= len)
goto out;
/* At this point we need to linearize multiple scatterlist
@@ -2338,12 +2215,12 @@ BPF_CALL_4(bpf_msg_pull_data,
* will copy the entire sg entry.
*/
do {
- copy += sg[i].length;
- sk_msg_iter_var(i);
+ copy += sk_msg_elem(msg, i)->length;
+ sk_msg_iter_var_next(i);
if (bytes_sg_total <= copy)
break;
- } while (i != msg->sg_end);
- last_sg = i;
+ } while (i != msg->sg.end);
+ last_sge = i;
if (unlikely(bytes_sg_total > copy))
return -EINVAL;
@@ -2352,63 +2229,61 @@ BPF_CALL_4(bpf_msg_pull_data,
get_order(copy));
if (unlikely(!page))
return -ENOMEM;
- p = page_address(page);
- i = first_sg;
+ raw = page_address(page);
+ i = first_sge;
do {
- from = sg_virt(&sg[i]);
- len = sg[i].length;
- to = p + poffset;
+ sge = sk_msg_elem(msg, i);
+ from = sg_virt(sge);
+ len = sge->length;
+ to = raw + poffset;
memcpy(to, from, len);
poffset += len;
- sg[i].length = 0;
- put_page(sg_page(&sg[i]));
+ sge->length = 0;
+ put_page(sg_page(sge));
- sk_msg_iter_var(i);
- } while (i != last_sg);
+ sk_msg_iter_var_next(i);
+ } while (i != last_sge);
- sg[first_sg].length = copy;
- sg_set_page(&sg[first_sg], page, copy, 0);
+ sg_set_page(&msg->sg.data[first_sge], page, copy, 0);
/* To repair sg ring we need to shift entries. If we only
* had a single entry though we can just replace it and
* be done. Otherwise walk the ring and shift the entries.
*/
- WARN_ON_ONCE(last_sg == first_sg);
- shift = last_sg > first_sg ?
- last_sg - first_sg - 1 :
- MAX_SKB_FRAGS - first_sg + last_sg - 1;
+ WARN_ON_ONCE(last_sge == first_sge);
+ shift = last_sge > first_sge ?
+ last_sge - first_sge - 1 :
+ MAX_SKB_FRAGS - first_sge + last_sge - 1;
if (!shift)
goto out;
- i = first_sg;
- sk_msg_iter_var(i);
+ i = first_sge;
+ sk_msg_iter_var_next(i);
do {
- int move_from;
+ u32 move_from;
- if (i + shift >= MAX_SKB_FRAGS)
- move_from = i + shift - MAX_SKB_FRAGS;
+ if (i + shift >= MAX_MSG_FRAGS)
+ move_from = i + shift - MAX_MSG_FRAGS;
else
move_from = i + shift;
-
- if (move_from == msg->sg_end)
+ if (move_from == msg->sg.end)
break;
- sg[i] = sg[move_from];
- sg[move_from].length = 0;
- sg[move_from].page_link = 0;
- sg[move_from].offset = 0;
-
- sk_msg_iter_var(i);
+ msg->sg.data[i] = msg->sg.data[move_from];
+ msg->sg.data[move_from].length = 0;
+ msg->sg.data[move_from].page_link = 0;
+ msg->sg.data[move_from].offset = 0;
+ sk_msg_iter_var_next(i);
} while (1);
- msg->sg_end -= shift;
- if (msg->sg_end < 0)
- msg->sg_end += MAX_SKB_FRAGS;
+
+ msg->sg.end = msg->sg.end - shift > msg->sg.end ?
+ msg->sg.end - shift + MAX_MSG_FRAGS :
+ msg->sg.end - shift;
out:
- msg->data = sg_virt(&sg[first_sg]) + start - offset;
+ msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
msg->data_end = msg->data + bytes;
-
return 0;
}
@@ -4821,9 +4696,12 @@ static const struct bpf_func_proto bpf_lwt_seg6_adjust_srh_proto = {
static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple,
struct sk_buff *skb, u8 family, u8 proto)
{
- int dif = skb->dev->ifindex;
bool refcounted = false;
struct sock *sk = NULL;
+ int dif = 0;
+
+ if (skb->dev)
+ dif = skb->dev->ifindex;
if (family == AF_INET) {
__be32 src4 = tuple->ipv4.saddr;
@@ -4839,21 +4717,24 @@ static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple,
sk = __udp4_lib_lookup(net, src4, tuple->ipv4.sport,
dst4, tuple->ipv4.dport,
dif, sdif, &udp_table, skb);
-#if IS_REACHABLE(CONFIG_IPV6)
+#if IS_ENABLED(CONFIG_IPV6)
} else {
struct in6_addr *src6 = (struct in6_addr *)&tuple->ipv6.saddr;
struct in6_addr *dst6 = (struct in6_addr *)&tuple->ipv6.daddr;
+ u16 hnum = ntohs(tuple->ipv6.dport);
int sdif = inet6_sdif(skb);
if (proto == IPPROTO_TCP)
sk = __inet6_lookup(net, &tcp_hashinfo, skb, 0,
src6, tuple->ipv6.sport,
- dst6, tuple->ipv6.dport,
+ dst6, hnum,
dif, sdif, &refcounted);
- else
- sk = __udp6_lib_lookup(net, src6, tuple->ipv6.sport,
- dst6, tuple->ipv6.dport,
- dif, sdif, &udp_table, skb);
+ else if (likely(ipv6_bpf_stub))
+ sk = ipv6_bpf_stub->udp6_lib_lookup(net,
+ src6, tuple->ipv6.sport,
+ dst6, hnum,
+ dif, sdif,
+ &udp_table, skb);
#endif
}
@@ -5200,6 +5081,9 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
}
}
+const struct bpf_func_proto bpf_sock_map_update_proto __weak;
+const struct bpf_func_proto bpf_sock_hash_update_proto __weak;
+
static const struct bpf_func_proto *
sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
{
@@ -5223,6 +5107,9 @@ sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
}
}
+const struct bpf_func_proto bpf_msg_redirect_map_proto __weak;
+const struct bpf_func_proto bpf_msg_redirect_hash_proto __weak;
+
static const struct bpf_func_proto *
sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
{
@@ -5244,6 +5131,9 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
}
}
+const struct bpf_func_proto bpf_sk_redirect_map_proto __weak;
+const struct bpf_func_proto bpf_sk_redirect_hash_proto __weak;
+
static const struct bpf_func_proto *
sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
{
@@ -6998,22 +6888,22 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
switch (si->off) {
case offsetof(struct sk_msg_md, data):
- *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data),
+ *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, data));
+ offsetof(struct sk_msg, data));
break;
case offsetof(struct sk_msg_md, data_end):
- *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end),
+ *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data_end),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, data_end));
+ offsetof(struct sk_msg, data_end));
break;
case offsetof(struct sk_msg_md, family):
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_family));
break;
@@ -7022,9 +6912,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_daddr));
break;
@@ -7034,9 +6924,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
skc_rcv_saddr) != 4);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common,
skc_rcv_saddr));
@@ -7051,9 +6941,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
off = si->off;
off -= offsetof(struct sk_msg_md, remote_ip6[0]);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common,
skc_v6_daddr.s6_addr32[0]) +
@@ -7072,9 +6962,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
off = si->off;
off -= offsetof(struct sk_msg_md, local_ip6[0]);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
offsetof(struct sock_common,
skc_v6_rcv_saddr.s6_addr32[0]) +
@@ -7088,9 +6978,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_dport));
#ifndef __BIG_ENDIAN_BITFIELD
@@ -7102,9 +6992,9 @@ static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2);
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(
- struct sk_msg_buff, sk),
+ struct sk_msg, sk),
si->dst_reg, si->src_reg,
- offsetof(struct sk_msg_buff, sk));
+ offsetof(struct sk_msg, sk));
*insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
offsetof(struct sock_common, skc_num));
break;
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
new file mode 100644
index 000000000000..56a99d0c9aa0
--- /dev/null
+++ b/net/core/skmsg.c
@@ -0,0 +1,802 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
+
+#include <linux/skmsg.h>
+#include <linux/skbuff.h>
+#include <linux/scatterlist.h>
+
+#include <net/sock.h>
+#include <net/tcp.h>
+
+static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
+{
+ if (msg->sg.end > msg->sg.start &&
+ elem_first_coalesce < msg->sg.end)
+ return true;
+
+ if (msg->sg.end < msg->sg.start &&
+ (elem_first_coalesce > msg->sg.start ||
+ elem_first_coalesce < msg->sg.end))
+ return true;
+
+ return false;
+}
+
+int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
+ int elem_first_coalesce)
+{
+ struct page_frag *pfrag = sk_page_frag(sk);
+ int ret = 0;
+
+ len -= msg->sg.size;
+ while (len > 0) {
+ struct scatterlist *sge;
+ u32 orig_offset;
+ int use, i;
+
+ if (!sk_page_frag_refill(sk, pfrag))
+ return -ENOMEM;
+
+ orig_offset = pfrag->offset;
+ use = min_t(int, len, pfrag->size - orig_offset);
+ if (!sk_wmem_schedule(sk, use))
+ return -ENOMEM;
+
+ i = msg->sg.end;
+ sk_msg_iter_var_prev(i);
+ sge = &msg->sg.data[i];
+
+ if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
+ sg_page(sge) == pfrag->page &&
+ sge->offset + sge->length == orig_offset) {
+ sge->length += use;
+ } else {
+ if (sk_msg_full(msg)) {
+ ret = -ENOSPC;
+ break;
+ }
+
+ sge = &msg->sg.data[msg->sg.end];
+ sg_unmark_end(sge);
+ sg_set_page(sge, pfrag->page, use, orig_offset);
+ get_page(pfrag->page);
+ sk_msg_iter_next(msg, end);
+ }
+
+ sk_mem_charge(sk, use);
+ msg->sg.size += use;
+ pfrag->offset += use;
+ len -= use;
+ }
+
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_msg_alloc);
+
+int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
+ u32 off, u32 len)
+{
+ int i = src->sg.start;
+ struct scatterlist *sge = sk_msg_elem(src, i);
+ u32 sge_len, sge_off;
+
+ if (sk_msg_full(dst))
+ return -ENOSPC;
+
+ while (off) {
+ if (sge->length > off)
+ break;
+ off -= sge->length;
+ sk_msg_iter_var_next(i);
+ if (i == src->sg.end && off)
+ return -ENOSPC;
+ sge = sk_msg_elem(src, i);
+ }
+
+ while (len) {
+ sge_len = sge->length - off;
+ sge_off = sge->offset + off;
+ if (sge_len > len)
+ sge_len = len;
+ off = 0;
+ len -= sge_len;
+ sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
+ sk_mem_charge(sk, sge_len);
+ sk_msg_iter_var_next(i);
+ if (i == src->sg.end && len)
+ return -ENOSPC;
+ sge = sk_msg_elem(src, i);
+ }
+
+ return 0;
+}
+EXPORT_SYMBOL_GPL(sk_msg_clone);
+
+void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
+{
+ int i = msg->sg.start;
+
+ do {
+ struct scatterlist *sge = sk_msg_elem(msg, i);
+
+ if (bytes < sge->length) {
+ sge->length -= bytes;
+ sge->offset += bytes;
+ sk_mem_uncharge(sk, bytes);
+ break;
+ }
+
+ sk_mem_uncharge(sk, sge->length);
+ bytes -= sge->length;
+ sge->length = 0;
+ sge->offset = 0;
+ sk_msg_iter_var_next(i);
+ } while (bytes && i != msg->sg.end);
+ msg->sg.start = i;
+}
+EXPORT_SYMBOL_GPL(sk_msg_return_zero);
+
+void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
+{
+ int i = msg->sg.start;
+
+ do {
+ struct scatterlist *sge = &msg->sg.data[i];
+ int uncharge = (bytes < sge->length) ? bytes : sge->length;
+
+ sk_mem_uncharge(sk, uncharge);
+ bytes -= uncharge;
+ sk_msg_iter_var_next(i);
+ } while (i != msg->sg.end);
+}
+EXPORT_SYMBOL_GPL(sk_msg_return);
+
+static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
+ bool charge)
+{
+ struct scatterlist *sge = sk_msg_elem(msg, i);
+ u32 len = sge->length;
+
+ if (charge)
+ sk_mem_uncharge(sk, len);
+ if (!msg->skb)
+ put_page(sg_page(sge));
+ memset(sge, 0, sizeof(*sge));
+ return len;
+}
+
+static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
+ bool charge)
+{
+ struct scatterlist *sge = sk_msg_elem(msg, i);
+ int freed = 0;
+
+ while (msg->sg.size) {
+ msg->sg.size -= sge->length;
+ freed += sk_msg_free_elem(sk, msg, i, charge);
+ sk_msg_iter_var_next(i);
+ sk_msg_check_to_free(msg, i, msg->sg.size);
+ sge = sk_msg_elem(msg, i);
+ }
+ if (msg->skb)
+ consume_skb(msg->skb);
+ sk_msg_init(msg);
+ return freed;
+}
+
+int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
+{
+ return __sk_msg_free(sk, msg, msg->sg.start, false);
+}
+EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
+
+int sk_msg_free(struct sock *sk, struct sk_msg *msg)
+{
+ return __sk_msg_free(sk, msg, msg->sg.start, true);
+}
+EXPORT_SYMBOL_GPL(sk_msg_free);
+
+static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
+ u32 bytes, bool charge)
+{
+ struct scatterlist *sge;
+ u32 i = msg->sg.start;
+
+ while (bytes) {
+ sge = sk_msg_elem(msg, i);
+ if (!sge->length)
+ break;
+ if (bytes < sge->length) {
+ if (charge)
+ sk_mem_uncharge(sk, bytes);
+ sge->length -= bytes;
+ sge->offset += bytes;
+ msg->sg.size -= bytes;
+ break;
+ }
+
+ msg->sg.size -= sge->length;
+ bytes -= sge->length;
+ sk_msg_free_elem(sk, msg, i, charge);
+ sk_msg_iter_var_next(i);
+ sk_msg_check_to_free(msg, i, bytes);
+ }
+ msg->sg.start = i;
+}
+
+void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
+{
+ __sk_msg_free_partial(sk, msg, bytes, true);
+}
+EXPORT_SYMBOL_GPL(sk_msg_free_partial);
+
+void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
+ u32 bytes)
+{
+ __sk_msg_free_partial(sk, msg, bytes, false);
+}
+
+void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
+{
+ int trim = msg->sg.size - len;
+ u32 i = msg->sg.end;
+
+ if (trim <= 0) {
+ WARN_ON(trim < 0);
+ return;
+ }
+
+ sk_msg_iter_var_prev(i);
+ msg->sg.size = len;
+ while (msg->sg.data[i].length &&
+ trim >= msg->sg.data[i].length) {
+ trim -= msg->sg.data[i].length;
+ sk_msg_free_elem(sk, msg, i, true);
+ sk_msg_iter_var_prev(i);
+ if (!trim)
+ goto out;
+ }
+
+ msg->sg.data[i].length -= trim;
+ sk_mem_uncharge(sk, trim);
+out:
+ /* If we trim data before curr pointer update copybreak and current
+ * so that any future copy operations start at new copy location.
+ * However trimed data that has not yet been used in a copy op
+ * does not require an update.
+ */
+ if (msg->sg.curr >= i) {
+ msg->sg.curr = i;
+ msg->sg.copybreak = msg->sg.data[i].length;
+ }
+ sk_msg_iter_var_next(i);
+ msg->sg.end = i;
+}
+EXPORT_SYMBOL_GPL(sk_msg_trim);
+
+int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
+ struct sk_msg *msg, u32 bytes)
+{
+ int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
+ const int to_max_pages = MAX_MSG_FRAGS;
+ struct page *pages[MAX_MSG_FRAGS];
+ ssize_t orig, copied, use, offset;
+
+ orig = msg->sg.size;
+ while (bytes > 0) {
+ i = 0;
+ maxpages = to_max_pages - num_elems;
+ if (maxpages == 0) {
+ ret = -EFAULT;
+ goto out;
+ }
+
+ copied = iov_iter_get_pages(from, pages, bytes, maxpages,
+ &offset);
+ if (copied <= 0) {
+ ret = -EFAULT;
+ goto out;
+ }
+
+ iov_iter_advance(from, copied);
+ bytes -= copied;
+ msg->sg.size += copied;
+
+ while (copied) {
+ use = min_t(int, copied, PAGE_SIZE - offset);
+ sg_set_page(&msg->sg.data[msg->sg.end],
+ pages[i], use, offset);
+ sg_unmark_end(&msg->sg.data[msg->sg.end]);
+ sk_mem_charge(sk, use);
+
+ offset = 0;
+ copied -= use;
+ sk_msg_iter_next(msg, end);
+ num_elems++;
+ i++;
+ }
+ /* When zerocopy is mixed with sk_msg_*copy* operations we
+ * may have a copybreak set in this case clear and prefer
+ * zerocopy remainder when possible.
+ */
+ msg->sg.copybreak = 0;
+ msg->sg.curr = msg->sg.end;
+ }
+out:
+ /* Revert iov_iter updates, msg will need to use 'trim' later if it
+ * also needs to be cleared.
+ */
+ if (ret)
+ iov_iter_revert(from, msg->sg.size - orig);
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
+
+int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
+ struct sk_msg *msg, u32 bytes)
+{
+ int ret = -ENOSPC, i = msg->sg.curr;
+ struct scatterlist *sge;
+ u32 copy, buf_size;
+ void *to;
+
+ do {
+ sge = sk_msg_elem(msg, i);
+ /* This is possible if a trim operation shrunk the buffer */
+ if (msg->sg.copybreak >= sge->length) {
+ msg->sg.copybreak = 0;
+ sk_msg_iter_var_next(i);
+ if (i == msg->sg.end)
+ break;
+ sge = sk_msg_elem(msg, i);
+ }
+
+ buf_size = sge->length - msg->sg.copybreak;
+ copy = (buf_size > bytes) ? bytes : buf_size;
+ to = sg_virt(sge) + msg->sg.copybreak;
+ msg->sg.copybreak += copy;
+ if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
+ ret = copy_from_iter_nocache(to, copy, from);
+ else
+ ret = copy_from_iter(to, copy, from);
+ if (ret != copy) {
+ ret = -EFAULT;
+ goto out;
+ }
+ bytes -= copy;
+ if (!bytes)
+ break;
+ msg->sg.copybreak = 0;
+ sk_msg_iter_var_next(i);
+ } while (i != msg->sg.end);
+out:
+ msg->sg.curr = i;
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
+
+static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
+{
+ struct sock *sk = psock->sk;
+ int copied = 0, num_sge;
+ struct sk_msg *msg;
+
+ msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
+ if (unlikely(!msg))
+ return -EAGAIN;
+ if (!sk_rmem_schedule(sk, skb, skb->len)) {
+ kfree(msg);
+ return -EAGAIN;
+ }
+
+ sk_msg_init(msg);
+ num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
+ if (unlikely(num_sge < 0)) {
+ kfree(msg);
+ return num_sge;
+ }
+
+ sk_mem_charge(sk, skb->len);
+ copied = skb->len;
+ msg->sg.start = 0;
+ msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
+ msg->skb = skb;
+
+ sk_psock_queue_msg(psock, msg);
+ sk->sk_data_ready(sk);
+ return copied;
+}
+
+static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
+ u32 off, u32 len, bool ingress)
+{
+ if (ingress)
+ return sk_psock_skb_ingress(psock, skb);
+ else
+ return skb_send_sock_locked(psock->sk, skb, off, len);
+}
+
+static void sk_psock_backlog(struct work_struct *work)
+{
+ struct sk_psock *psock = container_of(work, struct sk_psock, work);
+ struct sk_psock_work_state *state = &psock->work_state;
+ struct sk_buff *skb;
+ bool ingress;
+ u32 len, off;
+ int ret;
+
+ /* Lock sock to avoid losing sk_socket during loop. */
+ lock_sock(psock->sk);
+ if (state->skb) {
+ skb = state->skb;
+ len = state->len;
+ off = state->off;
+ state->skb = NULL;
+ goto start;
+ }
+
+ while ((skb = skb_dequeue(&psock->ingress_skb))) {
+ len = skb->len;
+ off = 0;
+start:
+ ingress = tcp_skb_bpf_ingress(skb);
+ do {
+ ret = -EIO;
+ if (likely(psock->sk->sk_socket))
+ ret = sk_psock_handle_skb(psock, skb, off,
+ len, ingress);
+ if (ret <= 0) {
+ if (ret == -EAGAIN) {
+ state->skb = skb;
+ state->len = len;
+ state->off = off;
+ goto end;
+ }
+ /* Hard errors break pipe and stop xmit. */
+ sk_psock_report_error(psock, ret ? -ret : EPIPE);
+ sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
+ kfree_skb(skb);
+ goto end;
+ }
+ off += ret;
+ len -= ret;
+ } while (len);
+
+ if (!ingress)
+ kfree_skb(skb);
+ }
+end:
+ release_sock(psock->sk);
+}
+
+struct sk_psock *sk_psock_init(struct sock *sk, int node)
+{
+ struct sk_psock *psock = kzalloc_node(sizeof(*psock),
+ GFP_ATOMIC | __GFP_NOWARN,
+ node);
+ if (!psock)
+ return NULL;
+
+ psock->sk = sk;
+ psock->eval = __SK_NONE;
+
+ INIT_LIST_HEAD(&psock->link);
+ spin_lock_init(&psock->link_lock);
+
+ INIT_WORK(&psock->work, sk_psock_backlog);
+ INIT_LIST_HEAD(&psock->ingress_msg);
+ skb_queue_head_init(&psock->ingress_skb);
+
+ sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
+ refcount_set(&psock->refcnt, 1);
+
+ rcu_assign_sk_user_data(sk, psock);
+ sock_hold(sk);
+
+ return psock;
+}
+EXPORT_SYMBOL_GPL(sk_psock_init);
+
+struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
+{
+ struct sk_psock_link *link;
+
+ spin_lock_bh(&psock->link_lock);
+ link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
+ list);
+ if (link)
+ list_del(&link->list);
+ spin_unlock_bh(&psock->link_lock);
+ return link;
+}
+
+void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
+{
+ struct sk_msg *msg, *tmp;
+
+ list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
+ list_del(&msg->list);
+ sk_msg_free(psock->sk, msg);
+ kfree(msg);
+ }
+}
+
+static void sk_psock_zap_ingress(struct sk_psock *psock)
+{
+ __skb_queue_purge(&psock->ingress_skb);
+ __sk_psock_purge_ingress_msg(psock);
+}
+
+static void sk_psock_link_destroy(struct sk_psock *psock)
+{
+ struct sk_psock_link *link, *tmp;
+
+ list_for_each_entry_safe(link, tmp, &psock->link, list) {
+ list_del(&link->list);
+ sk_psock_free_link(link);
+ }
+}
+
+static void sk_psock_destroy_deferred(struct work_struct *gc)
+{
+ struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
+
+ /* No sk_callback_lock since already detached. */
+ if (psock->parser.enabled)
+ strp_done(&psock->parser.strp);
+
+ cancel_work_sync(&psock->work);
+
+ psock_progs_drop(&psock->progs);
+
+ sk_psock_link_destroy(psock);
+ sk_psock_cork_free(psock);
+ sk_psock_zap_ingress(psock);
+
+ if (psock->sk_redir)
+ sock_put(psock->sk_redir);
+ sock_put(psock->sk);
+ kfree(psock);
+}
+
+void sk_psock_destroy(struct rcu_head *rcu)
+{
+ struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
+
+ INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
+ schedule_work(&psock->gc);
+}
+EXPORT_SYMBOL_GPL(sk_psock_destroy);
+
+void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
+{
+ rcu_assign_sk_user_data(sk, NULL);
+ sk_psock_cork_free(psock);
+ sk_psock_restore_proto(sk, psock);
+
+ write_lock_bh(&sk->sk_callback_lock);
+ if (psock->progs.skb_parser)
+ sk_psock_stop_strp(sk, psock);
+ write_unlock_bh(&sk->sk_callback_lock);
+ sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
+
+ call_rcu_sched(&psock->rcu, sk_psock_destroy);
+}
+EXPORT_SYMBOL_GPL(sk_psock_drop);
+
+static int sk_psock_map_verd(int verdict, bool redir)
+{
+ switch (verdict) {
+ case SK_PASS:
+ return redir ? __SK_REDIRECT : __SK_PASS;
+ case SK_DROP:
+ default:
+ break;
+ }
+
+ return __SK_DROP;
+}
+
+int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
+ struct sk_msg *msg)
+{
+ struct bpf_prog *prog;
+ int ret;
+
+ preempt_disable();
+ rcu_read_lock();
+ prog = READ_ONCE(psock->progs.msg_parser);
+ if (unlikely(!prog)) {
+ ret = __SK_PASS;
+ goto out;
+ }
+
+ sk_msg_compute_data_pointers(msg);
+ msg->sk = sk;
+ ret = BPF_PROG_RUN(prog, msg);
+ ret = sk_psock_map_verd(ret, msg->sk_redir);
+ psock->apply_bytes = msg->apply_bytes;
+ if (ret == __SK_REDIRECT) {
+ if (psock->sk_redir)
+ sock_put(psock->sk_redir);
+ psock->sk_redir = msg->sk_redir;
+ if (!psock->sk_redir) {
+ ret = __SK_DROP;
+ goto out;
+ }
+ sock_hold(psock->sk_redir);
+ }
+out:
+ rcu_read_unlock();
+ preempt_enable();
+ return ret;
+}
+EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
+
+static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
+ struct sk_buff *skb)
+{
+ int ret;
+
+ skb->sk = psock->sk;
+ bpf_compute_data_end_sk_skb(skb);
+ preempt_disable();
+ ret = BPF_PROG_RUN(prog, skb);
+ preempt_enable();
+ /* strparser clones the skb before handing it to a upper layer,
+ * meaning skb_orphan has been called. We NULL sk on the way out
+ * to ensure we don't trigger a BUG_ON() in skb/sk operations
+ * later and because we are not charging the memory of this skb
+ * to any socket yet.
+ */
+ skb->sk = NULL;
+ return ret;
+}
+
+static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
+{
+ struct sk_psock_parser *parser;
+
+ parser = container_of(strp, struct sk_psock_parser, strp);
+ return container_of(parser, struct sk_psock, parser);
+}
+
+static void sk_psock_verdict_apply(struct sk_psock *psock,
+ struct sk_buff *skb, int verdict)
+{
+ struct sk_psock *psock_other;
+ struct sock *sk_other;
+ bool ingress;
+
+ switch (verdict) {
+ case __SK_REDIRECT:
+ sk_other = tcp_skb_bpf_redirect_fetch(skb);
+ if (unlikely(!sk_other))
+ goto out_free;
+ psock_other = sk_psock(sk_other);
+ if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
+ !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
+ goto out_free;
+ ingress = tcp_skb_bpf_ingress(skb);
+ if ((!ingress && sock_writeable(sk_other)) ||
+ (ingress &&
+ atomic_read(&sk_other->sk_rmem_alloc) <=
+ sk_other->sk_rcvbuf)) {
+ if (!ingress)
+ skb_set_owner_w(skb, sk_other);
+ skb_queue_tail(&psock_other->ingress_skb, skb);
+ schedule_work(&psock_other->work);
+ break;
+ }
+ /* fall-through */
+ case __SK_DROP:
+ /* fall-through */
+ default:
+out_free:
+ kfree_skb(skb);
+ }
+}
+
+static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
+{
+ struct sk_psock *psock = sk_psock_from_strp(strp);
+ struct bpf_prog *prog;
+ int ret = __SK_DROP;
+
+ rcu_read_lock();
+ prog = READ_ONCE(psock->progs.skb_verdict);
+ if (likely(prog)) {
+ skb_orphan(skb);
+ tcp_skb_bpf_redirect_clear(skb);
+ ret = sk_psock_bpf_run(psock, prog, skb);
+ ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
+ }
+ rcu_read_unlock();
+ sk_psock_verdict_apply(psock, skb, ret);
+}
+
+static int sk_psock_strp_read_done(struct strparser *strp, int err)
+{
+ return err;
+}
+
+static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
+{
+ struct sk_psock *psock = sk_psock_from_strp(strp);
+ struct bpf_prog *prog;
+ int ret = skb->len;
+
+ rcu_read_lock();
+ prog = READ_ONCE(psock->progs.skb_parser);
+ if (likely(prog))
+ ret = sk_psock_bpf_run(psock, prog, skb);
+ rcu_read_unlock();
+ return ret;
+}
+
+/* Called with socket lock held. */
+static void sk_psock_data_ready(struct sock *sk)
+{
+ struct sk_psock *psock;
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (likely(psock)) {
+ write_lock_bh(&sk->sk_callback_lock);
+ strp_data_ready(&psock->parser.strp);
+ write_unlock_bh(&sk->sk_callback_lock);
+ }
+ rcu_read_unlock();
+}
+
+static void sk_psock_write_space(struct sock *sk)
+{
+ struct sk_psock *psock;
+ void (*write_space)(struct sock *sk);
+
+ rcu_read_lock();
+ psock = sk_psock(sk);
+ if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
+ schedule_work(&psock->work);
+ write_space = psock->saved_write_space;
+ rcu_read_unlock();
+ write_space(sk);
+}
+
+int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
+{
+ static const struct strp_callbacks cb = {
+ .rcv_msg = sk_psock_strp_read,
+ .read_sock_done = sk_psock_strp_read_done,
+ .parse_msg = sk_psock_strp_parse,
+ };
+
+ psock->parser.enabled = false;
+ return strp_init(&psock->parser.strp, sk, &cb);
+}
+
+void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
+{
+ struct sk_psock_parser *parser = &psock->parser;
+
+ if (parser->enabled)
+ return;
+
+ parser->saved_data_ready = sk->sk_data_ready;
+ sk->sk_data_ready = sk_psock_data_ready;
+ sk->sk_write_space = sk_psock_write_space;
+ parser->enabled = true;
+}
+
+void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
+{
+ struct sk_psock_parser *parser = &psock->parser;
+
+ if (!parser->enabled)
+ return;
+
+ sk->sk_data_ready = parser->saved_data_ready;
+ parser->saved_data_ready = NULL;
+ strp_stop(&parser->strp);
+ parser->enabled = false;
+}
diff --git a/net/core/sock.c b/net/core/sock.c
index fdf9fc7d3f98..6fcc4bc07d19 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -2239,67 +2239,6 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
}
EXPORT_SYMBOL(sk_page_frag_refill);
-int sk_alloc_sg(struct sock *sk, int len, struct scatterlist *sg,
- int sg_start, int *sg_curr_index, unsigned int *sg_curr_size,
- int first_coalesce)
-{
- int sg_curr = *sg_curr_index, use = 0, rc = 0;
- unsigned int size = *sg_curr_size;
- struct page_frag *pfrag;
- struct scatterlist *sge;
-
- len -= size;
- pfrag = sk_page_frag(sk);
-
- while (len > 0) {
- unsigned int orig_offset;
-
- if (!sk_page_frag_refill(sk, pfrag)) {
- rc = -ENOMEM;
- goto out;
- }
-
- use = min_t(int, len, pfrag->size - pfrag->offset);
-
- if (!sk_wmem_schedule(sk, use)) {
- rc = -ENOMEM;
- goto out;
- }
-
- sk_mem_charge(sk, use);
- size += use;
- orig_offset = pfrag->offset;
- pfrag->offset += use;
-
- sge = sg + sg_curr - 1;
- if (sg_curr > first_coalesce && sg_page(sge) == pfrag->page &&
- sge->offset + sge->length == orig_offset) {
- sge->length += use;
- } else {
- sge = sg + sg_curr;
- sg_unmark_end(sge);
- sg_set_page(sge, pfrag->page, use, orig_offset);
- get_page(pfrag->page);
- sg_curr++;
-
- if (sg_curr == MAX_SKB_FRAGS)
- sg_curr = 0;
-
- if (sg_curr == sg_start) {
- rc = -ENOSPC;
- break;
- }
- }
-
- len -= use;
- }
-out:
- *sg_curr_size = size;
- *sg_curr_index = sg_curr;
- return rc;
-}
-EXPORT_SYMBOL(sk_alloc_sg);
-
static void __lock_sock(struct sock *sk)
__releases(&sk->sk_lock.slock)
__acquires(&sk->sk_lock.slock)
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
new file mode 100644
index 000000000000..3c0e44cb811a
--- /dev/null
+++ b/net/core/sock_map.c
@@ -0,0 +1,1002 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
+
+#include <linux/bpf.h>
+#include <linux/filter.h>
+#include <linux/errno.h>
+#include <linux/file.h>
+#include <linux/net.h>
+#include <linux/workqueue.h>
+#include <linux/skmsg.h>
+#include <linux/list.h>
+#include <linux/jhash.h>
+
+struct bpf_stab {
+ struct bpf_map map;
+ struct sock **sks;
+ struct sk_psock_progs progs;
+ raw_spinlock_t lock;
+};
+
+#define SOCK_CREATE_FLAG_MASK \
+ (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
+
+static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
+{
+ struct bpf_stab *stab;
+ u64 cost;
+ int err;
+
+ if (!capable(CAP_NET_ADMIN))
+ return ERR_PTR(-EPERM);
+ if (attr->max_entries == 0 ||
+ attr->key_size != 4 ||
+ attr->value_size != 4 ||
+ attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
+ return ERR_PTR(-EINVAL);
+
+ stab = kzalloc(sizeof(*stab), GFP_USER);
+ if (!stab)
+ return ERR_PTR(-ENOMEM);
+
+ bpf_map_init_from_attr(&stab->map, attr);
+ raw_spin_lock_init(&stab->lock);
+
+ /* Make sure page count doesn't overflow. */
+ cost = (u64) stab->map.max_entries * sizeof(struct sock *);
+ if (cost >= U32_MAX - PAGE_SIZE) {
+ err = -EINVAL;
+ goto free_stab;
+ }
+
+ stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
+ err = bpf_map_precharge_memlock(stab->map.pages);
+ if (err)
+ goto free_stab;
+
+ stab->sks = bpf_map_area_alloc(stab->map.max_entries *
+ sizeof(struct sock *),
+ stab->map.numa_node);
+ if (stab->sks)
+ return &stab->map;
+ err = -ENOMEM;
+free_stab:
+ kfree(stab);
+ return ERR_PTR(err);
+}
+
+int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
+{
+ u32 ufd = attr->target_fd;
+ struct bpf_map *map;
+ struct fd f;
+ int ret;
+
+ f = fdget(ufd);
+ map = __bpf_map_get(f);
+ if (IS_ERR(map))
+ return PTR_ERR(map);
+ ret = sock_map_prog_update(map, prog, attr->attach_type);
+ fdput(f);
+ return ret;
+}
+
+static void sock_map_sk_acquire(struct sock *sk)
+ __acquires(&sk->sk_lock.slock)
+{
+ lock_sock(sk);
+ preempt_disable();
+ rcu_read_lock();
+}
+
+static void sock_map_sk_release(struct sock *sk)
+ __releases(&sk->sk_lock.slock)
+{
+ rcu_read_unlock();
+ preempt_enable();
+ release_sock(sk);
+}
+
+static void sock_map_add_link(struct sk_psock *psock,
+ struct sk_psock_link *link,
+ struct bpf_map *map, void *link_raw)
+{
+ link->link_raw = link_raw;
+ link->map = map;
+ spin_lock_bh(&psock->link_lock);
+ list_add_tail(&link->list, &psock->link);
+ spin_unlock_bh(&psock->link_lock);
+}
+
+static void sock_map_del_link(struct sock *sk,
+ struct sk_psock *psock, void *link_raw)
+{
+ struct sk_psock_link *link, *tmp;
+ bool strp_stop = false;
+
+ spin_lock_bh(&psock->link_lock);
+ list_for_each_entry_safe(link, tmp, &psock->link, list) {
+ if (link->link_raw == link_raw) {
+ struct bpf_map *map = link->map;
+ struct bpf_stab *stab = container_of(map, struct bpf_stab,
+ map);
+ if (psock->parser.enabled && stab->progs.skb_parser)
+ strp_stop = true;
+ list_del(&link->list);
+ sk_psock_free_link(link);
+ }
+ }
+ spin_unlock_bh(&psock->link_lock);
+ if (strp_stop) {
+ write_lock_bh(&sk->sk_callback_lock);
+ sk_psock_stop_strp(sk, psock);
+ write_unlock_bh(&sk->sk_callback_lock);
+ }
+}
+
+static void sock_map_unref(struct sock *sk, void *link_raw)
+{
+ struct sk_psock *psock = sk_psock(sk);
+
+ if (likely(psock)) {
+ sock_map_del_link(sk, psock, link_raw);
+ sk_psock_put(sk, psock);
+ }
+}
+
+static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
+ struct sock *sk)
+{
+ struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
+ bool skb_progs, sk_psock_is_new = false;
+ struct sk_psock *psock;
+ int ret;
+
+ skb_verdict = READ_ONCE(progs->skb_verdict);
+ skb_parser = READ_ONCE(progs->skb_parser);
+ skb_progs = skb_parser && skb_verdict;
+ if (skb_progs) {
+ skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
+ if (IS_ERR(skb_verdict))
+ return PTR_ERR(skb_verdict);
+ skb_parser = bpf_prog_inc_not_zero(skb_parser);
+ if (IS_ERR(skb_parser)) {
+ bpf_prog_put(skb_verdict);
+ return PTR_ERR(skb_parser);
+ }
+ }
+
+ msg_parser = READ_ONCE(progs->msg_parser);
+ if (msg_parser) {
+ msg_parser = bpf_prog_inc_not_zero(msg_parser);
+ if (IS_ERR(msg_parser)) {
+ ret = PTR_ERR(msg_parser);
+ goto out;
+ }
+ }
+
+ psock = sk_psock_get(sk);
+ if (psock) {
+ if (!sk_has_psock(sk)) {
+ ret = -EBUSY;
+ goto out_progs;
+ }
+ if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
+ (skb_progs && READ_ONCE(psock->progs.skb_parser))) {
+ sk_psock_put(sk, psock);
+ ret = -EBUSY;
+ goto out_progs;
+ }
+ } else {
+ psock = sk_psock_init(sk, map->numa_node);
+ if (!psock) {
+ ret = -ENOMEM;
+ goto out_progs;
+ }
+ sk_psock_is_new = true;
+ }
+
+ if (msg_parser)
+ psock_set_prog(&psock->progs.msg_parser, msg_parser);
+ if (sk_psock_is_new) {
+ ret = tcp_bpf_init(sk);
+ if (ret < 0)
+ goto out_drop;
+ } else {
+ tcp_bpf_reinit(sk);
+ }
+
+ write_lock_bh(&sk->sk_callback_lock);
+ if (skb_progs && !psock->parser.enabled) {
+ ret = sk_psock_init_strp(sk, psock);
+ if (ret) {
+ write_unlock_bh(&sk->sk_callback_lock);
+ goto out_drop;
+ }
+ psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
+ psock_set_prog(&psock->progs.skb_parser, skb_parser);
+ sk_psock_start_strp(sk, psock);
+ }
+ write_unlock_bh(&sk->sk_callback_lock);
+ return 0;
+out_drop:
+ sk_psock_put(sk, psock);
+out_progs:
+ if (msg_parser)
+ bpf_prog_put(msg_parser);
+out:
+ if (skb_progs) {
+ bpf_prog_put(skb_verdict);
+ bpf_prog_put(skb_parser);
+ }
+ return ret;
+}
+
+static void sock_map_free(struct bpf_map *map)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ int i;
+
+ synchronize_rcu();
+ rcu_read_lock();
+ raw_spin_lock_bh(&stab->lock);
+ for (i = 0; i < stab->map.max_entries; i++) {
+ struct sock **psk = &stab->sks[i];
+ struct sock *sk;
+
+ sk = xchg(psk, NULL);
+ if (sk)
+ sock_map_unref(sk, psk);
+ }
+ raw_spin_unlock_bh(&stab->lock);
+ rcu_read_unlock();
+
+ bpf_map_area_free(stab->sks);
+ kfree(stab);
+}
+
+static void sock_map_release_progs(struct bpf_map *map)
+{
+ psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
+}
+
+static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ if (unlikely(key >= map->max_entries))
+ return NULL;
+ return READ_ONCE(stab->sks[key]);
+}
+
+static void *sock_map_lookup(struct bpf_map *map, void *key)
+{
+ return ERR_PTR(-EOPNOTSUPP);
+}
+
+static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
+ struct sock **psk)
+{
+ struct sock *sk;
+
+ raw_spin_lock_bh(&stab->lock);
+ sk = *psk;
+ if (!sk_test || sk_test == sk)
+ *psk = NULL;
+ raw_spin_unlock_bh(&stab->lock);
+ if (unlikely(!sk))
+ return -EINVAL;
+ sock_map_unref(sk, psk);
+ return 0;
+}
+
+static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
+ void *link_raw)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+
+ __sock_map_delete(stab, sk, link_raw);
+}
+
+static int sock_map_delete_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ u32 i = *(u32 *)key;
+ struct sock **psk;
+
+ if (unlikely(i >= map->max_entries))
+ return -EINVAL;
+
+ psk = &stab->sks[i];
+ return __sock_map_delete(stab, NULL, psk);
+}
+
+static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ u32 i = key ? *(u32 *)key : U32_MAX;
+ u32 *key_next = next;
+
+ if (i == stab->map.max_entries - 1)
+ return -ENOENT;
+ if (i >= stab->map.max_entries)
+ *key_next = 0;
+ else
+ *key_next = i + 1;
+ return 0;
+}
+
+static int sock_map_update_common(struct bpf_map *map, u32 idx,
+ struct sock *sk, u64 flags)
+{
+ struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
+ struct sk_psock_link *link;
+ struct sk_psock *psock;
+ struct sock *osk;
+ int ret;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+ if (unlikely(flags > BPF_EXIST))
+ return -EINVAL;
+ if (unlikely(idx >= map->max_entries))
+ return -E2BIG;
+
+ link = sk_psock_init_link();
+ if (!link)
+ return -ENOMEM;
+
+ ret = sock_map_link(map, &stab->progs, sk);
+ if (ret < 0)
+ goto out_free;
+
+ psock = sk_psock(sk);
+ WARN_ON_ONCE(!psock);
+
+ raw_spin_lock_bh(&stab->lock);
+ osk = stab->sks[idx];
+ if (osk && flags == BPF_NOEXIST) {
+ ret = -EEXIST;
+ goto out_unlock;
+ } else if (!osk && flags == BPF_EXIST) {
+ ret = -ENOENT;
+ goto out_unlock;
+ }
+
+ sock_map_add_link(psock, link, map, &stab->sks[idx]);
+ stab->sks[idx] = sk;
+ if (osk)
+ sock_map_unref(osk, &stab->sks[idx]);
+ raw_spin_unlock_bh(&stab->lock);
+ return 0;
+out_unlock:
+ raw_spin_unlock_bh(&stab->lock);
+ if (psock)
+ sk_psock_put(sk, psock);
+out_free:
+ sk_psock_free_link(link);
+ return ret;
+}
+
+static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
+{
+ return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
+ ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
+}
+
+static bool sock_map_sk_is_suitable(const struct sock *sk)
+{
+ return sk->sk_type == SOCK_STREAM &&
+ sk->sk_protocol == IPPROTO_TCP;
+}
+
+static int sock_map_update_elem(struct bpf_map *map, void *key,
+ void *value, u64 flags)
+{
+ u32 ufd = *(u32 *)value;
+ u32 idx = *(u32 *)key;
+ struct socket *sock;
+ struct sock *sk;
+ int ret;
+
+ sock = sockfd_lookup(ufd, &ret);
+ if (!sock)
+ return ret;
+ sk = sock->sk;
+ if (!sk) {
+ ret = -EINVAL;
+ goto out;
+ }
+ if (!sock_map_sk_is_suitable(sk) ||
+ sk->sk_state != TCP_ESTABLISHED) {
+ ret = -EOPNOTSUPP;
+ goto out;
+ }
+
+ sock_map_sk_acquire(sk);
+ ret = sock_map_update_common(map, idx, sk, flags);
+ sock_map_sk_release(sk);
+out:
+ fput(sock->file);
+ return ret;
+}
+
+BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ if (likely(sock_map_sk_is_suitable(sops->sk) &&
+ sock_map_op_okay(sops)))
+ return sock_map_update_common(map, *(u32 *)key, sops->sk,
+ flags);
+ return -EOPNOTSUPP;
+}
+
+const struct bpf_func_proto bpf_sock_map_update_proto = {
+ .func = bpf_sock_map_update,
+ .gpl_only = false,
+ .pkt_access = true,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
+ struct bpf_map *, map, u32, key, u64, flags)
+{
+ struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ tcb->bpf.flags = flags;
+ tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
+ if (!tcb->bpf.sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_sk_redirect_map_proto = {
+ .func = bpf_sk_redirect_map,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_ANYTHING,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
+ struct bpf_map *, map, u32, key, u64, flags)
+{
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ msg->flags = flags;
+ msg->sk_redir = __sock_map_lookup_elem(map, key);
+ if (!msg->sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_msg_redirect_map_proto = {
+ .func = bpf_msg_redirect_map,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_ANYTHING,
+ .arg4_type = ARG_ANYTHING,
+};
+
+const struct bpf_map_ops sock_map_ops = {
+ .map_alloc = sock_map_alloc,
+ .map_free = sock_map_free,
+ .map_get_next_key = sock_map_get_next_key,
+ .map_update_elem = sock_map_update_elem,
+ .map_delete_elem = sock_map_delete_elem,
+ .map_lookup_elem = sock_map_lookup,
+ .map_release_uref = sock_map_release_progs,
+ .map_check_btf = map_check_no_btf,
+};
+
+struct bpf_htab_elem {
+ struct rcu_head rcu;
+ u32 hash;
+ struct sock *sk;
+ struct hlist_node node;
+ u8 key[0];
+};
+
+struct bpf_htab_bucket {
+ struct hlist_head head;
+ raw_spinlock_t lock;
+};
+
+struct bpf_htab {
+ struct bpf_map map;
+ struct bpf_htab_bucket *buckets;
+ u32 buckets_num;
+ u32 elem_size;
+ struct sk_psock_progs progs;
+ atomic_t count;
+};
+
+static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
+{
+ return jhash(key, len, 0);
+}
+
+static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
+ u32 hash)
+{
+ return &htab->buckets[hash & (htab->buckets_num - 1)];
+}
+
+static struct bpf_htab_elem *
+sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
+ u32 key_size)
+{
+ struct bpf_htab_elem *elem;
+
+ hlist_for_each_entry_rcu(elem, head, node) {
+ if (elem->hash == hash &&
+ !memcmp(&elem->key, key, key_size))
+ return elem;
+ }
+
+ return NULL;
+}
+
+static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ u32 key_size = map->key_size, hash;
+ struct bpf_htab_bucket *bucket;
+ struct bpf_htab_elem *elem;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ hash = sock_hash_bucket_hash(key, key_size);
+ bucket = sock_hash_select_bucket(htab, hash);
+ elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
+
+ return elem ? elem->sk : NULL;
+}
+
+static void sock_hash_free_elem(struct bpf_htab *htab,
+ struct bpf_htab_elem *elem)
+{
+ atomic_dec(&htab->count);
+ kfree_rcu(elem, rcu);
+}
+
+static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
+ void *link_raw)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ struct bpf_htab_elem *elem_probe, *elem = link_raw;
+ struct bpf_htab_bucket *bucket;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+ bucket = sock_hash_select_bucket(htab, elem->hash);
+
+ /* elem may be deleted in parallel from the map, but access here
+ * is okay since it's going away only after RCU grace period.
+ * However, we need to check whether it's still present.
+ */
+ raw_spin_lock_bh(&bucket->lock);
+ elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
+ elem->key, map->key_size);
+ if (elem_probe && elem_probe == elem) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ sock_hash_free_elem(htab, elem);
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+}
+
+static int sock_hash_delete_elem(struct bpf_map *map, void *key)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ u32 hash, key_size = map->key_size;
+ struct bpf_htab_bucket *bucket;
+ struct bpf_htab_elem *elem;
+ int ret = -ENOENT;
+
+ hash = sock_hash_bucket_hash(key, key_size);
+ bucket = sock_hash_select_bucket(htab, hash);
+
+ raw_spin_lock_bh(&bucket->lock);
+ elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
+ if (elem) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ sock_hash_free_elem(htab, elem);
+ ret = 0;
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+ return ret;
+}
+
+static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
+ void *key, u32 key_size,
+ u32 hash, struct sock *sk,
+ struct bpf_htab_elem *old)
+{
+ struct bpf_htab_elem *new;
+
+ if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
+ if (!old) {
+ atomic_dec(&htab->count);
+ return ERR_PTR(-E2BIG);
+ }
+ }
+
+ new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
+ htab->map.numa_node);
+ if (!new) {
+ atomic_dec(&htab->count);
+ return ERR_PTR(-ENOMEM);
+ }
+ memcpy(new->key, key, key_size);
+ new->sk = sk;
+ new->hash = hash;
+ return new;
+}
+
+static int sock_hash_update_common(struct bpf_map *map, void *key,
+ struct sock *sk, u64 flags)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ u32 key_size = map->key_size, hash;
+ struct bpf_htab_elem *elem, *elem_new;
+ struct bpf_htab_bucket *bucket;
+ struct sk_psock_link *link;
+ struct sk_psock *psock;
+ int ret;
+
+ WARN_ON_ONCE(!rcu_read_lock_held());
+ if (unlikely(flags > BPF_EXIST))
+ return -EINVAL;
+
+ link = sk_psock_init_link();
+ if (!link)
+ return -ENOMEM;
+
+ ret = sock_map_link(map, &htab->progs, sk);
+ if (ret < 0)
+ goto out_free;
+
+ psock = sk_psock(sk);
+ WARN_ON_ONCE(!psock);
+
+ hash = sock_hash_bucket_hash(key, key_size);
+ bucket = sock_hash_select_bucket(htab, hash);
+
+ raw_spin_lock_bh(&bucket->lock);
+ elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
+ if (elem && flags == BPF_NOEXIST) {
+ ret = -EEXIST;
+ goto out_unlock;
+ } else if (!elem && flags == BPF_EXIST) {
+ ret = -ENOENT;
+ goto out_unlock;
+ }
+
+ elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
+ if (IS_ERR(elem_new)) {
+ ret = PTR_ERR(elem_new);
+ goto out_unlock;
+ }
+
+ sock_map_add_link(psock, link, map, elem_new);
+ /* Add new element to the head of the list, so that
+ * concurrent search will find it before old elem.
+ */
+ hlist_add_head_rcu(&elem_new->node, &bucket->head);
+ if (elem) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ sock_hash_free_elem(htab, elem);
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+ return 0;
+out_unlock:
+ raw_spin_unlock_bh(&bucket->lock);
+ sk_psock_put(sk, psock);
+out_free:
+ sk_psock_free_link(link);
+ return ret;
+}
+
+static int sock_hash_update_elem(struct bpf_map *map, void *key,
+ void *value, u64 flags)
+{
+ u32 ufd = *(u32 *)value;
+ struct socket *sock;
+ struct sock *sk;
+ int ret;
+
+ sock = sockfd_lookup(ufd, &ret);
+ if (!sock)
+ return ret;
+ sk = sock->sk;
+ if (!sk) {
+ ret = -EINVAL;
+ goto out;
+ }
+ if (!sock_map_sk_is_suitable(sk) ||
+ sk->sk_state != TCP_ESTABLISHED) {
+ ret = -EOPNOTSUPP;
+ goto out;
+ }
+
+ sock_map_sk_acquire(sk);
+ ret = sock_hash_update_common(map, key, sk, flags);
+ sock_map_sk_release(sk);
+out:
+ fput(sock->file);
+ return ret;
+}
+
+static int sock_hash_get_next_key(struct bpf_map *map, void *key,
+ void *key_next)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ struct bpf_htab_elem *elem, *elem_next;
+ u32 hash, key_size = map->key_size;
+ struct hlist_head *head;
+ int i = 0;
+
+ if (!key)
+ goto find_first_elem;
+ hash = sock_hash_bucket_hash(key, key_size);
+ head = &sock_hash_select_bucket(htab, hash)->head;
+ elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
+ if (!elem)
+ goto find_first_elem;
+
+ elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
+ struct bpf_htab_elem, node);
+ if (elem_next) {
+ memcpy(key_next, elem_next->key, key_size);
+ return 0;
+ }
+
+ i = hash & (htab->buckets_num - 1);
+ i++;
+find_first_elem:
+ for (; i < htab->buckets_num; i++) {
+ head = &sock_hash_select_bucket(htab, i)->head;
+ elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
+ struct bpf_htab_elem, node);
+ if (elem_next) {
+ memcpy(key_next, elem_next->key, key_size);
+ return 0;
+ }
+ }
+
+ return -ENOENT;
+}
+
+static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
+{
+ struct bpf_htab *htab;
+ int i, err;
+ u64 cost;
+
+ if (!capable(CAP_NET_ADMIN))
+ return ERR_PTR(-EPERM);
+ if (attr->max_entries == 0 ||
+ attr->key_size == 0 ||
+ attr->value_size != 4 ||
+ attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
+ return ERR_PTR(-EINVAL);
+ if (attr->key_size > MAX_BPF_STACK)
+ return ERR_PTR(-E2BIG);
+
+ htab = kzalloc(sizeof(*htab), GFP_USER);
+ if (!htab)
+ return ERR_PTR(-ENOMEM);
+
+ bpf_map_init_from_attr(&htab->map, attr);
+
+ htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
+ htab->elem_size = sizeof(struct bpf_htab_elem) +
+ round_up(htab->map.key_size, 8);
+ if (htab->buckets_num == 0 ||
+ htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
+ err = -EINVAL;
+ goto free_htab;
+ }
+
+ cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
+ (u64) htab->elem_size * htab->map.max_entries;
+ if (cost >= U32_MAX - PAGE_SIZE) {
+ err = -EINVAL;
+ goto free_htab;
+ }
+
+ htab->buckets = bpf_map_area_alloc(htab->buckets_num *
+ sizeof(struct bpf_htab_bucket),
+ htab->map.numa_node);
+ if (!htab->buckets) {
+ err = -ENOMEM;
+ goto free_htab;
+ }
+
+ for (i = 0; i < htab->buckets_num; i++) {
+ INIT_HLIST_HEAD(&htab->buckets[i].head);
+ raw_spin_lock_init(&htab->buckets[i].lock);
+ }
+
+ return &htab->map;
+free_htab:
+ kfree(htab);
+ return ERR_PTR(err);
+}
+
+static void sock_hash_free(struct bpf_map *map)
+{
+ struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
+ struct bpf_htab_bucket *bucket;
+ struct bpf_htab_elem *elem;
+ struct hlist_node *node;
+ int i;
+
+ synchronize_rcu();
+ rcu_read_lock();
+ for (i = 0; i < htab->buckets_num; i++) {
+ bucket = sock_hash_select_bucket(htab, i);
+ raw_spin_lock_bh(&bucket->lock);
+ hlist_for_each_entry_safe(elem, node, &bucket->head, node) {
+ hlist_del_rcu(&elem->node);
+ sock_map_unref(elem->sk, elem);
+ }
+ raw_spin_unlock_bh(&bucket->lock);
+ }
+ rcu_read_unlock();
+
+ bpf_map_area_free(htab->buckets);
+ kfree(htab);
+}
+
+static void sock_hash_release_progs(struct bpf_map *map)
+{
+ psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
+}
+
+BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ WARN_ON_ONCE(!rcu_read_lock_held());
+
+ if (likely(sock_map_sk_is_suitable(sops->sk) &&
+ sock_map_op_okay(sops)))
+ return sock_hash_update_common(map, key, sops->sk, flags);
+ return -EOPNOTSUPP;
+}
+
+const struct bpf_func_proto bpf_sock_hash_update_proto = {
+ .func = bpf_sock_hash_update,
+ .gpl_only = false,
+ .pkt_access = true,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
+
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ tcb->bpf.flags = flags;
+ tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
+ if (!tcb->bpf.sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
+ .func = bpf_sk_redirect_hash,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
+ struct bpf_map *, map, void *, key, u64, flags)
+{
+ if (unlikely(flags & ~(BPF_F_INGRESS)))
+ return SK_DROP;
+ msg->flags = flags;
+ msg->sk_redir = __sock_hash_lookup_elem(map, key);
+ if (!msg->sk_redir)
+ return SK_DROP;
+ return SK_PASS;
+}
+
+const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
+ .func = bpf_msg_redirect_hash,
+ .gpl_only = false,
+ .ret_type = RET_INTEGER,
+ .arg1_type = ARG_PTR_TO_CTX,
+ .arg2_type = ARG_CONST_MAP_PTR,
+ .arg3_type = ARG_PTR_TO_MAP_KEY,
+ .arg4_type = ARG_ANYTHING,
+};
+
+const struct bpf_map_ops sock_hash_ops = {
+ .map_alloc = sock_hash_alloc,
+ .map_free = sock_hash_free,
+ .map_get_next_key = sock_hash_get_next_key,
+ .map_update_elem = sock_hash_update_elem,
+ .map_delete_elem = sock_hash_delete_elem,
+ .map_lookup_elem = sock_map_lookup,
+ .map_release_uref = sock_hash_release_progs,
+ .map_check_btf = map_check_no_btf,
+};
+
+static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
+{
+ switch (map->map_type) {
+ case BPF_MAP_TYPE_SOCKMAP:
+ return &container_of(map, struct bpf_stab, map)->progs;
+ case BPF_MAP_TYPE_SOCKHASH:
+ return &container_of(map, struct bpf_htab, map)->progs;
+ default:
+ break;
+ }
+
+ return NULL;
+}
+
+int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
+ u32 which)
+{
+ struct sk_psock_progs *progs = sock_map_progs(map);
+
+ if (!progs)
+ return -EOPNOTSUPP;
+
+ switch (which) {
+ case BPF_SK_MSG_VERDICT:
+ psock_set_prog(&progs->msg_parser, prog);
+ break;
+ case BPF_SK_SKB_STREAM_PARSER:
+ psock_set_prog(&progs->skb_parser, prog);
+ break;
+ case BPF_SK_SKB_STREAM_VERDICT:
+ psock_set_prog(&progs->skb_verdict, prog);
+ break;
+ default:
+ return -EOPNOTSUPP;
+ }
+
+ return 0;
+}
+
+void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
+{
+ switch (link->map->map_type) {
+ case BPF_MAP_TYPE_SOCKMAP:
+ return sock_map_delete_from_link(link->map, sk,
+ link->link_raw);
+ case BPF_MAP_TYPE_SOCKHASH:
+ return sock_hash_delete_from_link(link->map, sk,
+ link->link_raw);
+ default:
+ break;
+ }
+}
OpenPOWER on IntegriCloud