diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 318 |
1 files changed, 161 insertions, 157 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 43252a801c3f..94774c0e5ff3 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -39,8 +39,11 @@ #include <linux/netdevice.h> #include <linux/sched/signal.h> #include <linux/inetdevice.h> +#include <linux/inet_diag.h> +#include <net/snmp.h> #include <net/tls.h> +#include <net/tls_toe.h> MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); @@ -57,14 +60,12 @@ static struct proto *saved_tcpv6_prot; static DEFINE_MUTEX(tcpv6_prot_mutex); static struct proto *saved_tcpv4_prot; static DEFINE_MUTEX(tcpv4_prot_mutex); -static LIST_HEAD(device_list); -static DEFINE_SPINLOCK(device_spinlock); static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; static struct proto_ops tls_sw_proto_ops; static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], struct proto *base); -static void update_sk_prot(struct sock *sk, struct tls_context *ctx) +void update_sk_prot(struct sock *sk, struct tls_context *ctx) { int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; @@ -208,24 +209,15 @@ int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, return tls_push_sg(sk, ctx, sg, offset, flags); } -bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx) +void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) { struct scatterlist *sg; - sg = ctx->partially_sent_record; - if (!sg) - return false; - - while (1) { + for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { put_page(sg_page(sg)); sk_mem_uncharge(sk, sg->length); - - if (sg_is_last(sg)) - break; - sg++; } ctx->partially_sent_record = NULL; - return true; } static void tls_write_space(struct sock *sk) @@ -251,14 +243,27 @@ static void tls_write_space(struct sock *sk) ctx->sk_write_space(sk); } -void tls_ctx_free(struct tls_context *ctx) +/** + * tls_ctx_free() - free TLS ULP context + * @sk: socket to with @ctx is attached + * @ctx: TLS context structure + * + * Free TLS context. If @sk is %NULL caller guarantees that the socket + * to which @ctx was attached has no outstanding references. + */ +void tls_ctx_free(struct sock *sk, struct tls_context *ctx) { if (!ctx) return; memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); - kfree(ctx); + mutex_destroy(&ctx->tx_lock); + + if (sk) + kfree_rcu(ctx, rcu); + else + kfree(ctx); } static void tls_sk_proto_cleanup(struct sock *sk, @@ -273,19 +278,19 @@ static void tls_sk_proto_cleanup(struct sock *sk, kfree(ctx->tx.rec_seq); kfree(ctx->tx.iv); tls_sw_release_resources_tx(sk); -#ifdef CONFIG_TLS_DEVICE + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); } else if (ctx->tx_conf == TLS_HW) { tls_device_free_resources_tx(sk); -#endif + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); } - if (ctx->rx_conf == TLS_SW) + if (ctx->rx_conf == TLS_SW) { tls_sw_release_resources_rx(sk); - -#ifdef CONFIG_TLS_DEVICE - if (ctx->rx_conf == TLS_HW) + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + } else if (ctx->rx_conf == TLS_HW) { tls_device_offload_cleanup_rx(sk); -#endif + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + } } static void tls_sk_proto_close(struct sock *sk, long timeout) @@ -306,7 +311,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) write_lock_bh(&sk->sk_callback_lock); if (free_ctx) - icsk->icsk_ulp_data = NULL; + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); sk->sk_prot = ctx->sk_proto; if (sk->sk_write_space == tls_write_space) sk->sk_write_space = ctx->sk_write_space; @@ -318,10 +323,10 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) tls_sw_strparser_done(ctx); if (ctx->rx_conf == TLS_SW) tls_sw_free_ctx_rx(ctx); - ctx->sk_proto_close(sk, timeout); + ctx->sk_proto->close(sk, timeout); if (free_ctx) - tls_ctx_free(ctx); + tls_ctx_free(sk, ctx); } static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, @@ -438,7 +443,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, struct tls_context *ctx = tls_get_ctx(sk); if (level != SOL_TLS) - return ctx->getsockopt(sk, level, optname, optval, optlen); + return ctx->sk_proto->getsockopt(sk, level, + optname, optval, optlen); return do_tls_getsockopt(sk, optname, optval, optlen); } @@ -481,7 +487,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, /* check version */ if (crypto_info->version != TLS_1_2_VERSION && crypto_info->version != TLS_1_3_VERSION) { - rc = -ENOTSUPP; + rc = -EINVAL; goto err_crypto_info; } @@ -523,29 +529,31 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, } if (tx) { -#ifdef CONFIG_TLS_DEVICE rc = tls_set_device_offload(sk, ctx); conf = TLS_HW; - if (rc) { -#else - { -#endif + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); + } else { rc = tls_set_sw_offload(sk, ctx, 1); if (rc) goto err_crypto_info; + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); conf = TLS_SW; } } else { -#ifdef CONFIG_TLS_DEVICE rc = tls_set_device_offload_rx(sk, ctx); conf = TLS_HW; - if (rc) { -#else - { -#endif + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + } else { rc = tls_set_sw_offload(sk, ctx, 0); if (rc) goto err_crypto_info; + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); conf = TLS_SW; } tls_sw_strparser_arm(sk, ctx); @@ -596,12 +604,13 @@ static int tls_setsockopt(struct sock *sk, int level, int optname, struct tls_context *ctx = tls_get_ctx(sk); if (level != SOL_TLS) - return ctx->setsockopt(sk, level, optname, optval, optlen); + return ctx->sk_proto->setsockopt(sk, level, optname, optval, + optlen); return do_tls_setsockopt(sk, optname, optval, optlen); } -static struct tls_context *create_ctx(struct sock *sk) +struct tls_context *tls_ctx_create(struct sock *sk) { struct inet_connection_sock *icsk = inet_csk(sk); struct tls_context *ctx; @@ -610,11 +619,9 @@ static struct tls_context *create_ctx(struct sock *sk) if (!ctx) return NULL; - icsk->icsk_ulp_data = ctx; - ctx->setsockopt = sk->sk_prot->setsockopt; - ctx->getsockopt = sk->sk_prot->getsockopt; - ctx->sk_proto_close = sk->sk_prot->close; - ctx->unhash = sk->sk_prot->unhash; + mutex_init(&ctx->tx_lock); + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + ctx->sk_proto = sk->sk_prot; return ctx; } @@ -644,93 +651,6 @@ static void tls_build_proto(struct sock *sk) } } -static void tls_hw_sk_destruct(struct sock *sk) -{ - struct tls_context *ctx = tls_get_ctx(sk); - struct inet_connection_sock *icsk = inet_csk(sk); - - ctx->sk_destruct(sk); - /* Free ctx */ - tls_ctx_free(ctx); - icsk->icsk_ulp_data = NULL; -} - -static int tls_hw_prot(struct sock *sk) -{ - struct tls_context *ctx; - struct tls_device *dev; - int rc = 0; - - spin_lock_bh(&device_spinlock); - list_for_each_entry(dev, &device_list, dev_list) { - if (dev->feature && dev->feature(dev)) { - ctx = create_ctx(sk); - if (!ctx) - goto out; - - spin_unlock_bh(&device_spinlock); - tls_build_proto(sk); - ctx->hash = sk->sk_prot->hash; - ctx->unhash = sk->sk_prot->unhash; - ctx->sk_proto_close = sk->sk_prot->close; - ctx->sk_destruct = sk->sk_destruct; - sk->sk_destruct = tls_hw_sk_destruct; - ctx->rx_conf = TLS_HW_RECORD; - ctx->tx_conf = TLS_HW_RECORD; - update_sk_prot(sk, ctx); - spin_lock_bh(&device_spinlock); - rc = 1; - break; - } - } -out: - spin_unlock_bh(&device_spinlock); - return rc; -} - -static void tls_hw_unhash(struct sock *sk) -{ - struct tls_context *ctx = tls_get_ctx(sk); - struct tls_device *dev; - - spin_lock_bh(&device_spinlock); - list_for_each_entry(dev, &device_list, dev_list) { - if (dev->unhash) { - kref_get(&dev->kref); - spin_unlock_bh(&device_spinlock); - dev->unhash(dev, sk); - kref_put(&dev->kref, dev->release); - spin_lock_bh(&device_spinlock); - } - } - spin_unlock_bh(&device_spinlock); - ctx->unhash(sk); -} - -static int tls_hw_hash(struct sock *sk) -{ - struct tls_context *ctx = tls_get_ctx(sk); - struct tls_device *dev; - int err; - - err = ctx->hash(sk); - spin_lock_bh(&device_spinlock); - list_for_each_entry(dev, &device_list, dev_list) { - if (dev->hash) { - kref_get(&dev->kref); - spin_unlock_bh(&device_spinlock); - err |= dev->hash(dev, sk); - kref_put(&dev->kref, dev->release); - spin_lock_bh(&device_spinlock); - } - } - spin_unlock_bh(&device_spinlock); - - if (err) - tls_hw_unhash(sk); - return err; -} - static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], struct proto *base) { @@ -768,10 +688,11 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; #endif - +#ifdef CONFIG_TLS_TOE prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; - prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; - prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; +#endif } static int tls_init(struct sock *sk) @@ -779,8 +700,12 @@ static int tls_init(struct sock *sk) struct tls_context *ctx; int rc = 0; - if (tls_hw_prot(sk)) + tls_build_proto(sk); + +#ifdef CONFIG_TLS_TOE + if (tls_toe_bypass(sk)) return 0; +#endif /* The TLS ulp is currently supported only for TCP sockets * in ESTABLISHED state. @@ -789,13 +714,11 @@ static int tls_init(struct sock *sk) * share the ulp context. */ if (sk->sk_state != TCP_ESTABLISHED) - return -ENOTSUPP; - - tls_build_proto(sk); + return -ENOTCONN; /* allocate tls context */ write_lock_bh(&sk->sk_callback_lock); - ctx = create_ctx(sk); + ctx = tls_ctx_create(sk); if (!ctx) { rc = -ENOMEM; goto out; @@ -803,57 +726,139 @@ static int tls_init(struct sock *sk) ctx->tx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE; - ctx->sk_proto = sk->sk_prot; update_sk_prot(sk, ctx); out: write_unlock_bh(&sk->sk_callback_lock); return rc; } -static void tls_update(struct sock *sk, struct proto *p) +static void tls_update(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)) { struct tls_context *ctx; ctx = tls_get_ctx(sk); if (likely(ctx)) { - ctx->sk_proto_close = p->close; + ctx->sk_write_space = write_space; ctx->sk_proto = p; } else { sk->sk_prot = p; + sk->sk_write_space = write_space; } } -void tls_register_device(struct tls_device *device) +static int tls_get_info(const struct sock *sk, struct sk_buff *skb) { - spin_lock_bh(&device_spinlock); - list_add_tail(&device->dev_list, &device_list); - spin_unlock_bh(&device_spinlock); + u16 version, cipher_type; + struct tls_context *ctx; + struct nlattr *start; + int err; + + start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); + if (!start) + return -EMSGSIZE; + + rcu_read_lock(); + ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + if (!ctx) { + err = 0; + goto nla_failure; + } + version = ctx->prot_info.version; + if (version) { + err = nla_put_u16(skb, TLS_INFO_VERSION, version); + if (err) + goto nla_failure; + } + cipher_type = ctx->prot_info.cipher_type; + if (cipher_type) { + err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); + if (err) + goto nla_failure; + } + err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); + if (err) + goto nla_failure; + + err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); + if (err) + goto nla_failure; + + rcu_read_unlock(); + nla_nest_end(skb, start); + return 0; + +nla_failure: + rcu_read_unlock(); + nla_nest_cancel(skb, start); + return err; } -EXPORT_SYMBOL(tls_register_device); -void tls_unregister_device(struct tls_device *device) +static size_t tls_get_info_size(const struct sock *sk) { - spin_lock_bh(&device_spinlock); - list_del(&device->dev_list); - spin_unlock_bh(&device_spinlock); + size_t size = 0; + + size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ + 0; + + return size; } -EXPORT_SYMBOL(tls_unregister_device); + +static int __net_init tls_init_net(struct net *net) +{ + int err; + + net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); + if (!net->mib.tls_statistics) + return -ENOMEM; + + err = tls_proc_init(net); + if (err) + goto err_free_stats; + + return 0; +err_free_stats: + free_percpu(net->mib.tls_statistics); + return err; +} + +static void __net_exit tls_exit_net(struct net *net) +{ + tls_proc_fini(net); + free_percpu(net->mib.tls_statistics); +} + +static struct pernet_operations tls_proc_ops = { + .init = tls_init_net, + .exit = tls_exit_net, +}; static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .name = "tls", .owner = THIS_MODULE, .init = tls_init, .update = tls_update, + .get_info = tls_get_info, + .get_info_size = tls_get_info_size, }; static int __init tls_register(void) { + int err; + + err = register_pernet_subsys(&tls_proc_ops); + if (err) + return err; + tls_sw_proto_ops = inet_stream_ops; tls_sw_proto_ops.splice_read = tls_sw_splice_read; + tls_sw_proto_ops.sendpage_locked = tls_sw_sendpage_locked, -#ifdef CONFIG_TLS_DEVICE tls_device_init(); -#endif tcp_register_ulp(&tcp_tls_ulp_ops); return 0; @@ -862,9 +867,8 @@ static int __init tls_register(void) static void __exit tls_unregister(void) { tcp_unregister_ulp(&tcp_tls_ulp_ops); -#ifdef CONFIG_TLS_DEVICE tls_device_cleanup(); -#endif + unregister_pernet_subsys(&tls_proc_ops); } module_init(tls_register); |