summaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/net.c67
-rw-r--r--drivers/vhost/scsi.c32
-rw-r--r--drivers/vhost/vhost.c144
-rw-r--r--drivers/vhost/vhost.h3
-rw-r--r--drivers/vhost/vsock.c18
5 files changed, 203 insertions, 61 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index ab11b2bee273..bca86bf7189f 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -141,6 +141,10 @@ struct vhost_net {
unsigned tx_zcopy_err;
/* Flush in progress. Protected by tx vq lock. */
bool tx_flush;
+ /* Private page frag */
+ struct page_frag page_frag;
+ /* Refcount bias of page frag */
+ int refcnt_bias;
};
static unsigned vhost_net_zcopy_mask __read_mostly;
@@ -513,7 +517,13 @@ static void vhost_net_busy_poll(struct vhost_net *net,
struct socket *sock;
struct vhost_virtqueue *vq = poll_rx ? tvq : rvq;
- mutex_lock_nested(&vq->mutex, poll_rx ? VHOST_NET_VQ_TX: VHOST_NET_VQ_RX);
+ /* Try to hold the vq mutex of the paired virtqueue. We can't
+ * use mutex_lock() here since we could not guarantee a
+ * consistenet lock ordering.
+ */
+ if (!mutex_trylock(&vq->mutex))
+ return;
+
vhost_disable_notify(&net->dev, vq);
sock = rvq->private_data;
@@ -637,14 +647,53 @@ static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len)
!vhost_vq_avail_empty(vq->dev, vq);
}
+#define SKB_FRAG_PAGE_ORDER get_order(32768)
+
+static bool vhost_net_page_frag_refill(struct vhost_net *net, unsigned int sz,
+ struct page_frag *pfrag, gfp_t gfp)
+{
+ if (pfrag->page) {
+ if (pfrag->offset + sz <= pfrag->size)
+ return true;
+ __page_frag_cache_drain(pfrag->page, net->refcnt_bias);
+ }
+
+ pfrag->offset = 0;
+ net->refcnt_bias = 0;
+ if (SKB_FRAG_PAGE_ORDER) {
+ /* Avoid direct reclaim but allow kswapd to wake */
+ pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) |
+ __GFP_COMP | __GFP_NOWARN |
+ __GFP_NORETRY,
+ SKB_FRAG_PAGE_ORDER);
+ if (likely(pfrag->page)) {
+ pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER;
+ goto done;
+ }
+ }
+ pfrag->page = alloc_page(gfp);
+ if (likely(pfrag->page)) {
+ pfrag->size = PAGE_SIZE;
+ goto done;
+ }
+ return false;
+
+done:
+ net->refcnt_bias = USHRT_MAX;
+ page_ref_add(pfrag->page, USHRT_MAX - 1);
+ return true;
+}
+
#define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD)
static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
struct iov_iter *from)
{
struct vhost_virtqueue *vq = &nvq->vq;
+ struct vhost_net *net = container_of(vq->dev, struct vhost_net,
+ dev);
struct socket *sock = vq->private_data;
- struct page_frag *alloc_frag = &current->task_frag;
+ struct page_frag *alloc_frag = &net->page_frag;
struct virtio_net_hdr *gso;
struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp];
struct tun_xdp_hdr *hdr;
@@ -665,7 +714,8 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
buflen += SKB_DATA_ALIGN(len + pad);
alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
- if (unlikely(!skb_page_frag_refill(buflen, alloc_frag, GFP_KERNEL)))
+ if (unlikely(!vhost_net_page_frag_refill(net, buflen,
+ alloc_frag, GFP_KERNEL)))
return -ENOMEM;
buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
@@ -703,7 +753,7 @@ static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
xdp->data_end = xdp->data + len;
hdr->buflen = buflen;
- get_page(alloc_frag->page);
+ --net->refcnt_bias;
alloc_frag->offset += buflen;
++nvq->batched_xdp;
@@ -1186,7 +1236,8 @@ static void handle_rx(struct vhost_net *net)
if (nvq->done_idx > VHOST_NET_BATCH)
vhost_net_signal_used(nvq);
if (unlikely(vq_log))
- vhost_log_write(vq, vq_log, log, vhost_len);
+ vhost_log_write(vq, vq_log, log, vhost_len,
+ vq->iov, in);
total_len += vhost_len;
if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) {
vhost_poll_queue(&vq->poll);
@@ -1292,6 +1343,8 @@ static int vhost_net_open(struct inode *inode, struct file *f)
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
f->private_data = n;
+ n->page_frag.page = NULL;
+ n->refcnt_bias = 0;
return 0;
}
@@ -1359,13 +1412,15 @@ static int vhost_net_release(struct inode *inode, struct file *f)
if (rx_sock)
sockfd_put(rx_sock);
/* Make sure no callbacks are outstanding */
- synchronize_rcu_bh();
+ synchronize_rcu();
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
vhost_net_flush(n);
kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
kfree(n->vqs[VHOST_NET_VQ_TX].xdp);
kfree(n->dev.vqs);
+ if (n->page_frag.page)
+ __page_frag_cache_drain(n->page_frag.page, n->refcnt_bias);
kvfree(n);
return 0;
}
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 50dffe83714c..344684f3e2e4 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -285,11 +285,6 @@ static int vhost_scsi_check_false(struct se_portal_group *se_tpg)
return 0;
}
-static char *vhost_scsi_get_fabric_name(void)
-{
- return "vhost";
-}
-
static char *vhost_scsi_get_fabric_wwn(struct se_portal_group *se_tpg)
{
struct vhost_scsi_tpg *tpg = container_of(se_tpg,
@@ -889,7 +884,7 @@ vhost_scsi_get_req(struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc,
if (unlikely(!copy_from_iter_full(vc->req, vc->req_size,
&vc->out_iter))) {
- vq_err(vq, "Faulted on copy_from_iter\n");
+ vq_err(vq, "Faulted on copy_from_iter_full\n");
} else if (unlikely(*vc->lunp != 1)) {
/* virtio-scsi spec requires byte 0 of the lun to be 1 */
vq_err(vq, "Illegal virtio-scsi lun: %u\n", *vc->lunp);
@@ -1132,16 +1127,18 @@ vhost_scsi_send_tmf_reject(struct vhost_scsi *vs,
struct vhost_virtqueue *vq,
struct vhost_scsi_ctx *vc)
{
- struct virtio_scsi_ctrl_tmf_resp __user *resp;
struct virtio_scsi_ctrl_tmf_resp rsp;
+ struct iov_iter iov_iter;
int ret;
pr_debug("%s\n", __func__);
memset(&rsp, 0, sizeof(rsp));
rsp.response = VIRTIO_SCSI_S_FUNCTION_REJECTED;
- resp = vq->iov[vc->out].iov_base;
- ret = __copy_to_user(resp, &rsp, sizeof(rsp));
- if (!ret)
+
+ iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+ ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+ if (likely(ret == sizeof(rsp)))
vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
else
pr_err("Faulted on virtio_scsi_ctrl_tmf_resp\n");
@@ -1152,16 +1149,18 @@ vhost_scsi_send_an_resp(struct vhost_scsi *vs,
struct vhost_virtqueue *vq,
struct vhost_scsi_ctx *vc)
{
- struct virtio_scsi_ctrl_an_resp __user *resp;
struct virtio_scsi_ctrl_an_resp rsp;
+ struct iov_iter iov_iter;
int ret;
pr_debug("%s\n", __func__);
memset(&rsp, 0, sizeof(rsp)); /* event_actual = 0 */
rsp.response = VIRTIO_SCSI_S_OK;
- resp = vq->iov[vc->out].iov_base;
- ret = __copy_to_user(resp, &rsp, sizeof(rsp));
- if (!ret)
+
+ iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+ ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+ if (likely(ret == sizeof(rsp)))
vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
else
pr_err("Faulted on virtio_scsi_ctrl_an_resp\n");
@@ -1441,7 +1440,7 @@ vhost_scsi_set_endpoint(struct vhost_scsi *vs,
se_tpg = &tpg->se_tpg;
ret = target_depend_item(&se_tpg->tpg_group.cg_item);
if (ret) {
- pr_warn("configfs_depend_item() failed: %d\n", ret);
+ pr_warn("target_depend_item() failed: %d\n", ret);
kfree(vs_tpg);
mutex_unlock(&tpg->tv_tpg_mutex);
goto out;
@@ -2289,8 +2288,7 @@ static struct configfs_attribute *vhost_scsi_wwn_attrs[] = {
static const struct target_core_fabric_ops vhost_scsi_ops = {
.module = THIS_MODULE,
- .name = "vhost",
- .get_fabric_name = vhost_scsi_get_fabric_name,
+ .fabric_name = "vhost",
.tpg_get_wwn = vhost_scsi_get_fabric_wwn,
.tpg_get_tag = vhost_scsi_get_tpgt,
.tpg_check_demo_mode = vhost_scsi_check_true,
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 6b98d8e3a5bf..15a216cdd507 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -295,11 +295,8 @@ static void vhost_vq_meta_reset(struct vhost_dev *d)
{
int i;
- for (i = 0; i < d->nvqs; ++i) {
- mutex_lock(&d->vqs[i]->mutex);
+ for (i = 0; i < d->nvqs; ++i)
__vhost_vq_meta_reset(d->vqs[i]);
- mutex_unlock(&d->vqs[i]->mutex);
- }
}
static void vhost_vq_reset(struct vhost_dev *dev,
@@ -658,7 +655,7 @@ static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
a + (unsigned long)log_base > ULONG_MAX)
return false;
- return access_ok(VERIFY_WRITE, log_base + a,
+ return access_ok(log_base + a,
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
}
@@ -684,7 +681,7 @@ static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
return false;
- if (!access_ok(VERIFY_WRITE, (void __user *)a,
+ if (!access_ok((void __user *)a,
node->size))
return false;
else if (log_all && !log_access_ok(log_base,
@@ -895,6 +892,20 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
#define vhost_get_used(vq, x, ptr) \
vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_lock_nested(&d->vqs[i]->mutex, i);
+}
+
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_unlock(&d->vqs[i]->mutex);
+}
+
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
@@ -962,10 +973,10 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access)
return false;
if ((access & VHOST_ACCESS_RO) &&
- !access_ok(VERIFY_READ, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
if ((access & VHOST_ACCESS_WO) &&
- !access_ok(VERIFY_WRITE, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
return true;
}
@@ -976,6 +987,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
int ret = 0;
mutex_lock(&dev->mutex);
+ vhost_dev_lock_vqs(dev);
switch (msg->type) {
case VHOST_IOTLB_UPDATE:
if (!dev->iotlb) {
@@ -1009,6 +1021,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
break;
}
+ vhost_dev_unlock_vqs(dev);
mutex_unlock(&dev->mutex);
return ret;
@@ -1021,8 +1034,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
int type, ret;
ret = copy_from_iter(&type, sizeof(type), from);
- if (ret != sizeof(type))
+ if (ret != sizeof(type)) {
+ ret = -EINVAL;
goto done;
+ }
switch (type) {
case VHOST_IOTLB_MSG:
@@ -1041,8 +1056,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
iov_iter_advance(from, offset);
ret = copy_from_iter(&msg, sizeof(msg), from);
- if (ret != sizeof(msg))
+ if (ret != sizeof(msg)) {
+ ret = -EINVAL;
goto done;
+ }
if (vhost_process_iotlb_msg(dev, &msg)) {
ret = -EFAULT;
goto done;
@@ -1172,10 +1189,10 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
{
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
- return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
- access_ok(VERIFY_READ, avail,
+ return access_ok(desc, num * sizeof *desc) &&
+ access_ok(avail,
sizeof *avail + num * sizeof *avail->ring + s) &&
- access_ok(VERIFY_WRITE, used,
+ access_ok(used,
sizeof *used + num * sizeof *used->ring + s);
}
@@ -1720,13 +1737,87 @@ static int log_write(void __user *log_base,
return r;
}
+static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
+{
+ struct vhost_umem *umem = vq->umem;
+ struct vhost_umem_node *u;
+ u64 start, end, l, min;
+ int r;
+ bool hit = false;
+
+ while (len) {
+ min = len;
+ /* More than one GPAs can be mapped into a single HVA. So
+ * iterate all possible umems here to be safe.
+ */
+ list_for_each_entry(u, &umem->umem_list, link) {
+ if (u->userspace_addr > hva - 1 + len ||
+ u->userspace_addr - 1 + u->size < hva)
+ continue;
+ start = max(u->userspace_addr, hva);
+ end = min(u->userspace_addr - 1 + u->size,
+ hva - 1 + len);
+ l = end - start + 1;
+ r = log_write(vq->log_base,
+ u->start + start - u->userspace_addr,
+ l);
+ if (r < 0)
+ return r;
+ hit = true;
+ min = min(l, min);
+ }
+
+ if (!hit)
+ return -EFAULT;
+
+ len -= min;
+ hva += min;
+ }
+
+ return 0;
+}
+
+static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
+{
+ struct iovec iov[64];
+ int i, ret;
+
+ if (!vq->iotlb)
+ return log_write(vq->log_base, vq->log_addr + used_offset, len);
+
+ ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
+ len, iov, 64, VHOST_ACCESS_WO);
+ if (ret)
+ return ret;
+
+ for (i = 0; i < ret; i++) {
+ ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (ret)
+ return ret;
+ }
+
+ return 0;
+}
+
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
- unsigned int log_num, u64 len)
+ unsigned int log_num, u64 len, struct iovec *iov, int count)
{
int i, r;
/* Make sure data written is seen before log. */
smp_wmb();
+
+ if (vq->iotlb) {
+ for (i = 0; i < count; i++) {
+ r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (r < 0)
+ return r;
+ }
+ return 0;
+ }
+
for (i = 0; i < log_num; ++i) {
u64 l = min(log[i].len, len);
r = log_write(vq->log_base, log[i].addr, l);
@@ -1756,9 +1847,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
smp_wmb();
/* Log used flag write. */
used = &vq->used->flags;
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof vq->used->flags);
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof vq->used->flags);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -1776,9 +1866,8 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
smp_wmb();
/* Log avail event write */
used = vhost_avail_event(vq);
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof *vhost_avail_event(vq));
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof *vhost_avail_event(vq));
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -1801,7 +1890,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
goto err;
vq->signalled_used_valid = false;
if (!vq->iotlb &&
- !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
+ !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
r = -EFAULT;
goto err;
}
@@ -2178,10 +2267,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
/* Make sure data is seen before log. */
smp_wmb();
/* Log used ring entry write. */
- log_write(vq->log_base,
- vq->log_addr +
- ((void __user *)used - (void __user *)vq->used),
- count * sizeof *used);
+ log_used(vq, ((void __user *)used - (void __user *)vq->used),
+ count * sizeof *used);
}
old = vq->last_used_idx;
new = (vq->last_used_idx += count);
@@ -2220,10 +2307,11 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
return -EFAULT;
}
if (unlikely(vq->log_used)) {
+ /* Make sure used idx is seen before log. */
+ smp_wmb();
/* Log used index update. */
- log_write(vq->log_base,
- vq->log_addr + offsetof(struct vring_used, idx),
- sizeof vq->used->idx);
+ log_used(vq, offsetof(struct vring_used, idx),
+ sizeof vq->used->idx);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 466ef7542291..1b675dad5e05 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -205,7 +205,8 @@ bool vhost_vq_avail_empty(struct vhost_dev *, struct vhost_virtqueue *);
bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
- unsigned int log_num, u64 len);
+ unsigned int log_num, u64 len,
+ struct iovec *iov, int count);
int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 98ed5be132c6..3fbc068eaa9b 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -27,14 +27,14 @@ enum {
};
/* Used to track all the vhost_vsock instances on the system. */
-static DEFINE_SPINLOCK(vhost_vsock_lock);
+static DEFINE_MUTEX(vhost_vsock_mutex);
static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
struct vhost_vsock {
struct vhost_dev dev;
struct vhost_virtqueue vqs[2];
- /* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */
+ /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
struct hlist_node hash;
struct vhost_work send_pkt_work;
@@ -51,7 +51,7 @@ static u32 vhost_transport_get_local_cid(void)
return VHOST_VSOCK_DEFAULT_HOST_CID;
}
-/* Callers that dereference the return value must hold vhost_vsock_lock or the
+/* Callers that dereference the return value must hold vhost_vsock_mutex or the
* RCU read lock.
*/
static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
@@ -584,10 +584,10 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
{
struct vhost_vsock *vsock = file->private_data;
- spin_lock_bh(&vhost_vsock_lock);
+ mutex_lock(&vhost_vsock_mutex);
if (vsock->guest_cid)
hash_del_rcu(&vsock->hash);
- spin_unlock_bh(&vhost_vsock_lock);
+ mutex_unlock(&vhost_vsock_mutex);
/* Wait for other CPUs to finish using vsock */
synchronize_rcu();
@@ -631,10 +631,10 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
return -EINVAL;
/* Refuse if CID is already in use */
- spin_lock_bh(&vhost_vsock_lock);
+ mutex_lock(&vhost_vsock_mutex);
other = vhost_vsock_get(guest_cid);
if (other && other != vsock) {
- spin_unlock_bh(&vhost_vsock_lock);
+ mutex_unlock(&vhost_vsock_mutex);
return -EADDRINUSE;
}
@@ -642,8 +642,8 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
hash_del_rcu(&vsock->hash);
vsock->guest_cid = guest_cid;
- hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid);
- spin_unlock_bh(&vhost_vsock_lock);
+ hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
+ mutex_unlock(&vhost_vsock_mutex);
return 0;
}
OpenPOWER on IntegriCloud