summaryrefslogtreecommitdiffstats
path: root/drivers/vhost/vhost.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost/vhost.c')
-rw-r--r--drivers/vhost/vhost.c147
1 files changed, 116 insertions, 31 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 3a5f81a66d34..15a216cdd507 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -295,11 +295,8 @@ static void vhost_vq_meta_reset(struct vhost_dev *d)
{
int i;
- for (i = 0; i < d->nvqs; ++i) {
- mutex_lock(&d->vqs[i]->mutex);
+ for (i = 0; i < d->nvqs; ++i)
__vhost_vq_meta_reset(d->vqs[i]);
- mutex_unlock(&d->vqs[i]->mutex);
- }
}
static void vhost_vq_reset(struct vhost_dev *dev,
@@ -658,7 +655,7 @@ static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
a + (unsigned long)log_base > ULONG_MAX)
return false;
- return access_ok(VERIFY_WRITE, log_base + a,
+ return access_ok(log_base + a,
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
}
@@ -684,7 +681,7 @@ static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
return false;
- if (!access_ok(VERIFY_WRITE, (void __user *)a,
+ if (!access_ok((void __user *)a,
node->size))
return false;
else if (log_all && !log_access_ok(log_base,
@@ -895,6 +892,20 @@ static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
#define vhost_get_used(vq, x, ptr) \
vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
+static void vhost_dev_lock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_lock_nested(&d->vqs[i]->mutex, i);
+}
+
+static void vhost_dev_unlock_vqs(struct vhost_dev *d)
+{
+ int i = 0;
+ for (i = 0; i < d->nvqs; ++i)
+ mutex_unlock(&d->vqs[i]->mutex);
+}
+
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
@@ -944,10 +955,7 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d,
if (msg->iova <= vq_msg->iova &&
msg->iova + msg->size - 1 >= vq_msg->iova &&
vq_msg->type == VHOST_IOTLB_MISS) {
- mutex_lock(&node->vq->mutex);
vhost_poll_queue(&node->vq->poll);
- mutex_unlock(&node->vq->mutex);
-
list_del(&node->node);
kfree(node);
}
@@ -965,10 +973,10 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access)
return false;
if ((access & VHOST_ACCESS_RO) &&
- !access_ok(VERIFY_READ, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
if ((access & VHOST_ACCESS_WO) &&
- !access_ok(VERIFY_WRITE, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
return true;
}
@@ -979,6 +987,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
int ret = 0;
mutex_lock(&dev->mutex);
+ vhost_dev_lock_vqs(dev);
switch (msg->type) {
case VHOST_IOTLB_UPDATE:
if (!dev->iotlb) {
@@ -1012,6 +1021,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
break;
}
+ vhost_dev_unlock_vqs(dev);
mutex_unlock(&dev->mutex);
return ret;
@@ -1024,8 +1034,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
int type, ret;
ret = copy_from_iter(&type, sizeof(type), from);
- if (ret != sizeof(type))
+ if (ret != sizeof(type)) {
+ ret = -EINVAL;
goto done;
+ }
switch (type) {
case VHOST_IOTLB_MSG:
@@ -1044,8 +1056,10 @@ ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
iov_iter_advance(from, offset);
ret = copy_from_iter(&msg, sizeof(msg), from);
- if (ret != sizeof(msg))
+ if (ret != sizeof(msg)) {
+ ret = -EINVAL;
goto done;
+ }
if (vhost_process_iotlb_msg(dev, &msg)) {
ret = -EFAULT;
goto done;
@@ -1175,10 +1189,10 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
{
size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
- return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
- access_ok(VERIFY_READ, avail,
+ return access_ok(desc, num * sizeof *desc) &&
+ access_ok(avail,
sizeof *avail + num * sizeof *avail->ring + s) &&
- access_ok(VERIFY_WRITE, used,
+ access_ok(used,
sizeof *used + num * sizeof *used->ring + s);
}
@@ -1723,13 +1737,87 @@ static int log_write(void __user *log_base,
return r;
}
+static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
+{
+ struct vhost_umem *umem = vq->umem;
+ struct vhost_umem_node *u;
+ u64 start, end, l, min;
+ int r;
+ bool hit = false;
+
+ while (len) {
+ min = len;
+ /* More than one GPAs can be mapped into a single HVA. So
+ * iterate all possible umems here to be safe.
+ */
+ list_for_each_entry(u, &umem->umem_list, link) {
+ if (u->userspace_addr > hva - 1 + len ||
+ u->userspace_addr - 1 + u->size < hva)
+ continue;
+ start = max(u->userspace_addr, hva);
+ end = min(u->userspace_addr - 1 + u->size,
+ hva - 1 + len);
+ l = end - start + 1;
+ r = log_write(vq->log_base,
+ u->start + start - u->userspace_addr,
+ l);
+ if (r < 0)
+ return r;
+ hit = true;
+ min = min(l, min);
+ }
+
+ if (!hit)
+ return -EFAULT;
+
+ len -= min;
+ hva += min;
+ }
+
+ return 0;
+}
+
+static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
+{
+ struct iovec iov[64];
+ int i, ret;
+
+ if (!vq->iotlb)
+ return log_write(vq->log_base, vq->log_addr + used_offset, len);
+
+ ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
+ len, iov, 64, VHOST_ACCESS_WO);
+ if (ret)
+ return ret;
+
+ for (i = 0; i < ret; i++) {
+ ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (ret)
+ return ret;
+ }
+
+ return 0;
+}
+
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
- unsigned int log_num, u64 len)
+ unsigned int log_num, u64 len, struct iovec *iov, int count)
{
int i, r;
/* Make sure data written is seen before log. */
smp_wmb();
+
+ if (vq->iotlb) {
+ for (i = 0; i < count; i++) {
+ r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (r < 0)
+ return r;
+ }
+ return 0;
+ }
+
for (i = 0; i < log_num; ++i) {
u64 l = min(log[i].len, len);
r = log_write(vq->log_base, log[i].addr, l);
@@ -1759,9 +1847,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
smp_wmb();
/* Log used flag write. */
used = &vq->used->flags;
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof vq->used->flags);
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof vq->used->flags);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -1779,9 +1866,8 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
smp_wmb();
/* Log avail event write */
used = vhost_avail_event(vq);
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof *vhost_avail_event(vq));
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof *vhost_avail_event(vq));
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -1804,7 +1890,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
goto err;
vq->signalled_used_valid = false;
if (!vq->iotlb &&
- !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
+ !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
r = -EFAULT;
goto err;
}
@@ -2181,10 +2267,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
/* Make sure data is seen before log. */
smp_wmb();
/* Log used ring entry write. */
- log_write(vq->log_base,
- vq->log_addr +
- ((void __user *)used - (void __user *)vq->used),
- count * sizeof *used);
+ log_used(vq, ((void __user *)used - (void __user *)vq->used),
+ count * sizeof *used);
}
old = vq->last_used_idx;
new = (vq->last_used_idx += count);
@@ -2223,10 +2307,11 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
return -EFAULT;
}
if (unlikely(vq->log_used)) {
+ /* Make sure used idx is seen before log. */
+ smp_wmb();
/* Log used index update. */
- log_write(vq->log_base,
- vq->log_addr + offsetof(struct vring_used, idx),
- sizeof vq->used->idx);
+ log_used(vq, offsetof(struct vring_used, idx),
+ sizeof vq->used->idx);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
OpenPOWER on IntegriCloud