diff options
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r-- | net/tls/tls_main.c | 96 |
1 files changed, 57 insertions, 39 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 60aff60e30ad..e07ee3ae0023 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -45,8 +45,18 @@ MODULE_AUTHOR("Mellanox Technologies"); MODULE_DESCRIPTION("Transport Layer Security Support"); MODULE_LICENSE("Dual BSD/GPL"); -static struct proto tls_base_prot; -static struct proto tls_sw_prot; +enum { + TLS_BASE_TX, + TLS_SW_TX, + TLS_NUM_CONFIG, +}; + +static struct proto tls_prots[TLS_NUM_CONFIG]; + +static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) +{ + sk->sk_prot = &tls_prots[ctx->tx_conf]; +} int wait_on_pending_writer(struct sock *sk, long *timeo) { @@ -216,6 +226,12 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) void (*sk_proto_close)(struct sock *sk, long timeout); lock_sock(sk); + sk_proto_close = ctx->sk_proto_close; + + if (ctx->tx_conf == TLS_BASE_TX) { + kfree(ctx); + goto skip_tx_cleanup; + } if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) tls_handle_open_record(sk, 0); @@ -232,13 +248,14 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) sg++; } } - ctx->free_resources(sk); + kfree(ctx->rec_seq); kfree(ctx->iv); - sk_proto_close = ctx->sk_proto_close; - kfree(ctx); + if (ctx->tx_conf == TLS_SW_TX) + tls_sw_free_tx_resources(sk); +skip_tx_cleanup: release_sock(sk); sk_proto_close(sk, timeout); } @@ -338,46 +355,41 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, unsigned int optlen) { - struct tls_crypto_info *crypto_info, tmp_crypto_info; + struct tls_crypto_info *crypto_info; struct tls_context *ctx = tls_get_ctx(sk); - struct proto *prot = NULL; int rc = 0; + int tx_conf; if (!optval || (optlen < sizeof(*crypto_info))) { rc = -EINVAL; goto out; } - rc = copy_from_user(&tmp_crypto_info, optval, sizeof(*crypto_info)); + crypto_info = &ctx->crypto_send; + /* Currently we don't support set crypto info more than one time */ + if (TLS_CRYPTO_INFO_READY(crypto_info)) + goto out; + + rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info)); if (rc) { rc = -EFAULT; goto out; } /* check version */ - if (tmp_crypto_info.version != TLS_1_2_VERSION) { + if (crypto_info->version != TLS_1_2_VERSION) { rc = -ENOTSUPP; - goto out; + goto err_crypto_info; } - /* get user crypto info */ - crypto_info = &ctx->crypto_send; - - /* Currently we don't support set crypto info more than one time */ - if (TLS_CRYPTO_INFO_READY(crypto_info)) - goto out; - - switch (tmp_crypto_info.cipher_type) { + switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: { if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { rc = -EINVAL; goto out; } - rc = copy_from_user( - crypto_info, - optval, - sizeof(struct tls12_crypto_info_aes_gcm_128)); - + rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), + optlen - sizeof(*crypto_info)); if (rc) { rc = -EFAULT; goto err_crypto_info; @@ -389,18 +401,16 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, goto out; } - ctx->sk_write_space = sk->sk_write_space; - sk->sk_write_space = tls_write_space; - - ctx->sk_proto_close = sk->sk_prot->close; - /* currently SW is default, we will have ethtool in future */ rc = tls_set_sw_offload(sk, ctx); - prot = &tls_sw_prot; + tx_conf = TLS_SW_TX; if (rc) goto err_crypto_info; - sk->sk_prot = prot; + ctx->tx_conf = tx_conf; + update_sk_prot(sk, ctx); + ctx->sk_write_space = sk->sk_write_space; + sk->sk_write_space = tls_write_space; goto out; err_crypto_info: @@ -453,7 +463,10 @@ static int tls_init(struct sock *sk) icsk->icsk_ulp_data = ctx; ctx->setsockopt = sk->sk_prot->setsockopt; ctx->getsockopt = sk->sk_prot->getsockopt; - sk->sk_prot = &tls_base_prot; + ctx->sk_proto_close = sk->sk_prot->close; + + ctx->tx_conf = TLS_BASE_TX; + update_sk_prot(sk, ctx); out: return rc; } @@ -464,16 +477,21 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { .init = tls_init, }; +static void build_protos(struct proto *prot, struct proto *base) +{ + prot[TLS_BASE_TX] = *base; + prot[TLS_BASE_TX].setsockopt = tls_setsockopt; + prot[TLS_BASE_TX].getsockopt = tls_getsockopt; + prot[TLS_BASE_TX].close = tls_sk_proto_close; + + prot[TLS_SW_TX] = prot[TLS_BASE_TX]; + prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; + prot[TLS_SW_TX].sendpage = tls_sw_sendpage; +} + static int __init tls_register(void) { - tls_base_prot = tcp_prot; - tls_base_prot.setsockopt = tls_setsockopt; - tls_base_prot.getsockopt = tls_getsockopt; - - tls_sw_prot = tls_base_prot; - tls_sw_prot.sendmsg = tls_sw_sendmsg; - tls_sw_prot.sendpage = tls_sw_sendpage; - tls_sw_prot.close = tls_sk_proto_close; + build_protos(tls_prots, &tcp_prot); tcp_register_ulp(&tcp_tls_ulp_ops); |