diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_device.c | 237 | ||||
-rw-r--r-- | net/tls/tls_device_fallback.c | 2 | ||||
-rw-r--r-- | net/tls/tls_main.c | 217 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 89 |
4 files changed, 353 insertions, 192 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 7c0b2b778703..f959487c5cd1 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -61,7 +61,7 @@ static void tls_device_free_ctx(struct tls_context *ctx) if (ctx->rx_conf == TLS_HW) kfree(tls_offload_ctx_rx(ctx)); - tls_ctx_free(ctx); + tls_ctx_free(NULL, ctx); } static void tls_device_gc_task(struct work_struct *work) @@ -122,13 +122,10 @@ static struct net_device *get_netdev_for_sock(struct sock *sk) static void destroy_record(struct tls_record_info *record) { - int nr_frags = record->num_frags; - skb_frag_t *frag; + int i; - while (nr_frags-- > 0) { - frag = &record->frags[nr_frags]; - __skb_frag_unref(frag); - } + for (i = 0; i < record->num_frags; i++) + __skb_frag_unref(&record->frags[i]); kfree(record); } @@ -159,12 +156,8 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) spin_lock_irqsave(&ctx->lock, flags); info = ctx->retransmit_hint; - if (info && !before(acked_seq, info->end_seq)) { + if (info && !before(acked_seq, info->end_seq)) ctx->retransmit_hint = NULL; - list_del(&info->list); - destroy_record(info); - deleted_records++; - } list_for_each_entry_safe(info, temp, &ctx->records_list, list) { if (before(acked_seq, info->end_seq)) @@ -243,14 +236,14 @@ static void tls_append_frag(struct tls_record_info *record, skb_frag_t *frag; frag = &record->frags[record->num_frags - 1]; - if (frag->page.p == pfrag->page && - frag->page_offset + frag->size == pfrag->offset) { - frag->size += size; + if (skb_frag_page(frag) == pfrag->page && + skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) { + skb_frag_size_add(frag, size); } else { ++frag; - frag->page.p = pfrag->page; - frag->page_offset = pfrag->offset; - frag->size = size; + __skb_frag_set_page(frag, pfrag->page); + skb_frag_off_set(frag, pfrag->offset); + skb_frag_size_set(frag, size); ++record->num_frags; get_page(pfrag->page); } @@ -263,33 +256,15 @@ static int tls_push_record(struct sock *sk, struct tls_context *ctx, struct tls_offload_context_tx *offload_ctx, struct tls_record_info *record, - struct page_frag *pfrag, - int flags, - unsigned char record_type) + int flags) { struct tls_prot_info *prot = &ctx->prot_info; struct tcp_sock *tp = tcp_sk(sk); - struct page_frag dummy_tag_frag; skb_frag_t *frag; int i; - /* fill prepend */ - frag = &record->frags[0]; - tls_fill_prepend(ctx, - skb_frag_address(frag), - record->len - prot->prepend_size, - record_type, - prot->version); - - /* HW doesn't care about the data in the tag, because it fills it. */ - dummy_tag_frag.page = skb_frag_page(frag); - dummy_tag_frag.offset = 0; - - tls_append_frag(record, &dummy_tag_frag, prot->tag_size); record->end_seq = tp->write_seq + record->len; - spin_lock_irq(&offload_ctx->lock); - list_add_tail(&record->list, &offload_ctx->records_list); - spin_unlock_irq(&offload_ctx->lock); + list_add_tail_rcu(&record->list, &offload_ctx->records_list); offload_ctx->open_record = NULL; if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags)) @@ -301,8 +276,8 @@ static int tls_push_record(struct sock *sk, frag = &record->frags[i]; sg_unmark_end(&offload_ctx->sg_tx_data[i]); sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag), - frag->size, frag->page_offset); - sk_mem_charge(sk, frag->size); + skb_frag_size(frag), skb_frag_off(frag)); + sk_mem_charge(sk, skb_frag_size(frag)); get_page(skb_frag_page(frag)); } sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]); @@ -311,6 +286,38 @@ static int tls_push_record(struct sock *sk, return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags); } +static int tls_device_record_close(struct sock *sk, + struct tls_context *ctx, + struct tls_record_info *record, + struct page_frag *pfrag, + unsigned char record_type) +{ + struct tls_prot_info *prot = &ctx->prot_info; + int ret; + + /* append tag + * device will fill in the tag, we just need to append a placeholder + * use socket memory to improve coalescing (re-using a single buffer + * increases frag count) + * if we can't allocate memory now, steal some back from data + */ + if (likely(skb_page_frag_refill(prot->tag_size, pfrag, + sk->sk_allocation))) { + ret = 0; + tls_append_frag(record, pfrag, prot->tag_size); + } else { + ret = prot->tag_size; + if (record->len <= prot->overhead_size) + return -ENOMEM; + } + + /* fill prepend */ + tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]), + record->len - prot->overhead_size, + record_type, prot->version); + return ret; +} + static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx, struct page_frag *pfrag, size_t prepend_size) @@ -324,7 +331,7 @@ static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx, frag = &record->frags[0]; __skb_frag_set_page(frag, pfrag->page); - frag->page_offset = pfrag->offset; + skb_frag_off_set(frag, pfrag->offset); skb_frag_size_set(frag, prepend_size); get_page(pfrag->page); @@ -365,6 +372,31 @@ static int tls_do_allocation(struct sock *sk, return 0; } +static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i) +{ + size_t pre_copy, nocache; + + pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1); + if (pre_copy) { + pre_copy = min(pre_copy, bytes); + if (copy_from_iter(addr, pre_copy, i) != pre_copy) + return -EFAULT; + bytes -= pre_copy; + addr += pre_copy; + } + + nocache = round_down(bytes, SMP_CACHE_BYTES); + if (copy_from_iter_nocache(addr, nocache, i) != nocache) + return -EFAULT; + bytes -= nocache; + addr += nocache; + + if (bytes && copy_from_iter(addr, bytes, i) != bytes) + return -EFAULT; + + return 0; +} + static int tls_push_data(struct sock *sk, struct iov_iter *msg_iter, size_t size, int flags, @@ -373,9 +405,9 @@ static int tls_push_data(struct sock *sk, struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_prot_info *prot = &tls_ctx->prot_info; struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); - int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); struct tls_record_info *record = ctx->open_record; + int tls_push_record_flags; struct page_frag *pfrag; size_t orig_size = size; u32 max_open_record_len; @@ -390,6 +422,9 @@ static int tls_push_data(struct sock *sk, if (sk->sk_err) return -sk->sk_err; + flags |= MSG_SENDPAGE_DECRYPTED; + tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); if (tls_is_partially_sent_record(tls_ctx)) { rc = tls_push_partial_record(sk, tls_ctx, flags); @@ -435,12 +470,10 @@ handle_error: copy = min_t(size_t, size, (pfrag->size - pfrag->offset)); copy = min_t(size_t, copy, (max_open_record_len - record->len)); - if (copy_from_iter_nocache(page_address(pfrag->page) + - pfrag->offset, - copy, msg_iter) != copy) { - rc = -EFAULT; + rc = tls_device_copy_data(page_address(pfrag->page) + + pfrag->offset, copy, msg_iter); + if (rc) goto handle_error; - } tls_append_frag(record, pfrag, copy); size -= copy; @@ -458,13 +491,24 @@ last_record: if (done || record->len >= max_open_record_len || (record->num_frags >= MAX_SKB_FRAGS - 1)) { + rc = tls_device_record_close(sk, tls_ctx, record, + pfrag, record_type); + if (rc) { + if (rc > 0) { + size += rc; + } else { + size = orig_size; + destroy_record(record); + ctx->open_record = NULL; + break; + } + } + rc = tls_push_record(sk, tls_ctx, ctx, record, - pfrag, - tls_push_record_flags, - record_type); + tls_push_record_flags); if (rc < 0) break; } @@ -539,12 +583,16 @@ struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, /* if retransmit_hint is irrelevant start * from the beggining of the list */ - info = list_first_entry(&context->records_list, - struct tls_record_info, list); + info = list_first_entry_or_null(&context->records_list, + struct tls_record_info, list); + if (!info) + return NULL; record_sn = context->unacked_record_sn; } - list_for_each_entry_from(info, &context->records_list, list) { + /* We just need the _rcu for the READ_ONCE() */ + rcu_read_lock(); + list_for_each_entry_from_rcu(info, &context->records_list, list) { if (before(seq, info->end_seq)) { if (!context->retransmit_hint || after(info->end_seq, @@ -553,12 +601,15 @@ struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, context->retransmit_hint = info; } *p_record_sn = record_sn; - return info; + goto exit_rcu_unlock; } record_sn++; } + info = NULL; - return NULL; +exit_rcu_unlock: + rcu_read_unlock(); + return info; } EXPORT_SYMBOL(tls_get_record); @@ -576,7 +627,9 @@ void tls_device_write_space(struct sock *sk, struct tls_context *ctx) gfp_t sk_allocation = sk->sk_allocation; sk->sk_allocation = GFP_ATOMIC; - tls_push_partial_record(sk, ctx, MSG_DONTWAIT | MSG_NOSIGNAL); + tls_push_partial_record(sk, ctx, + MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_DECRYPTED); sk->sk_allocation = sk_allocation; } } @@ -833,22 +886,18 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) struct net_device *netdev; char *iv, *rec_seq; struct sk_buff *skb; - int rc = -EINVAL; __be64 rcd_sn; + int rc; if (!ctx) - goto out; + return -EINVAL; - if (ctx->priv_ctx_tx) { - rc = -EEXIST; - goto out; - } + if (ctx->priv_ctx_tx) + return -EEXIST; start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); - if (!start_marker_record) { - rc = -ENOMEM; - goto out; - } + if (!start_marker_record) + return -ENOMEM; offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL); if (!offload_ctx) { @@ -934,17 +983,11 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) if (skb) TCP_SKB_CB(skb)->eor = 1; - /* We support starting offload on multiple sockets - * concurrently, so we only need a read lock here. - * This lock must precede get_netdev_for_sock to prevent races between - * NETDEV_DOWN and setsockopt. - */ - down_read(&device_offload_lock); netdev = get_netdev_for_sock(sk); if (!netdev) { pr_err_ratelimited("%s: netdev not found\n", __func__); rc = -EINVAL; - goto release_lock; + goto disable_cad; } if (!(netdev->features & NETIF_F_HW_TLS_TX)) { @@ -955,10 +998,15 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) /* Avoid offloading if the device is down * We don't want to offload new flows after * the NETDEV_DOWN event + * + * device_offload_lock is taken in tls_devices's NETDEV_DOWN + * handler thus protecting from the device going down before + * ctx was added to tls_device_list. */ + down_read(&device_offload_lock); if (!(netdev->flags & IFF_UP)) { rc = -EINVAL; - goto release_netdev; + goto release_lock; } ctx->priv_ctx_tx = offload_ctx; @@ -966,9 +1014,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) &ctx->crypto_send.info, tcp_sk(sk)->write_seq); if (rc) - goto release_netdev; + goto release_lock; tls_device_attach(ctx, sk, netdev); + up_read(&device_offload_lock); /* following this assignment tls_is_sk_tx_device_offloaded * will return true and the context might be accessed @@ -976,13 +1025,14 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) */ smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb); dev_put(netdev); - up_read(&device_offload_lock); - goto out; -release_netdev: - dev_put(netdev); + return 0; + release_lock: up_read(&device_offload_lock); +release_netdev: + dev_put(netdev); +disable_cad: clean_acked_data_disable(inet_csk(sk)); crypto_free_aead(offload_ctx->aead_send); free_rec_seq: @@ -994,7 +1044,6 @@ free_offload_ctx: ctx->priv_ctx_tx = NULL; free_marker_record: kfree(start_marker_record); -out: return rc; } @@ -1007,17 +1056,10 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) if (ctx->crypto_recv.info.version != TLS_1_2_VERSION) return -EOPNOTSUPP; - /* We support starting offload on multiple sockets - * concurrently, so we only need a read lock here. - * This lock must precede get_netdev_for_sock to prevent races between - * NETDEV_DOWN and setsockopt. - */ - down_read(&device_offload_lock); netdev = get_netdev_for_sock(sk); if (!netdev) { pr_err_ratelimited("%s: netdev not found\n", __func__); - rc = -EINVAL; - goto release_lock; + return -EINVAL; } if (!(netdev->features & NETIF_F_HW_TLS_RX)) { @@ -1028,16 +1070,21 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) /* Avoid offloading if the device is down * We don't want to offload new flows after * the NETDEV_DOWN event + * + * device_offload_lock is taken in tls_devices's NETDEV_DOWN + * handler thus protecting from the device going down before + * ctx was added to tls_device_list. */ + down_read(&device_offload_lock); if (!(netdev->flags & IFF_UP)) { rc = -EINVAL; - goto release_netdev; + goto release_lock; } context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL); if (!context) { rc = -ENOMEM; - goto release_netdev; + goto release_lock; } context->resync_nh_reset = 1; @@ -1053,7 +1100,11 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) goto free_sw_resources; tls_device_attach(ctx, sk, netdev); - goto release_netdev; + up_read(&device_offload_lock); + + dev_put(netdev); + + return 0; free_sw_resources: up_read(&device_offload_lock); @@ -1061,10 +1112,10 @@ free_sw_resources: down_read(&device_offload_lock); release_ctx: ctx->priv_ctx_rx = NULL; -release_netdev: - dev_put(netdev); release_lock: up_read(&device_offload_lock); +release_netdev: + dev_put(netdev); return rc; } diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 9070d68a92a4..28895333701e 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -273,7 +273,7 @@ static int fill_sg_in(struct scatterlist *sg_in, __skb_frag_ref(frag); sg_set_page(sg_in + i, skb_frag_page(frag), - skb_frag_size(frag), frag->page_offset); + skb_frag_size(frag), skb_frag_off(frag)); remaining -= skb_frag_size(frag); diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 4674e57e66b0..ac88877dcade 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -39,6 +39,7 @@ #include <linux/netdevice.h> #include <linux/sched/signal.h> #include <linux/inetdevice.h> +#include <linux/inet_diag.h> #include <net/tls.h> @@ -251,34 +252,31 @@ static void tls_write_space(struct sock *sk) ctx->sk_write_space(sk); } -void tls_ctx_free(struct tls_context *ctx) +/** + * tls_ctx_free() - free TLS ULP context + * @sk: socket to with @ctx is attached + * @ctx: TLS context structure + * + * Free TLS context. If @sk is %NULL caller guarantees that the socket + * to which @ctx was attached has no outstanding references. + */ +void tls_ctx_free(struct sock *sk, struct tls_context *ctx) { if (!ctx) return; memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); - kfree(ctx); + + if (sk) + kfree_rcu(ctx, rcu); + else + kfree(ctx); } -static void tls_sk_proto_close(struct sock *sk, long timeout) +static void tls_sk_proto_cleanup(struct sock *sk, + struct tls_context *ctx, long timeo) { - struct tls_context *ctx = tls_get_ctx(sk); - long timeo = sock_sndtimeo(sk, 0); - void (*sk_proto_close)(struct sock *sk, long timeout); - bool free_ctx = false; - - lock_sock(sk); - sk_proto_close = ctx->sk_proto_close; - - if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) - goto skip_tx_cleanup; - - if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { - free_ctx = true; - goto skip_tx_cleanup; - } - if (unlikely(sk->sk_write_pending) && !wait_on_pending_writer(sk, &timeo)) tls_handle_open_record(sk, 0); @@ -287,36 +285,51 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) if (ctx->tx_conf == TLS_SW) { kfree(ctx->tx.rec_seq); kfree(ctx->tx.iv); - tls_sw_free_resources_tx(sk); -#ifdef CONFIG_TLS_DEVICE + tls_sw_release_resources_tx(sk); } else if (ctx->tx_conf == TLS_HW) { tls_device_free_resources_tx(sk); -#endif } if (ctx->rx_conf == TLS_SW) - tls_sw_free_resources_rx(sk); - -#ifdef CONFIG_TLS_DEVICE - if (ctx->rx_conf == TLS_HW) + tls_sw_release_resources_rx(sk); + else if (ctx->rx_conf == TLS_HW) tls_device_offload_cleanup_rx(sk); +} - if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) { -#else - { -#endif - tls_ctx_free(ctx); - ctx = NULL; - } +static void tls_sk_proto_close(struct sock *sk, long timeout) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tls_context *ctx = tls_get_ctx(sk); + long timeo = sock_sndtimeo(sk, 0); + bool free_ctx; + + if (ctx->tx_conf == TLS_SW) + tls_sw_cancel_work_tx(ctx); + + lock_sock(sk); + free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; + + if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) + tls_sk_proto_cleanup(sk, ctx, timeo); -skip_tx_cleanup: + write_lock_bh(&sk->sk_callback_lock); + if (free_ctx) + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + sk->sk_prot = ctx->sk_proto; + if (sk->sk_write_space == tls_write_space) + sk->sk_write_space = ctx->sk_write_space; + write_unlock_bh(&sk->sk_callback_lock); release_sock(sk); - sk_proto_close(sk, timeout); - /* free ctx for TLS_HW_RECORD, used by tcp_set_state - * for sk->sk_prot->unhash [tls_hw_unhash] - */ + if (ctx->tx_conf == TLS_SW) + tls_sw_free_ctx_tx(ctx); + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) + tls_sw_strparser_done(ctx); + if (ctx->rx_conf == TLS_SW) + tls_sw_free_ctx_rx(ctx); + ctx->sk_proto->close(sk, timeout); + if (free_ctx) - tls_ctx_free(ctx); + tls_ctx_free(sk, ctx); } static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, @@ -433,7 +446,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, struct tls_context *ctx = tls_get_ctx(sk); if (level != SOL_TLS) - return ctx->getsockopt(sk, level, optname, optval, optlen); + return ctx->sk_proto->getsockopt(sk, level, + optname, optval, optlen); return do_tls_getsockopt(sk, optname, optval, optlen); } @@ -518,32 +532,26 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, } if (tx) { -#ifdef CONFIG_TLS_DEVICE rc = tls_set_device_offload(sk, ctx); conf = TLS_HW; if (rc) { -#else - { -#endif rc = tls_set_sw_offload(sk, ctx, 1); + if (rc) + goto err_crypto_info; conf = TLS_SW; } } else { -#ifdef CONFIG_TLS_DEVICE rc = tls_set_device_offload_rx(sk, ctx); conf = TLS_HW; if (rc) { -#else - { -#endif rc = tls_set_sw_offload(sk, ctx, 0); + if (rc) + goto err_crypto_info; conf = TLS_SW; } + tls_sw_strparser_arm(sk, ctx); } - if (rc) - goto err_crypto_info; - if (tx) ctx->tx_conf = conf; else @@ -589,7 +597,8 @@ static int tls_setsockopt(struct sock *sk, int level, int optname, struct tls_context *ctx = tls_get_ctx(sk); if (level != SOL_TLS) - return ctx->setsockopt(sk, level, optname, optval, optlen); + return ctx->sk_proto->setsockopt(sk, level, optname, optval, + optlen); return do_tls_setsockopt(sk, optname, optval, optlen); } @@ -603,10 +612,8 @@ static struct tls_context *create_ctx(struct sock *sk) if (!ctx) return NULL; - icsk->icsk_ulp_data = ctx; - ctx->setsockopt = sk->sk_prot->setsockopt; - ctx->getsockopt = sk->sk_prot->getsockopt; - ctx->sk_proto_close = sk->sk_prot->close; + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + ctx->sk_proto = sk->sk_prot; return ctx; } @@ -643,8 +650,8 @@ static void tls_hw_sk_destruct(struct sock *sk) ctx->sk_destruct(sk); /* Free ctx */ - tls_ctx_free(ctx); - icsk->icsk_ulp_data = NULL; + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + tls_ctx_free(sk, ctx); } static int tls_hw_prot(struct sock *sk) @@ -662,9 +669,6 @@ static int tls_hw_prot(struct sock *sk) spin_unlock_bh(&device_spinlock); tls_build_proto(sk); - ctx->hash = sk->sk_prot->hash; - ctx->unhash = sk->sk_prot->unhash; - ctx->sk_proto_close = sk->sk_prot->close; ctx->sk_destruct = sk->sk_destruct; sk->sk_destruct = tls_hw_sk_destruct; ctx->rx_conf = TLS_HW_RECORD; @@ -696,7 +700,7 @@ static void tls_hw_unhash(struct sock *sk) } } spin_unlock_bh(&device_spinlock); - ctx->unhash(sk); + ctx->sk_proto->unhash(sk); } static int tls_hw_hash(struct sock *sk) @@ -705,7 +709,7 @@ static int tls_hw_hash(struct sock *sk) struct tls_device *dev; int err; - err = ctx->hash(sk); + err = ctx->sk_proto->hash(sk); spin_lock_bh(&device_spinlock); list_for_each_entry(dev, &device_list, dev_list) { if (dev->hash) { @@ -764,7 +768,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; - prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close; } static int tls_init(struct sock *sk) @@ -773,7 +776,7 @@ static int tls_init(struct sock *sk) int rc = 0; if (tls_hw_prot(sk)) - goto out; + return 0; /* The TLS ulp is currently supported only for TCP sockets * in ESTABLISHED state. @@ -784,21 +787,96 @@ static int tls_init(struct sock *sk) if (sk->sk_state != TCP_ESTABLISHED) return -ENOTSUPP; + tls_build_proto(sk); + /* allocate tls context */ + write_lock_bh(&sk->sk_callback_lock); ctx = create_ctx(sk); if (!ctx) { rc = -ENOMEM; goto out; } - tls_build_proto(sk); ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; update_sk_prot(sk, ctx); out: + write_unlock_bh(&sk->sk_callback_lock); return rc; } +static void tls_update(struct sock *sk, struct proto *p) +{ + struct tls_context *ctx; + + ctx = tls_get_ctx(sk); + if (likely(ctx)) + ctx->sk_proto = p; + else + sk->sk_prot = p; +} + +static int tls_get_info(const struct sock *sk, struct sk_buff *skb) +{ + u16 version, cipher_type; + struct tls_context *ctx; + struct nlattr *start; + int err; + + start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); + if (!start) + return -EMSGSIZE; + + rcu_read_lock(); + ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + if (!ctx) { + err = 0; + goto nla_failure; + } + version = ctx->prot_info.version; + if (version) { + err = nla_put_u16(skb, TLS_INFO_VERSION, version); + if (err) + goto nla_failure; + } + cipher_type = ctx->prot_info.cipher_type; + if (cipher_type) { + err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); + if (err) + goto nla_failure; + } + err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); + if (err) + goto nla_failure; + + err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); + if (err) + goto nla_failure; + + rcu_read_unlock(); + nla_nest_end(skb, start); + return 0; + +nla_failure: + rcu_read_unlock(); + nla_nest_cancel(skb, start); + return err; +} + +static size_t tls_get_info_size(const struct sock *sk) +{ + size_t size = 0; + + size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ + 0; + + return size; +} + void tls_register_device(struct tls_device *device) { spin_lock_bh(&device_spinlock); @@ -819,6 +897,9 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .name = "tls", .owner = THIS_MODULE, .init = tls_init, + .update = tls_update, + .get_info = tls_get_info, + .get_info_size = tls_get_info_size, }; static int __init tls_register(void) @@ -826,9 +907,7 @@ static int __init tls_register(void) tls_sw_proto_ops = inet_stream_ops; tls_sw_proto_ops.splice_read = tls_sw_splice_read; -#ifdef CONFIG_TLS_DEVICE tls_device_init(); -#endif tcp_register_ulp(&tcp_tls_ulp_ops); return 0; @@ -837,9 +916,7 @@ static int __init tls_register(void) static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); -#ifdef CONFIG_TLS_DEVICE tls_device_cleanup(); -#endif } module_init(tls_register); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 53b4ad94e74a..c2b5e0d2ba1a 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1489,13 +1489,12 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, int pad, err = 0; if (!ctx->decrypted) { -#ifdef CONFIG_TLS_DEVICE if (tls_ctx->rx_conf == TLS_HW) { err = tls_device_decrypted(sk, skb); if (err < 0) return err; } -#endif + /* Still not decrypted after tls_device */ if (!ctx->decrypted) { err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, @@ -2014,10 +2013,9 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb) ret = -EINVAL; goto read_failure; } -#ifdef CONFIG_TLS_DEVICE + tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, TCP_SKB_CB(skb)->seq + rxm->offset); -#endif return data_len + TLS_HEADER_SIZE; read_failure: @@ -2054,7 +2052,16 @@ static void tls_data_ready(struct sock *sk) } } -void tls_sw_free_resources_tx(struct sock *sk) +void tls_sw_cancel_work_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask); + set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask); + cancel_delayed_work_sync(&ctx->tx_work.work); +} + +void tls_sw_release_resources_tx(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); @@ -2065,11 +2072,6 @@ void tls_sw_free_resources_tx(struct sock *sk) if (atomic_read(&ctx->encrypt_pending)) crypto_wait_req(-EINPROGRESS, &ctx->async_wait); - release_sock(sk); - cancel_delayed_work_sync(&ctx->tx_work.work); - lock_sock(sk); - - /* Tx whatever records we can transmit and abandon the rest */ tls_tx_records(sk, -1); /* Free up un-sent records in tx_list. First, free @@ -2092,6 +2094,11 @@ void tls_sw_free_resources_tx(struct sock *sk) crypto_free_aead(ctx->aead_send); tls_free_open_rec(sk); +} + +void tls_sw_free_ctx_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); kfree(ctx); } @@ -2110,25 +2117,40 @@ void tls_sw_release_resources_rx(struct sock *sk) skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); strp_stop(&ctx->strp); - write_lock_bh(&sk->sk_callback_lock); - sk->sk_data_ready = ctx->saved_data_ready; - write_unlock_bh(&sk->sk_callback_lock); - release_sock(sk); - strp_done(&ctx->strp); - lock_sock(sk); + /* If tls_sw_strparser_arm() was not called (cleanup paths) + * we still want to strp_stop(), but sk->sk_data_ready was + * never swapped. + */ + if (ctx->saved_data_ready) { + write_lock_bh(&sk->sk_callback_lock); + sk->sk_data_ready = ctx->saved_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + } } } -void tls_sw_free_resources_rx(struct sock *sk) +void tls_sw_strparser_done(struct tls_context *tls_ctx) { - struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - tls_sw_release_resources_rx(sk); + strp_done(&ctx->strp); +} + +void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); kfree(ctx); } +void tls_sw_free_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + + tls_sw_release_resources_rx(sk); + tls_sw_free_ctx_rx(tls_ctx); +} + /* The work handler to transmitt the encrypted records in tx_list */ static void tx_work_handler(struct work_struct *work) { @@ -2137,11 +2159,17 @@ static void tx_work_handler(struct work_struct *work) struct tx_work, work); struct sock *sk = tx_work->sk; struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_sw_context_tx *ctx; - if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + if (unlikely(!tls_ctx)) return; + ctx = tls_sw_ctx_tx(tls_ctx); + if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask)) + return; + + if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + return; lock_sock(sk); tls_tx_records(sk, -1); release_sock(sk); @@ -2160,6 +2188,18 @@ void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) } } +void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); + + write_lock_bh(&sk->sk_callback_lock); + rx_ctx->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = tls_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + + strp_check_rcv(&rx_ctx->strp); +} + int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { struct tls_context *tls_ctx = tls_get_ctx(sk); @@ -2357,13 +2397,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) cb.parse_msg = tls_read_size; strp_init(&sw_ctx_rx->strp, sk, &cb); - - write_lock_bh(&sk->sk_callback_lock); - sw_ctx_rx->saved_data_ready = sk->sk_data_ready; - sk->sk_data_ready = tls_data_ready; - write_unlock_bh(&sk->sk_callback_lock); - - strp_check_rcv(&sw_ctx_rx->strp); } goto out; |