diff options
Diffstat (limited to 'drivers/vhost/net.c')
-rw-r--r-- | drivers/vhost/net.c | 159 |
1 files changed, 26 insertions, 133 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index f616cefc95ba..2f7c76a85e53 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -60,6 +60,7 @@ static int move_iovec_hdr(struct iovec *from, struct iovec *to, { int seg = 0; size_t size; + while (len && seg < iov_count) { size = min(from->iov_len, len); to->iov_base = from->iov_base; @@ -79,6 +80,7 @@ static void copy_iovec_hdr(const struct iovec *from, struct iovec *to, { int seg = 0; size_t size; + while (len && seg < iovcount) { size = min(from->iov_len, len); to->iov_base = from->iov_base; @@ -211,12 +213,13 @@ static int peek_head_len(struct sock *sk) { struct sk_buff *head; int len = 0; + unsigned long flags; - lock_sock(sk); + spin_lock_irqsave(&sk->sk_receive_queue.lock, flags); head = skb_peek(&sk->sk_receive_queue); - if (head) + if (likely(head)) len = head->len; - release_sock(sk); + spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags); return len; } @@ -227,6 +230,7 @@ static int peek_head_len(struct sock *sk) * @iovcount - returned count of io vectors we fill * @log - vhost log * @log_num - log offset + * @quota - headcount quota, 1 for big buffer * returns number of buffer heads allocated, negative on error */ static int get_rx_bufs(struct vhost_virtqueue *vq, @@ -234,7 +238,8 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, int datalen, unsigned *iovcount, struct vhost_log *log, - unsigned *log_num) + unsigned *log_num, + unsigned int quota) { unsigned int out, in; int seg = 0; @@ -242,7 +247,7 @@ static int get_rx_bufs(struct vhost_virtqueue *vq, unsigned d; int r, nlogs = 0; - while (datalen > 0) { + while (datalen > 0 && headcount < quota) { if (unlikely(seg >= UIO_MAXIOV)) { r = -ENOBUFS; goto err; @@ -282,117 +287,7 @@ err: /* Expects to be always run from workqueue - which acts as * read-size critical section for our kind of RCU. */ -static void handle_rx_big(struct vhost_net *net) -{ - struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; - unsigned out, in, log, s; - int head; - struct vhost_log *vq_log; - struct msghdr msg = { - .msg_name = NULL, - .msg_namelen = 0, - .msg_control = NULL, /* FIXME: get and handle RX aux data. */ - .msg_controllen = 0, - .msg_iov = vq->iov, - .msg_flags = MSG_DONTWAIT, - }; - - struct virtio_net_hdr hdr = { - .flags = 0, - .gso_type = VIRTIO_NET_HDR_GSO_NONE - }; - - size_t len, total_len = 0; - int err; - size_t hdr_size; - /* TODO: check that we are running from vhost_worker? */ - struct socket *sock = rcu_dereference_check(vq->private_data, 1); - if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue)) - return; - - mutex_lock(&vq->mutex); - vhost_disable_notify(vq); - hdr_size = vq->vhost_hlen; - - vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? - vq->log : NULL; - - for (;;) { - head = vhost_get_vq_desc(&net->dev, vq, vq->iov, - ARRAY_SIZE(vq->iov), - &out, &in, - vq_log, &log); - /* On error, stop handling until the next kick. */ - if (unlikely(head < 0)) - break; - /* OK, now we need to know about added descriptors. */ - if (head == vq->num) { - if (unlikely(vhost_enable_notify(vq))) { - /* They have slipped one in as we were - * doing that: check again. */ - vhost_disable_notify(vq); - continue; - } - /* Nothing new? Wait for eventfd to tell us - * they refilled. */ - break; - } - /* We don't need to be notified again. */ - if (out) { - vq_err(vq, "Unexpected descriptor format for RX: " - "out %d, int %d\n", - out, in); - break; - } - /* Skip header. TODO: support TSO/mergeable rx buffers. */ - s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, in); - msg.msg_iovlen = in; - len = iov_length(vq->iov, in); - /* Sanity check */ - if (!len) { - vq_err(vq, "Unexpected header len for RX: " - "%zd expected %zd\n", - iov_length(vq->hdr, s), hdr_size); - break; - } - err = sock->ops->recvmsg(NULL, sock, &msg, - len, MSG_DONTWAIT | MSG_TRUNC); - /* TODO: Check specific error and bomb out unless EAGAIN? */ - if (err < 0) { - vhost_discard_vq_desc(vq, 1); - break; - } - /* TODO: Should check and handle checksum. */ - if (err > len) { - pr_debug("Discarded truncated rx packet: " - " len %d > %zd\n", err, len); - vhost_discard_vq_desc(vq, 1); - continue; - } - len = err; - err = memcpy_toiovec(vq->hdr, (unsigned char *)&hdr, hdr_size); - if (err) { - vq_err(vq, "Unable to write vnet_hdr at addr %p: %d\n", - vq->iov->iov_base, err); - break; - } - len += hdr_size; - vhost_add_used_and_signal(&net->dev, vq, head, len); - if (unlikely(vq_log)) - vhost_log_write(vq, vq_log, log, len); - total_len += len; - if (unlikely(total_len >= VHOST_NET_WEIGHT)) { - vhost_poll_queue(&vq->poll); - break; - } - } - - mutex_unlock(&vq->mutex); -} - -/* Expects to be always run from workqueue - which acts as - * read-size critical section for our kind of RCU. */ -static void handle_rx_mergeable(struct vhost_net *net) +static void handle_rx(struct vhost_net *net) { struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX]; unsigned uninitialized_var(in), log; @@ -405,19 +300,18 @@ static void handle_rx_mergeable(struct vhost_net *net) .msg_iov = vq->iov, .msg_flags = MSG_DONTWAIT, }; - struct virtio_net_hdr_mrg_rxbuf hdr = { .hdr.flags = 0, .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE }; - size_t total_len = 0; - int err, headcount; + int err, headcount, mergeable; size_t vhost_hlen, sock_hlen; size_t vhost_len, sock_len; /* TODO: check that we are running from vhost_worker? */ struct socket *sock = rcu_dereference_check(vq->private_data, 1); - if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue)) + + if (!sock) return; mutex_lock(&vq->mutex); @@ -427,12 +321,14 @@ static void handle_rx_mergeable(struct vhost_net *net) vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ? vq->log : NULL; + mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF); while ((sock_len = peek_head_len(sock->sk))) { sock_len += sock_hlen; vhost_len = sock_len + vhost_hlen; headcount = get_rx_bufs(vq, vq->heads, vhost_len, - &in, vq_log, &log); + &in, vq_log, &log, + likely(mergeable) ? UIO_MAXIOV : 1); /* On error, stop handling until the next kick. */ if (unlikely(headcount < 0)) break; @@ -476,7 +372,7 @@ static void handle_rx_mergeable(struct vhost_net *net) break; } /* TODO: Should check and handle checksum. */ - if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF) && + if (likely(mergeable) && memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount, offsetof(typeof(hdr), num_buffers), sizeof hdr.num_buffers)) { @@ -498,14 +394,6 @@ static void handle_rx_mergeable(struct vhost_net *net) mutex_unlock(&vq->mutex); } -static void handle_rx(struct vhost_net *net) -{ - if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF)) - handle_rx_mergeable(net); - else - handle_rx_big(net); -} - static void handle_tx_kick(struct vhost_work *work) { struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, @@ -654,6 +542,7 @@ static struct socket *get_raw_socket(int fd) } uaddr; int uaddr_len = sizeof uaddr, r; struct socket *sock = sockfd_lookup(fd, &r); + if (!sock) return ERR_PTR(-ENOTSOCK); @@ -682,6 +571,7 @@ static struct socket *get_tap_socket(int fd) { struct file *file = fget(fd); struct socket *sock; + if (!file) return ERR_PTR(-EBADF); sock = tun_get_socket(file); @@ -696,6 +586,7 @@ static struct socket *get_tap_socket(int fd) static struct socket *get_socket(int fd) { struct socket *sock; + /* special case to disable backend */ if (fd == -1) return NULL; @@ -741,9 +632,9 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) oldsock = rcu_dereference_protected(vq->private_data, lockdep_is_held(&vq->mutex)); if (sock != oldsock) { - vhost_net_disable_vq(n, vq); - rcu_assign_pointer(vq->private_data, sock); - vhost_net_enable_vq(n, vq); + vhost_net_disable_vq(n, vq); + rcu_assign_pointer(vq->private_data, sock); + vhost_net_enable_vq(n, vq); } mutex_unlock(&vq->mutex); @@ -768,6 +659,7 @@ static long vhost_net_reset_owner(struct vhost_net *n) struct socket *tx_sock = NULL; struct socket *rx_sock = NULL; long err; + mutex_lock(&n->dev.mutex); err = vhost_dev_check_owner(&n->dev); if (err) @@ -829,6 +721,7 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl, struct vhost_vring_file backend; u64 features; int r; + switch (ioctl) { case VHOST_NET_SET_BACKEND: if (copy_from_user(&backend, argp, sizeof backend)) |