diff options
Diffstat (limited to 'drivers/vhost')
-rw-r--r-- | drivers/vhost/net.c | 67 | ||||
-rw-r--r-- | drivers/vhost/scsi.c | 32 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 144 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 3 | ||||
-rw-r--r-- | drivers/vhost/vsock.c | 18 |
5 files changed, 203 insertions, 61 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index ab11b2bee273..bca86bf7189f 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -141,6 +141,10 @@ struct vhost_net { unsigned tx_zcopy_err; /* Flush in progress. Protected by tx vq lock. */ bool tx_flush; + /* Private page frag */ + struct page_frag page_frag; + /* Refcount bias of page frag */ + int refcnt_bias; }; static unsigned vhost_net_zcopy_mask __read_mostly; @@ -513,7 +517,13 @@ static void vhost_net_busy_poll(struct vhost_net *net, struct socket *sock; struct vhost_virtqueue *vq = poll_rx ? tvq : rvq; - mutex_lock_nested(&vq->mutex, poll_rx ? VHOST_NET_VQ_TX: VHOST_NET_VQ_RX); + /* Try to hold the vq mutex of the paired virtqueue. We can't + * use mutex_lock() here since we could not guarantee a + * consistenet lock ordering. + */ + if (!mutex_trylock(&vq->mutex)) + return; + vhost_disable_notify(&net->dev, vq); sock = rvq->private_data; @@ -637,14 +647,53 @@ static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len) !vhost_vq_avail_empty(vq->dev, vq); } +#define SKB_FRAG_PAGE_ORDER get_order(32768) + +static bool vhost_net_page_frag_refill(struct vhost_net *net, unsigned int sz, + struct page_frag *pfrag, gfp_t gfp) +{ + if (pfrag->page) { + if (pfrag->offset + sz <= pfrag->size) + return true; + __page_frag_cache_drain(pfrag->page, net->refcnt_bias); + } + + pfrag->offset = 0; + net->refcnt_bias = 0; + if (SKB_FRAG_PAGE_ORDER) { + /* Avoid direct reclaim but allow kswapd to wake */ + pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) | + __GFP_COMP | __GFP_NOWARN | + __GFP_NORETRY, + SKB_FRAG_PAGE_ORDER); + if (likely(pfrag->page)) { + pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER; + goto done; + } + } + pfrag->page = alloc_page(gfp); + if (likely(pfrag->page)) { + pfrag->size = PAGE_SIZE; + goto done; + } + return false; + +done: + net->refcnt_bias = USHRT_MAX; + page_ref_add(pfrag->page, USHRT_MAX - 1); + return true; +} + #define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD) static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, struct iov_iter *from) { struct vhost_virtqueue *vq = &nvq->vq; + struct vhost_net *net = container_of(vq->dev, struct vhost_net, + dev); struct socket *sock = vq->private_data; - struct page_frag *alloc_frag = ¤t->task_frag; + struct page_frag *alloc_frag = &net->page_frag; struct virtio_net_hdr *gso; struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp]; struct tun_xdp_hdr *hdr; @@ -665,7 +714,8 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, buflen += SKB_DATA_ALIGN(len + pad); alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES); - if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL))) + if (unlikely(!vhost_net_page_frag_refill(net, buflen, + alloc_frag, GFP_KERNEL))) return -ENOMEM; buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset; @@ -703,7 +753,7 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq, xdp->data_end = xdp->data + len; hdr->buflen = buflen; - get_page(alloc_frag->page); + --net->refcnt_bias; alloc_frag->offset += buflen; ++nvq->batched_xdp; @@ -1186,7 +1236,8 @@ static void handle_rx(struct vhost_net *net) if (nvq->done_idx > VHOST_NET_BATCH) vhost_net_signal_used(nvq); if (unlikely(vq_log)) - vhost_log_write(vq, vq_log, log, vhost_len); + vhost_log_write(vq, vq_log, log, vhost_len, + vq->iov, in); total_len += vhost_len; if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) { vhost_poll_queue(&vq->poll); @@ -1292,6 +1343,8 @@ static int vhost_net_open(struct inode *inode, struct file *f) vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev); f->private_data = n; + n->page_frag.page = NULL; + n->refcnt_bias = 0; return 0; } @@ -1359,13 +1412,15 @@ static int vhost_net_release(struct inode *inode, struct file *f) if (rx_sock) sockfd_put(rx_sock); /* Make sure no callbacks are outstanding */ - synchronize_rcu_bh(); + synchronize_rcu(); /* We do an extra flush before freeing memory, * since jobs can re-queue themselves. */ vhost_net_flush(n); kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue); kfree(n->vqs[VHOST_NET_VQ_TX].xdp); kfree(n->dev.vqs); + if (n->page_frag.page) + __page_frag_cache_drain(n->page_frag.page, n->refcnt_bias); kvfree(n); return 0; } diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c index 50dffe83714c..344684f3e2e4 100644 --- a/drivers/vhost/scsi.c +++ b/drivers/vhost/scsi.c @@ -285,11 +285,6 @@ static int vhost_scsi_check_false(struct se_portal_group *se_tpg) return 0; } -static char *vhost_scsi_get_fabric_name(void) -{ - return "vhost"; -} - static char *vhost_scsi_get_fabric_wwn(struct se_portal_group *se_tpg) { struct vhost_scsi_tpg *tpg = container_of(se_tpg, @@ -889,7 +884,7 @@ vhost_scsi_get_req(struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc, if (unlikely(!copy_from_iter_full(vc->req, vc->req_size, &vc->out_iter))) { - vq_err(vq, "Faulted on copy_from_iter\n"); + vq_err(vq, "Faulted on copy_from_iter_full\n"); } else if (unlikely(*vc->lunp != 1)) { /* virtio-scsi spec requires byte 0 of the lun to be 1 */ vq_err(vq, "Illegal virtio-scsi lun: %u\n", *vc->lunp); @@ -1132,16 +1127,18 @@ vhost_scsi_send_tmf_reject(struct vhost_scsi *vs, struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc) { - struct virtio_scsi_ctrl_tmf_resp __user *resp; struct virtio_scsi_ctrl_tmf_resp rsp; + struct iov_iter iov_iter; int ret; pr_debug("%s\n", __func__); memset(&rsp, 0, sizeof(rsp)); rsp.response = VIRTIO_SCSI_S_FUNCTION_REJECTED; - resp = vq->iov[vc->out].iov_base; - ret = __copy_to_user(resp, &rsp, sizeof(rsp)); - if (!ret) + + iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp)); + + ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter); + if (likely(ret == sizeof(rsp))) vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0); else pr_err("Faulted on virtio_scsi_ctrl_tmf_resp\n"); @@ -1152,16 +1149,18 @@ vhost_scsi_send_an_resp(struct vhost_scsi *vs, struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc) { - struct virtio_scsi_ctrl_an_resp __user *resp; struct virtio_scsi_ctrl_an_resp rsp; + struct iov_iter iov_iter; int ret; pr_debug("%s\n", __func__); memset(&rsp, 0, sizeof(rsp)); /* event_actual = 0 */ rsp.response = VIRTIO_SCSI_S_OK; - resp = vq->iov[vc->out].iov_base; - ret = __copy_to_user(resp, &rsp, sizeof(rsp)); - if (!ret) + + iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp)); + + ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter); + if (likely(ret == sizeof(rsp))) vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0); else pr_err("Faulted on virtio_scsi_ctrl_an_resp\n"); @@ -1441,7 +1440,7 @@ vhost_scsi_set_endpoint(struct vhost_scsi *vs, se_tpg = &tpg->se_tpg; ret = target_depend_item(&se_tpg->tpg_group.cg_item); if (ret) { - pr_warn("configfs_depend_item() failed: %d\n", ret); + pr_warn("target_depend_item() failed: %d\n", ret); kfree(vs_tpg); mutex_unlock(&tpg->tv_tpg_mutex); goto out; @@ -2289,8 +2288,7 @@ static struct configfs_attribute *vhost_scsi_wwn_attrs[] = { static const struct target_core_fabric_ops vhost_scsi_ops = { .module = THIS_MODULE, - .name = "vhost", - .get_fabric_name = vhost_scsi_get_fabric_name, + .fabric_name = "vhost", .tpg_get_wwn = vhost_scsi_get_fabric_wwn, .tpg_get_tag = vhost_scsi_get_tpgt, .tpg_check_demo_mode = vhost_scsi_check_true, diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 6b98d8e3a5bf..15a216cdd507 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -295,11 +295,8 @@ static void vhost_vq_meta_reset(struct vhost_dev *d) { int i; - for (i = 0; i < d->nvqs; ++i) { - mutex_lock(&d->vqs[i]->mutex); + for (i = 0; i < d->nvqs; ++i) __vhost_vq_meta_reset(d->vqs[i]); - mutex_unlock(&d->vqs[i]->mutex); - } } static void vhost_vq_reset(struct vhost_dev *dev, @@ -658,7 +655,7 @@ static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz) a + (unsigned long)log_base > ULONG_MAX) return false; - return access_ok(VERIFY_WRITE, log_base + a, + return access_ok(log_base + a, (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); } @@ -684,7 +681,7 @@ static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, return false; - if (!access_ok(VERIFY_WRITE, (void __user *)a, + if (!access_ok((void __user *)a, node->size)) return false; else if (log_all && !log_access_ok(log_base, @@ -895,6 +892,20 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, #define vhost_get_used(vq, x, ptr) \ vhost_get_user(vq, x, ptr, VHOST_ADDR_USED) +static void vhost_dev_lock_vqs(struct vhost_dev *d) +{ + int i = 0; + for (i = 0; i < d->nvqs; ++i) + mutex_lock_nested(&d->vqs[i]->mutex, i); +} + +static void vhost_dev_unlock_vqs(struct vhost_dev *d) +{ + int i = 0; + for (i = 0; i < d->nvqs; ++i) + mutex_unlock(&d->vqs[i]->mutex); +} + static int vhost_new_umem_range(struct vhost_umem *umem, u64 start, u64 size, u64 end, u64 userspace_addr, int perm) @@ -962,10 +973,10 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access) return false; if ((access & VHOST_ACCESS_RO) && - !access_ok(VERIFY_READ, (void __user *)a, size)) + !access_ok((void __user *)a, size)) return false; if ((access & VHOST_ACCESS_WO) && - !access_ok(VERIFY_WRITE, (void __user *)a, size)) + !access_ok((void __user *)a, size)) return false; return true; } @@ -976,6 +987,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, int ret = 0; mutex_lock(&dev->mutex); + vhost_dev_lock_vqs(dev); switch (msg->type) { case VHOST_IOTLB_UPDATE: if (!dev->iotlb) { @@ -1009,6 +1021,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev, break; } + vhost_dev_unlock_vqs(dev); mutex_unlock(&dev->mutex); return ret; @@ -1021,8 +1034,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, int type, ret; ret = copy_from_iter(&type, sizeof(type), from); - if (ret != sizeof(type)) + if (ret != sizeof(type)) { + ret = -EINVAL; goto done; + } switch (type) { case VHOST_IOTLB_MSG: @@ -1041,8 +1056,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev, iov_iter_advance(from, offset); ret = copy_from_iter(&msg, sizeof(msg), from); - if (ret != sizeof(msg)) + if (ret != sizeof(msg)) { + ret = -EINVAL; goto done; + } if (vhost_process_iotlb_msg(dev, &msg)) { ret = -EFAULT; goto done; @@ -1172,10 +1189,10 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, { size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; - return access_ok(VERIFY_READ, desc, num * sizeof *desc) && - access_ok(VERIFY_READ, avail, + return access_ok(desc, num * sizeof *desc) && + access_ok(avail, sizeof *avail + num * sizeof *avail->ring + s) && - access_ok(VERIFY_WRITE, used, + access_ok(used, sizeof *used + num * sizeof *used->ring + s); } @@ -1720,13 +1737,87 @@ static int log_write(void __user *log_base, return r; } +static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len) +{ + struct vhost_umem *umem = vq->umem; + struct vhost_umem_node *u; + u64 start, end, l, min; + int r; + bool hit = false; + + while (len) { + min = len; + /* More than one GPAs can be mapped into a single HVA. So + * iterate all possible umems here to be safe. + */ + list_for_each_entry(u, &umem->umem_list, link) { + if (u->userspace_addr > hva - 1 + len || + u->userspace_addr - 1 + u->size < hva) + continue; + start = max(u->userspace_addr, hva); + end = min(u->userspace_addr - 1 + u->size, + hva - 1 + len); + l = end - start + 1; + r = log_write(vq->log_base, + u->start + start - u->userspace_addr, + l); + if (r < 0) + return r; + hit = true; + min = min(l, min); + } + + if (!hit) + return -EFAULT; + + len -= min; + hva += min; + } + + return 0; +} + +static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) +{ + struct iovec iov[64]; + int i, ret; + + if (!vq->iotlb) + return log_write(vq->log_base, vq->log_addr + used_offset, len); + + ret = translate_desc(vq, (uintptr_t)vq->used + used_offset, + len, iov, 64, VHOST_ACCESS_WO); + if (ret) + return ret; + + for (i = 0; i < ret; i++) { + ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base, + iov[i].iov_len); + if (ret) + return ret; + } + + return 0; +} + int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, - unsigned int log_num, u64 len) + unsigned int log_num, u64 len, struct iovec *iov, int count) { int i, r; /* Make sure data written is seen before log. */ smp_wmb(); + + if (vq->iotlb) { + for (i = 0; i < count; i++) { + r = log_write_hva(vq, (uintptr_t)iov[i].iov_base, + iov[i].iov_len); + if (r < 0) + return r; + } + return 0; + } + for (i = 0; i < log_num; ++i) { u64 l = min(log[i].len, len); r = log_write(vq->log_base, log[i].addr, l); @@ -1756,9 +1847,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq) smp_wmb(); /* Log used flag write. */ used = &vq->used->flags; - log_write(vq->log_base, vq->log_addr + - (used - (void __user *)vq->used), - sizeof vq->used->flags); + log_used(vq, (used - (void __user *)vq->used), + sizeof vq->used->flags); if (vq->log_ctx) eventfd_signal(vq->log_ctx, 1); } @@ -1776,9 +1866,8 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) smp_wmb(); /* Log avail event write */ used = vhost_avail_event(vq); - log_write(vq->log_base, vq->log_addr + - (used - (void __user *)vq->used), - sizeof *vhost_avail_event(vq)); + log_used(vq, (used - (void __user *)vq->used), + sizeof *vhost_avail_event(vq)); if (vq->log_ctx) eventfd_signal(vq->log_ctx, 1); } @@ -1801,7 +1890,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq) goto err; vq->signalled_used_valid = false; if (!vq->iotlb && - !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { + !access_ok(&vq->used->idx, sizeof vq->used->idx)) { r = -EFAULT; goto err; } @@ -2178,10 +2267,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq, /* Make sure data is seen before log. */ smp_wmb(); /* Log used ring entry write. */ - log_write(vq->log_base, - vq->log_addr + - ((void __user *)used - (void __user *)vq->used), - count * sizeof *used); + log_used(vq, ((void __user *)used - (void __user *)vq->used), + count * sizeof *used); } old = vq->last_used_idx; new = (vq->last_used_idx += count); @@ -2220,10 +2307,11 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, return -EFAULT; } if (unlikely(vq->log_used)) { + /* Make sure used idx is seen before log. */ + smp_wmb(); /* Log used index update. */ - log_write(vq->log_base, - vq->log_addr + offsetof(struct vring_used, idx), - sizeof vq->used->idx); + log_used(vq, offsetof(struct vring_used, idx), + sizeof vq->used->idx); if (vq->log_ctx) eventfd_signal(vq->log_ctx, 1); } diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 466ef7542291..1b675dad5e05 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -205,7 +205,8 @@ bool vhost_vq_avail_empty(struct vhost_dev *, struct vhost_virtqueue *); bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *); int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, - unsigned int log_num, u64 len); + unsigned int log_num, u64 len, + struct iovec *iov, int count); int vq_iotlb_prefetch(struct vhost_virtqueue *vq); struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type); diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c index 98ed5be132c6..3fbc068eaa9b 100644 --- a/drivers/vhost/vsock.c +++ b/drivers/vhost/vsock.c @@ -27,14 +27,14 @@ enum { }; /* Used to track all the vhost_vsock instances on the system. */ -static DEFINE_SPINLOCK(vhost_vsock_lock); +static DEFINE_MUTEX(vhost_vsock_mutex); static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8); struct vhost_vsock { struct vhost_dev dev; struct vhost_virtqueue vqs[2]; - /* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */ + /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */ struct hlist_node hash; struct vhost_work send_pkt_work; @@ -51,7 +51,7 @@ static u32 vhost_transport_get_local_cid(void) return VHOST_VSOCK_DEFAULT_HOST_CID; } -/* Callers that dereference the return value must hold vhost_vsock_lock or the +/* Callers that dereference the return value must hold vhost_vsock_mutex or the * RCU read lock. */ static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) @@ -584,10 +584,10 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file) { struct vhost_vsock *vsock = file->private_data; - spin_lock_bh(&vhost_vsock_lock); + mutex_lock(&vhost_vsock_mutex); if (vsock->guest_cid) hash_del_rcu(&vsock->hash); - spin_unlock_bh(&vhost_vsock_lock); + mutex_unlock(&vhost_vsock_mutex); /* Wait for other CPUs to finish using vsock */ synchronize_rcu(); @@ -631,10 +631,10 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) return -EINVAL; /* Refuse if CID is already in use */ - spin_lock_bh(&vhost_vsock_lock); + mutex_lock(&vhost_vsock_mutex); other = vhost_vsock_get(guest_cid); if (other && other != vsock) { - spin_unlock_bh(&vhost_vsock_lock); + mutex_unlock(&vhost_vsock_mutex); return -EADDRINUSE; } @@ -642,8 +642,8 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) hash_del_rcu(&vsock->hash); vsock->guest_cid = guest_cid; - hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid); - spin_unlock_bh(&vhost_vsock_lock); + hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid); + mutex_unlock(&vhost_vsock_mutex); return 0; } |