diff options
Diffstat (limited to 'drivers/xen/pvcalls-front.c')
-rw-r--r-- | drivers/xen/pvcalls-front.c | 104 |
1 files changed, 75 insertions, 29 deletions
diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c index 77224d8f3e6f..8a249c95c193 100644 --- a/drivers/xen/pvcalls-front.c +++ b/drivers/xen/pvcalls-front.c @@ -31,6 +31,12 @@ #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE) #define PVCALLS_FRONT_MAX_SPIN 5000 +static struct proto pvcalls_proto = { + .name = "PVCalls", + .owner = THIS_MODULE, + .obj_size = sizeof(struct sock), +}; + struct pvcalls_bedata { struct xen_pvcalls_front_ring ring; grant_ref_t ref; @@ -335,6 +341,42 @@ int pvcalls_front_socket(struct socket *sock) return ret; } +static void free_active_ring(struct sock_mapping *map) +{ + if (!map->active.ring) + return; + + free_pages((unsigned long)map->active.data.in, + map->active.ring->ring_order); + free_page((unsigned long)map->active.ring); +} + +static int alloc_active_ring(struct sock_mapping *map) +{ + void *bytes; + + map->active.ring = (struct pvcalls_data_intf *) + get_zeroed_page(GFP_KERNEL); + if (!map->active.ring) + goto out; + + map->active.ring->ring_order = PVCALLS_RING_ORDER; + bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, + PVCALLS_RING_ORDER); + if (!bytes) + goto out; + + map->active.data.in = bytes; + map->active.data.out = bytes + + XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); + + return 0; + +out: + free_active_ring(map); + return -ENOMEM; +} + static int create_active(struct sock_mapping *map, int *evtchn) { void *bytes; @@ -343,15 +385,7 @@ static int create_active(struct sock_mapping *map, int *evtchn) *evtchn = -1; init_waitqueue_head(&map->active.inflight_conn_req); - map->active.ring = (struct pvcalls_data_intf *) - __get_free_page(GFP_KERNEL | __GFP_ZERO); - if (map->active.ring == NULL) - goto out_error; - map->active.ring->ring_order = PVCALLS_RING_ORDER; - bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, - PVCALLS_RING_ORDER); - if (bytes == NULL) - goto out_error; + bytes = map->active.data.in; for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) map->active.ring->ref[i] = gnttab_grant_foreign_access( pvcalls_front_dev->otherend_id, @@ -361,10 +395,6 @@ static int create_active(struct sock_mapping *map, int *evtchn) pvcalls_front_dev->otherend_id, pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); - map->active.data.in = bytes; - map->active.data.out = bytes + - XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); - ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); if (ret) goto out_error; @@ -385,8 +415,6 @@ static int create_active(struct sock_mapping *map, int *evtchn) out_error: if (*evtchn >= 0) xenbus_free_evtchn(pvcalls_front_dev, *evtchn); - free_pages((unsigned long)map->active.data.in, PVCALLS_RING_ORDER); - free_page((unsigned long)map->active.ring); return ret; } @@ -406,17 +434,24 @@ int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, return PTR_ERR(map); bedata = dev_get_drvdata(&pvcalls_front_dev->dev); + ret = alloc_active_ring(map); + if (ret < 0) { + pvcalls_exit_sock(sock); + return ret; + } spin_lock(&bedata->socket_lock); ret = get_request(bedata, &req_id); if (ret < 0) { spin_unlock(&bedata->socket_lock); + free_active_ring(map); pvcalls_exit_sock(sock); return ret; } ret = create_active(map, &evtchn); if (ret < 0) { spin_unlock(&bedata->socket_lock); + free_active_ring(map); pvcalls_exit_sock(sock); return ret; } @@ -469,8 +504,10 @@ static int __write_ring(struct pvcalls_data_intf *intf, virt_mb(); size = pvcalls_queued(prod, cons, array_size); - if (size >= array_size) + if (size > array_size) return -EINVAL; + if (size == array_size) + return 0; if (len > array_size - size) len = array_size - size; @@ -560,15 +597,13 @@ static int __read_ring(struct pvcalls_data_intf *intf, error = intf->in_error; /* get pointers before reading from the ring */ virt_rmb(); - if (error < 0) - return error; size = pvcalls_queued(prod, cons, array_size); masked_prod = pvcalls_mask(prod, array_size); masked_cons = pvcalls_mask(cons, array_size); if (size == 0) - return 0; + return error ?: size; if (len > size) len = size; @@ -780,25 +815,36 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags) } } - spin_lock(&bedata->socket_lock); - ret = get_request(bedata, &req_id); - if (ret < 0) { + map2 = kzalloc(sizeof(*map2), GFP_KERNEL); + if (map2 == NULL) { clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); - spin_unlock(&bedata->socket_lock); + pvcalls_exit_sock(sock); + return -ENOMEM; + } + ret = alloc_active_ring(map2); + if (ret < 0) { + clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, + (void *)&map->passive.flags); + kfree(map2); pvcalls_exit_sock(sock); return ret; } - map2 = kzalloc(sizeof(*map2), GFP_ATOMIC); - if (map2 == NULL) { + spin_lock(&bedata->socket_lock); + ret = get_request(bedata, &req_id); + if (ret < 0) { clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); spin_unlock(&bedata->socket_lock); + free_active_ring(map2); + kfree(map2); pvcalls_exit_sock(sock); - return -ENOMEM; + return ret; } + ret = create_active(map2, &evtchn); if (ret < 0) { + free_active_ring(map2); kfree(map2); clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); @@ -839,7 +885,7 @@ int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags) received: map2->sock = newsock; - newsock->sk = kzalloc(sizeof(*newsock->sk), GFP_KERNEL); + newsock->sk = sk_alloc(sock_net(sock->sk), PF_INET, GFP_KERNEL, &pvcalls_proto, false); if (!newsock->sk) { bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; map->passive.inflight_req_id = PVCALLS_INVALID_ID; @@ -1032,8 +1078,8 @@ int pvcalls_front_release(struct socket *sock) spin_lock(&bedata->socket_lock); list_del(&map->list); spin_unlock(&bedata->socket_lock); - if (READ_ONCE(map->passive.inflight_req_id) != - PVCALLS_INVALID_ID) { + if (READ_ONCE(map->passive.inflight_req_id) != PVCALLS_INVALID_ID && + READ_ONCE(map->passive.inflight_req_id) != 0) { pvcalls_front_free_map(bedata, map->passive.accept_map); } |