diff options
Diffstat (limited to 'include/linux/skmsg.h')
-rw-r--r-- | include/linux/skmsg.h | 42 |
1 files changed, 33 insertions, 9 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 0b919f0bc6d6..2a11e9d91dfa 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -176,6 +176,7 @@ static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src, { dst->sg.data[which] = src->sg.data[which]; dst->sg.data[which].length = size; + dst->sg.size += size; src->sg.data[which].length -= size; src->sg.data[which].offset += size; } @@ -186,21 +187,29 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src) sk_msg_init(src); } +static inline bool sk_msg_full(const struct sk_msg *msg) +{ + return (msg->sg.end == msg->sg.start) && msg->sg.size; +} + static inline u32 sk_msg_elem_used(const struct sk_msg *msg) { + if (sk_msg_full(msg)) + return MAX_MSG_FRAGS; + return msg->sg.end >= msg->sg.start ? msg->sg.end - msg->sg.start : msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start); } -static inline bool sk_msg_full(const struct sk_msg *msg) +static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which) { - return (msg->sg.end == msg->sg.start) && msg->sg.size; + return &msg->sg.data[which]; } -static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which) +static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which) { - return &msg->sg.data[which]; + return msg->sg.data[which]; } static inline struct page *sk_msg_page(struct sk_msg *msg, int which) @@ -266,11 +275,6 @@ static inline struct sk_psock *sk_psock(const struct sock *sk) return rcu_dereference_sk_user_data(sk); } -static inline bool sk_has_psock(struct sock *sk) -{ - return sk_psock(sk) != NULL && sk->sk_prot->recvmsg == tcp_bpf_recvmsg; -} - static inline void sk_psock_queue_msg(struct sk_psock *psock, struct sk_msg *msg) { @@ -370,6 +374,26 @@ static inline bool sk_psock_test_state(const struct sk_psock *psock, return test_bit(bit, &psock->state); } +static inline struct sk_psock *sk_psock_get_checked(struct sock *sk) +{ + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) { + if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) { + psock = ERR_PTR(-EBUSY); + goto out; + } + + if (!refcount_inc_not_zero(&psock->refcnt)) + psock = ERR_PTR(-EBUSY); + } +out: + rcu_read_unlock(); + return psock; +} + static inline struct sk_psock *sk_psock_get(struct sock *sk) { struct sk_psock *psock; |