diff options
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 713 |
1 files changed, 569 insertions, 144 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index f26376e954ae..4dc766b03f00 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -34,11 +34,60 @@ * SOFTWARE. */ +#include <linux/sched/signal.h> #include <linux/module.h> #include <crypto/aead.h> +#include <net/strparser.h> #include <net/tls.h> +static int tls_do_decryption(struct sock *sk, + struct scatterlist *sgin, + struct scatterlist *sgout, + char *iv_recv, + size_t data_len, + struct sk_buff *skb, + gfp_t flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm = strp_msg(skb); + struct aead_request *aead_req; + + int ret; + unsigned int req_size = sizeof(struct aead_request) + + crypto_aead_reqsize(ctx->aead_recv); + + aead_req = kzalloc(req_size, flags); + if (!aead_req) + return -ENOMEM; + + aead_request_set_tfm(aead_req, ctx->aead_recv); + aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); + aead_request_set_crypt(aead_req, sgin, sgout, + data_len + tls_ctx->rx.tag_size, + (u8 *)iv_recv); + aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, + crypto_req_done, &ctx->async_wait); + + ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); + + if (ret < 0) + goto out; + + rxm->offset += tls_ctx->rx.prepend_size; + rxm->full_len -= tls_ctx->rx.overhead_size; + tls_advance_record_sn(sk, &tls_ctx->rx); + + ctx->decrypted = true; + + ctx->saved_data_ready(sk); + +out: + kfree(aead_req); + return ret; +} + static void trim_sg(struct sock *sk, struct scatterlist *sg, int *sg_num_elem, unsigned int *sg_size, int target_size) { @@ -79,7 +128,7 @@ static void trim_both_sgl(struct sock *sk, int target_size) target_size); if (target_size > 0) - target_size += tls_ctx->overhead_size; + target_size += tls_ctx->tx.overhead_size; trim_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem, @@ -87,71 +136,16 @@ static void trim_both_sgl(struct sock *sk, int target_size) target_size); } -static int alloc_sg(struct sock *sk, int len, struct scatterlist *sg, - int *sg_num_elem, unsigned int *sg_size, - int first_coalesce) -{ - struct page_frag *pfrag; - unsigned int size = *sg_size; - int num_elem = *sg_num_elem, use = 0, rc = 0; - struct scatterlist *sge; - unsigned int orig_offset; - - len -= size; - pfrag = sk_page_frag(sk); - - while (len > 0) { - if (!sk_page_frag_refill(sk, pfrag)) { - rc = -ENOMEM; - goto out; - } - - use = min_t(int, len, pfrag->size - pfrag->offset); - - if (!sk_wmem_schedule(sk, use)) { - rc = -ENOMEM; - goto out; - } - - sk_mem_charge(sk, use); - size += use; - orig_offset = pfrag->offset; - pfrag->offset += use; - - sge = sg + num_elem - 1; - if (num_elem > first_coalesce && sg_page(sg) == pfrag->page && - sg->offset + sg->length == orig_offset) { - sg->length += use; - } else { - sge++; - sg_unmark_end(sge); - sg_set_page(sge, pfrag->page, use, orig_offset); - get_page(pfrag->page); - ++num_elem; - if (num_elem == MAX_SKB_FRAGS) { - rc = -ENOSPC; - break; - } - } - - len -= use; - } - goto out; - -out: - *sg_size = size; - *sg_num_elem = num_elem; - return rc; -} - static int alloc_encrypted_sg(struct sock *sk, int len) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); int rc = 0; - rc = alloc_sg(sk, len, ctx->sg_encrypted_data, - &ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_size, 0); + rc = sk_alloc_sg(sk, len, + ctx->sg_encrypted_data, 0, + &ctx->sg_encrypted_num_elem, + &ctx->sg_encrypted_size, 0); return rc; } @@ -162,9 +156,9 @@ static int alloc_plaintext_sg(struct sock *sk, int len) struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); int rc = 0; - rc = alloc_sg(sk, len, ctx->sg_plaintext_data, - &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, - tls_ctx->pending_open_record_frags); + rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, + &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, + tls_ctx->pending_open_record_frags); return rc; } @@ -207,21 +201,21 @@ static int tls_do_encryption(struct tls_context *tls_ctx, if (!aead_req) return -ENOMEM; - ctx->sg_encrypted_data[0].offset += tls_ctx->prepend_size; - ctx->sg_encrypted_data[0].length -= tls_ctx->prepend_size; + ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size; + ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size; aead_request_set_tfm(aead_req, ctx->aead_send); aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out, - data_len, tls_ctx->iv); + data_len, tls_ctx->tx.iv); aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, crypto_req_done, &ctx->async_wait); rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait); - ctx->sg_encrypted_data[0].offset -= tls_ctx->prepend_size; - ctx->sg_encrypted_data[0].length += tls_ctx->prepend_size; + ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size; + ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; kfree(aead_req); return rc; @@ -238,7 +232,7 @@ static int tls_push_record(struct sock *sk, int flags, sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size, - tls_ctx->rec_seq, tls_ctx->rec_seq_size, + tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, record_type); tls_fill_prepend(tls_ctx, @@ -269,9 +263,9 @@ static int tls_push_record(struct sock *sk, int flags, /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */ rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags); if (rc < 0 && rc != -EAGAIN) - tls_err_abort(sk); + tls_err_abort(sk, EBADMSG); - tls_advance_record_sn(sk, tls_ctx); + tls_advance_record_sn(sk, &tls_ctx->tx); return rc; } @@ -281,23 +275,24 @@ static int tls_sw_push_pending_record(struct sock *sk, int flags) } static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, - int length) + int length, int *pages_used, + unsigned int *size_used, + struct scatterlist *to, int to_max_pages, + bool charge) { - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); struct page *pages[MAX_SKB_FRAGS]; size_t offset; ssize_t copied, use; int i = 0; - unsigned int size = ctx->sg_plaintext_size; - int num_elem = ctx->sg_plaintext_num_elem; + unsigned int size = *size_used; + int num_elem = *pages_used; int rc = 0; int maxpages; while (length > 0) { i = 0; - maxpages = ARRAY_SIZE(ctx->sg_plaintext_data) - num_elem; + maxpages = to_max_pages - num_elem; if (maxpages == 0) { rc = -EFAULT; goto out; @@ -317,10 +312,11 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, while (copied) { use = min_t(int, copied, PAGE_SIZE - offset); - sg_set_page(&ctx->sg_plaintext_data[num_elem], + sg_set_page(&to[num_elem], pages[i], use, offset); - sg_unmark_end(&ctx->sg_plaintext_data[num_elem]); - sk_mem_charge(sk, use); + sg_unmark_end(&to[num_elem]); + if (charge) + sk_mem_charge(sk, use); offset = 0; copied -= use; @@ -331,8 +327,9 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, } out: - ctx->sg_plaintext_size = size; - ctx->sg_plaintext_num_elem = num_elem; + *size_used = size; + *pages_used = num_elem; + return rc; } @@ -409,7 +406,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) } required_size = ctx->sg_plaintext_size + try_to_copy + - tls_ctx->overhead_size; + tls_ctx->tx.overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; @@ -429,7 +426,11 @@ alloc_encrypted: if (full_record || eor) { ret = zerocopy_from_iter(sk, &msg->msg_iter, - try_to_copy); + try_to_copy, &ctx->sg_plaintext_num_elem, + &ctx->sg_plaintext_size, + ctx->sg_plaintext_data, + ARRAY_SIZE(ctx->sg_plaintext_data), + true); if (ret) goto fallback_to_reg_send; @@ -468,7 +469,7 @@ alloc_plaintext: &ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_size, ctx->sg_plaintext_size + - tls_ctx->overhead_size); + tls_ctx->tx.overhead_size); } ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy); @@ -560,7 +561,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, full_record = true; } required_size = ctx->sg_plaintext_size + copy + - tls_ctx->overhead_size; + tls_ctx->tx.overhead_size; if (!sk_stream_memory_free(sk)) goto wait_for_sndbuf; @@ -629,13 +630,404 @@ sendpage_end: return ret; } -void tls_sw_free_tx_resources(struct sock *sk) +static struct sk_buff *tls_wait_data(struct sock *sk, int flags, + long timeo, int *err) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct sk_buff *skb; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + while (!(skb = ctx->recv_pkt)) { + if (sk->sk_err) { + *err = sock_error(sk); + return NULL; + } + + if (sock_flag(sk, SOCK_DONE)) + return NULL; + + if ((flags & MSG_DONTWAIT) || !timeo) { + *err = -EAGAIN; + return NULL; + } + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + + /* Handle signals */ + if (signal_pending(current)) { + *err = sock_intr_errno(timeo); + return NULL; + } + } + + return skb; +} + +static int decrypt_skb(struct sock *sk, struct sk_buff *skb, + struct scatterlist *sgout) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + tls_ctx->rx.iv_size]; + struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; + struct scatterlist *sgin = &sgin_arr[0]; + struct strp_msg *rxm = strp_msg(skb); + int ret, nsg = ARRAY_SIZE(sgin_arr); + char aad_recv[TLS_AAD_SPACE_SIZE]; + struct sk_buff *unused; + + ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, + iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + tls_ctx->rx.iv_size); + if (ret < 0) + return ret; + + memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + if (!sgout) { + nsg = skb_cow_data(skb, 0, &unused) + 1; + sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); + if (!sgout) + sgout = sgin; + } + + sg_init_table(sgin, nsg); + sg_set_buf(&sgin[0], aad_recv, sizeof(aad_recv)); + + nsg = skb_to_sgvec(skb, &sgin[1], + rxm->offset + tls_ctx->rx.prepend_size, + rxm->full_len - tls_ctx->rx.prepend_size); + + tls_make_aad(aad_recv, + rxm->full_len - tls_ctx->rx.overhead_size, + tls_ctx->rx.rec_seq, + tls_ctx->rx.rec_seq_size, + ctx->control); + + ret = tls_do_decryption(sk, sgin, sgout, iv, + rxm->full_len - tls_ctx->rx.overhead_size, + skb, sk->sk_allocation); + + if (sgin != &sgin_arr[0]) + kfree(sgin); + + return ret; +} + +static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, + unsigned int len) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm = strp_msg(skb); + + if (len < rxm->full_len) { + rxm->offset += len; + rxm->full_len -= len; + + return false; + } + + /* Finished with message */ + ctx->recv_pkt = NULL; + kfree_skb(skb); + strp_unpause(&ctx->strp); + + return true; +} + +int tls_sw_recvmsg(struct sock *sk, + struct msghdr *msg, + size_t len, + int nonblock, + int flags, + int *addr_len) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + unsigned char control; + struct strp_msg *rxm; + struct sk_buff *skb; + ssize_t copied = 0; + bool cmsg = false; + int err = 0; + long timeo; + + flags |= nonblock; + + if (unlikely(flags & MSG_ERRQUEUE)) + return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + do { + bool zc = false; + int chunk = 0; + + skb = tls_wait_data(sk, flags, timeo, &err); + if (!skb) + goto recv_end; + + rxm = strp_msg(skb); + if (!cmsg) { + int cerr; + + cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, + sizeof(ctx->control), &ctx->control); + cmsg = true; + control = ctx->control; + if (ctx->control != TLS_RECORD_TYPE_DATA) { + if (cerr || msg->msg_flags & MSG_CTRUNC) { + err = -EIO; + goto recv_end; + } + } + } else if (control != ctx->control) { + goto recv_end; + } + + if (!ctx->decrypted) { + int page_count; + int to_copy; + + page_count = iov_iter_npages(&msg->msg_iter, + MAX_SKB_FRAGS); + to_copy = rxm->full_len - tls_ctx->rx.overhead_size; + if (to_copy <= len && page_count < MAX_SKB_FRAGS && + likely(!(flags & MSG_PEEK))) { + struct scatterlist sgin[MAX_SKB_FRAGS + 1]; + char unused[21]; + int pages = 0; + + zc = true; + sg_init_table(sgin, MAX_SKB_FRAGS + 1); + sg_set_buf(&sgin[0], unused, 13); + + err = zerocopy_from_iter(sk, &msg->msg_iter, + to_copy, &pages, + &chunk, &sgin[1], + MAX_SKB_FRAGS, false); + if (err < 0) + goto fallback_to_reg_recv; + + err = decrypt_skb(sk, skb, sgin); + for (; pages > 0; pages--) + put_page(sg_page(&sgin[pages])); + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto recv_end; + } + } else { +fallback_to_reg_recv: + err = decrypt_skb(sk, skb, NULL); + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto recv_end; + } + } + ctx->decrypted = true; + } + + if (!zc) { + chunk = min_t(unsigned int, rxm->full_len, len); + err = skb_copy_datagram_msg(skb, rxm->offset, msg, + chunk); + if (err < 0) + goto recv_end; + } + + copied += chunk; + len -= chunk; + if (likely(!(flags & MSG_PEEK))) { + u8 control = ctx->control; + + if (tls_sw_advance_skb(sk, skb, chunk)) { + /* Return full control message to + * userspace before trying to parse + * another message type + */ + msg->msg_flags |= MSG_EOR; + if (control != TLS_RECORD_TYPE_DATA) + goto recv_end; + } + } + } while (len); + +recv_end: + release_sock(sk); + return copied ? : err; +} + +ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t len, unsigned int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sock->sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm = NULL; + struct sock *sk = sock->sk; + struct sk_buff *skb; + ssize_t copied = 0; + int err = 0; + long timeo; + int chunk; + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + skb = tls_wait_data(sk, flags, timeo, &err); + if (!skb) + goto splice_read_end; + + /* splice does not support reading control messages */ + if (ctx->control != TLS_RECORD_TYPE_DATA) { + err = -ENOTSUPP; + goto splice_read_end; + } + + if (!ctx->decrypted) { + err = decrypt_skb(sk, skb, NULL); + + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto splice_read_end; + } + ctx->decrypted = true; + } + rxm = strp_msg(skb); + + chunk = min_t(unsigned int, rxm->full_len, len); + copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); + if (copied < 0) + goto splice_read_end; + + if (likely(!(flags & MSG_PEEK))) + tls_sw_advance_skb(sk, skb, copied); + +splice_read_end: + release_sock(sk); + return copied ? : err; +} + +unsigned int tls_sw_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait) +{ + unsigned int ret; + struct sock *sk = sock->sk; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + + /* Grab POLLOUT and POLLHUP from the underlying socket */ + ret = ctx->sk_poll(file, sock, wait); + + /* Clear POLLIN bits, and set based on recv_pkt */ + ret &= ~(POLLIN | POLLRDNORM); + if (ctx->recv_pkt) + ret |= POLLIN | POLLRDNORM; + + return ret; +} + +static int tls_read_size(struct strparser *strp, struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(strp->sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + char header[tls_ctx->rx.prepend_size]; + struct strp_msg *rxm = strp_msg(skb); + size_t cipher_overhead; + size_t data_len = 0; + int ret; + + /* Verify that we have a full TLS header, or wait for more data */ + if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) + return 0; + + /* Linearize header to local buffer */ + ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); + + if (ret < 0) + goto read_failure; + + ctx->control = header[0]; + + data_len = ((header[4] & 0xFF) | (header[3] << 8)); + + cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size; + + if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) { + ret = -EMSGSIZE; + goto read_failure; + } + if (data_len < cipher_overhead) { + ret = -EBADMSG; + goto read_failure; + } + + if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) || + header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) { + ret = -EINVAL; + goto read_failure; + } + + return data_len + TLS_HEADER_SIZE; + +read_failure: + tls_err_abort(strp->sk, ret); + + return ret; +} + +static void tls_queue(struct strparser *strp, struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(strp->sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + struct strp_msg *rxm; + + rxm = strp_msg(skb); + + ctx->decrypted = false; + + ctx->recv_pkt = skb; + strp_pause(strp); + + strp->sk->sk_state_change(strp->sk); +} + +static void tls_data_ready(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); + + strp_data_ready(&ctx->strp); +} + +void tls_sw_free_resources(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); if (ctx->aead_send) crypto_free_aead(ctx->aead_send); + if (ctx->aead_recv) { + if (ctx->recv_pkt) { + kfree_skb(ctx->recv_pkt); + ctx->recv_pkt = NULL; + } + crypto_free_aead(ctx->aead_recv); + strp_stop(&ctx->strp); + write_lock_bh(&sk->sk_callback_lock); + sk->sk_data_ready = ctx->saved_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); + strp_done(&ctx->strp); + lock_sock(sk); + } tls_free_both_sg(sk); @@ -643,12 +1035,15 @@ void tls_sw_free_tx_resources(struct sock *sk) kfree(tls_ctx); } -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) { char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE]; struct tls_crypto_info *crypto_info; struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; struct tls_sw_context *sw_ctx; + struct cipher_context *cctx; + struct crypto_aead **aead; + struct strp_callbacks cb; u16 nonce_size, tag_size, iv_size, rec_seq_size; char *iv, *rec_seq; int rc = 0; @@ -658,22 +1053,29 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) goto out; } - if (ctx->priv_ctx) { - rc = -EEXIST; - goto out; - } - - sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); - if (!sw_ctx) { - rc = -ENOMEM; - goto out; + if (!ctx->priv_ctx) { + sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); + if (!sw_ctx) { + rc = -ENOMEM; + goto out; + } + crypto_init_wait(&sw_ctx->async_wait); + } else { + sw_ctx = ctx->priv_ctx; } - crypto_init_wait(&sw_ctx->async_wait); - ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; - crypto_info = &ctx->crypto_send; + if (tx) { + crypto_info = &ctx->crypto_send; + cctx = &ctx->tx; + aead = &sw_ctx->aead_send; + } else { + crypto_info = &ctx->crypto_recv; + cctx = &ctx->rx; + aead = &sw_ctx->aead_recv; + } + switch (crypto_info->cipher_type) { case TLS_CIPHER_AES_GCM_128: { nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; @@ -692,46 +1094,49 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) goto free_priv; } - ctx->prepend_size = TLS_HEADER_SIZE + nonce_size; - ctx->tag_size = tag_size; - ctx->overhead_size = ctx->prepend_size + ctx->tag_size; - ctx->iv_size = iv_size; - ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL); - if (!ctx->iv) { + cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; + cctx->tag_size = tag_size; + cctx->overhead_size = cctx->prepend_size + cctx->tag_size; + cctx->iv_size = iv_size; + cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + GFP_KERNEL); + if (!cctx->iv) { rc = -ENOMEM; goto free_priv; } - memcpy(ctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); - ctx->rec_seq_size = rec_seq_size; - ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); - if (!ctx->rec_seq) { + memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); + cctx->rec_seq_size = rec_seq_size; + cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); + if (!cctx->rec_seq) { rc = -ENOMEM; goto free_iv; } - memcpy(ctx->rec_seq, rec_seq, rec_seq_size); - - sg_init_table(sw_ctx->sg_encrypted_data, - ARRAY_SIZE(sw_ctx->sg_encrypted_data)); - sg_init_table(sw_ctx->sg_plaintext_data, - ARRAY_SIZE(sw_ctx->sg_plaintext_data)); - - sg_init_table(sw_ctx->sg_aead_in, 2); - sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, - sizeof(sw_ctx->aad_space)); - sg_unmark_end(&sw_ctx->sg_aead_in[1]); - sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); - sg_init_table(sw_ctx->sg_aead_out, 2); - sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, - sizeof(sw_ctx->aad_space)); - sg_unmark_end(&sw_ctx->sg_aead_out[1]); - sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); - - if (!sw_ctx->aead_send) { - sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0); - if (IS_ERR(sw_ctx->aead_send)) { - rc = PTR_ERR(sw_ctx->aead_send); - sw_ctx->aead_send = NULL; + memcpy(cctx->rec_seq, rec_seq, rec_seq_size); + + if (tx) { + sg_init_table(sw_ctx->sg_encrypted_data, + ARRAY_SIZE(sw_ctx->sg_encrypted_data)); + sg_init_table(sw_ctx->sg_plaintext_data, + ARRAY_SIZE(sw_ctx->sg_plaintext_data)); + + sg_init_table(sw_ctx->sg_aead_in, 2); + sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, + sizeof(sw_ctx->aad_space)); + sg_unmark_end(&sw_ctx->sg_aead_in[1]); + sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); + sg_init_table(sw_ctx->sg_aead_out, 2); + sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, + sizeof(sw_ctx->aad_space)); + sg_unmark_end(&sw_ctx->sg_aead_out[1]); + sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); + } + + if (!*aead) { + *aead = crypto_alloc_aead("gcm(aes)", 0, 0); + if (IS_ERR(*aead)) { + rc = PTR_ERR(*aead); + *aead = NULL; goto free_rec_seq; } } @@ -740,24 +1145,44 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE); - rc = crypto_aead_setkey(sw_ctx->aead_send, keyval, + rc = crypto_aead_setkey(*aead, keyval, TLS_CIPHER_AES_GCM_128_KEY_SIZE); if (rc) goto free_aead; - rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tag_size); - if (!rc) - return 0; + rc = crypto_aead_setauthsize(*aead, cctx->tag_size); + if (rc) + goto free_aead; + + if (!tx) { + /* Set up strparser */ + memset(&cb, 0, sizeof(cb)); + cb.rcv_msg = tls_queue; + cb.parse_msg = tls_read_size; + + strp_init(&sw_ctx->strp, sk, &cb); + + write_lock_bh(&sk->sk_callback_lock); + sw_ctx->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = tls_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + + sw_ctx->sk_poll = sk->sk_socket->ops->poll; + + strp_check_rcv(&sw_ctx->strp); + } + + goto out; free_aead: - crypto_free_aead(sw_ctx->aead_send); - sw_ctx->aead_send = NULL; + crypto_free_aead(*aead); + *aead = NULL; free_rec_seq: - kfree(ctx->rec_seq); - ctx->rec_seq = NULL; + kfree(cctx->rec_seq); + cctx->rec_seq = NULL; free_iv: - kfree(ctx->iv); - ctx->iv = NULL; + kfree(ctx->tx.iv); + ctx->tx.iv = NULL; free_priv: kfree(ctx->priv_ctx); ctx->priv_ctx = NULL; |