diff options
-rw-r--r-- | drivers/vhost/net.c | 4 | ||||
-rw-r--r-- | drivers/vhost/test.c | 5 | ||||
-rw-r--r-- | drivers/vhost/vhost.c | 14 | ||||
-rw-r--r-- | drivers/vhost/vhost.h | 1 |
4 files changed, 16 insertions, 8 deletions
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index f0fd52cdfadc..70ac60437d17 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -703,6 +703,10 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd) vhost_net_disable_vq(n, vq); rcu_assign_pointer(vq->private_data, sock); vhost_net_enable_vq(n, vq); + + r = vhost_init_used(vq); + if (r) + goto err_vq; } mutex_unlock(&vq->mutex); diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c index 734e1d74ad80..fc9a1d75281f 100644 --- a/drivers/vhost/test.c +++ b/drivers/vhost/test.c @@ -195,8 +195,13 @@ static long vhost_test_run(struct vhost_test *n, int test) lockdep_is_held(&vq->mutex)); rcu_assign_pointer(vq->private_data, priv); + r = vhost_init_used(&n->vqs[index]); + mutex_unlock(&vq->mutex); + if (r) + goto err; + if (oldpriv) { vhost_test_flush_vq(n, index); } diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 5ef2f62becf4..9a108038fe52 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -629,15 +629,17 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) return 0; } -static int init_used(struct vhost_virtqueue *vq, - struct vring_used __user *used) +int vhost_init_used(struct vhost_virtqueue *vq) { - int r = put_user(vq->used_flags, &used->flags); + int r; + if (!vq->private_data) + return 0; + r = put_user(vq->used_flags, &vq->used->flags); if (r) return r; vq->signalled_used_valid = false; - return get_user(vq->last_used_idx, &used->idx); + return get_user(vq->last_used_idx, &vq->used->idx); } static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp) @@ -752,10 +754,6 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp) } } - r = init_used(vq, (struct vring_used __user *)(unsigned long) - a.used_user_addr); - if (r) - break; vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); vq->desc = (void __user *)(unsigned long)a.desc_user_addr; vq->avail = (void __user *)(unsigned long)a.avail_user_addr; diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h index 1544b782529b..14c9abf0d800 100644 --- a/drivers/vhost/vhost.h +++ b/drivers/vhost/vhost.h @@ -174,6 +174,7 @@ int vhost_get_vq_desc(struct vhost_dev *, struct vhost_virtqueue *, struct vhost_log *log, unsigned int *log_num); void vhost_discard_vq_desc(struct vhost_virtqueue *, int n); +int vhost_init_used(struct vhost_virtqueue *); int vhost_add_used(struct vhost_virtqueue *, unsigned int head, int len); int vhost_add_used_n(struct vhost_virtqueue *, struct vring_used_elem *heads, unsigned count); |