#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "qt_idf_tls.h"
#include "dlg/dlg.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/ssl.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/debug.h"
#include "mbedtls/platform.h"
#include "mbedtls/timing.h"
#include "mbedtls/entropy.h"
#include "mbedtls/error.h"
typedef struct
{
mbedtls_net_context socket_fd;
mbedtls_entropy_context entropy;
mbedtls_ctr_drbg_context ctr_drbg;
mbedtls_ssl_context ssl;
mbedtls_ssl_config ssl_conf;
mbedtls_x509_crt ca_cert;
mbedtls_x509_crt client_cert;
mbedtls_pk_context private_key;
}qtf_tls_handle_t;
#if defined(MBEDTLS_DEBUG_C)
static void _ssl_debug(void *ctx, int level, const char *file, int line, const char *str)
{
printf("[mbedTLS]:[%s]:[%d]: %s\r\n", (file), line, (str));
}
#endif
static int _mbedtls_tcp_connect(mbedtls_net_context *ctx, const char *host, uint16_t port)
{
int ret = 0;
char port_str[6] = {0};
snprintf(port_str, sizeof(port_str), "%d", port);
ret = mbedtls_net_connect(ctx, host, port_str, MBEDTLS_NET_PROTO_TCP);
if(ret != 0)
{
dlg_error("mbedtls_net_connect connect failed returned 0x%04x errno: %d", ret < 0 ? -ret : ret, errno);
return ret;
}
ret = mbedtls_net_set_block(ctx);
if(ret != 0)
{
dlg_error("mbedtls_net_set_block failed returned 0x%04x errno: %d", ret < 0 ? -ret : ret, errno);
return ret;
}
return 0;
}
static int _tls_net_init(qtf_tls_handle_t *handle, qtf_tls_conn_param_t *param)
{
int ret = -1;
mbedtls_net_init(&(handle->socket_fd));
mbedtls_ssl_init(&(handle->ssl));
mbedtls_ssl_config_init(&(handle->ssl_conf));
mbedtls_ctr_drbg_init(&(handle->ctr_drbg));
mbedtls_entropy_init(&(handle->entropy));
#if defined(MBEDTLS_DEBUG_C)
mbedtls_debug_set_threshold(param->debug_level);
mbedtls_ssl_conf_dbg(&handle->ssl_conf, _ssl_debug, NULL);
#endif
#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
// tls 1.3 需要调用 psa_crypto_init
ret = psa_crypto_init();
if (ret != PSA_SUCCESS)
{
dlg_error("psa_crypto_init failed");
goto error;
}
#endif
ret = mbedtls_ssl_conf_max_frag_len(&(handle->ssl_conf), param->max_frag_len);
if (ret != 0)
{
dlg_error("mbedtls_ssl_conf_max_frag_len failed");
goto error;
}
ret = mbedtls_ctr_drbg_seed(&(handle->ctr_drbg), mbedtls_entropy_func, &(handle->entropy), NULL, 0);
if (ret != 0)
{
dlg_error("mbedtls_ctr_drbg_seed failed");
goto error;
}
// 证书认证方式
if(param->auth_mode == QTF_TLS_AUTH_MODE_CERT)
{
if(param->ca_cert && param->ca_cert_len)
{
mbedtls_x509_crt_init(&(handle->ca_cert));
ret = mbedtls_x509_crt_parse(&(handle->ca_cert), (const unsigned char *)param->ca_cert, param->ca_cert_len+1);
if (ret != 0)
{
dlg_error("mbedtls_x509_crt_parse failed");
goto error;
}
}
else
{
if(param->verify_mode != QTF_TLS_VERIFY_MODE_NONE)
{
dlg_error("invalid ca cert");
goto error;
}
dlg_info("verify mode is none");
}
mbedtls_ssl_conf_ca_chain(&(handle->ssl_conf), &(handle->ca_cert), NULL);
if (param->client_cert && param->client_cert_len && param->client_key && param->client_key_len)
{
// 双向认证
mbedtls_x509_crt_init(&(handle->client_cert));
mbedtls_pk_init(&(handle->private_key));
ret = mbedtls_x509_crt_parse(&(handle->client_cert), (const unsigned char *)param->client_cert, param->client_cert_len + 1);
if (ret != 0)
{
dlg_error("mbedtls_x509_crt_parse failed");
goto error;
}
ret = mbedtls_pk_parse_key(&(handle->private_key), (const unsigned char *)param->client_key,
param->client_key_len + 1, (const unsigned char *)param->client_key_passwd, param->client_key_passwd_len + 1, NULL, NULL);
if (ret != 0)
{
dlg_error("mbedtls_pk_parse_key failed");
goto error;
}
ret = mbedtls_ssl_conf_own_cert(&(handle->ssl_conf), &(handle->client_cert), &(handle->private_key));
if (ret != 0)
{
dlg_error("mbedtls_ssl_conf_own_cert failed");
goto error;
}
}
}
// psk 认证
else if(param->auth_mode == QTF_TLS_AUTH_MODE_PSK)
{
if(param->psk && param->psk_len && param->psk_id)
{
ret = mbedtls_ssl_conf_psk(&(handle->ssl_conf), (const unsigned char *)param->psk, param->psk_len,
(const unsigned char *)param->psk_id, strlen(param->psk_id));
if (ret != 0)
{
dlg_error("mbedtls_ssl_conf_psk failed");
goto error;
}
}
else
{
dlg_error("invalid psk");
goto error;
}
}
else
{
dlg_error("invalid auth mode");
goto error;
}
return ret;
error:
return ret;
}
static int __tls_net_deinit(qtf_tls_handle_t *handle)
{
mbedtls_net_free(&(handle->socket_fd));
mbedtls_ssl_free(&(handle->ssl));
mbedtls_ssl_config_free(&(handle->ssl_conf));
mbedtls_ctr_drbg_free(&(handle->ctr_drbg));
mbedtls_entropy_free(&(handle->entropy));
mbedtls_x509_crt_free(&(handle->ca_cert));
mbedtls_x509_crt_free(&(handle->client_cert));
mbedtls_pk_free(&(handle->private_key));
return 0;
}
void *qtf_tls_connect(const char *host, uint16_t port, qtf_tls_conn_param_t *param)
{
qtf_tls_handle_t *handle = NULL;
int ret = 0;
if(!host || !param)
{
dlg_error("invalid param");
goto error;
}
handle = (qtf_tls_handle_t *)malloc(sizeof(qtf_tls_handle_t));
if(!handle)
{
dlg_error("malloc failed");
goto error;
}
// 配置 tls 连接参数
ret = _tls_net_init(handle, param);
if(ret != 0)
{
dlg_error("tls net init failed");
goto error;
}
dlg_info("Connecting to %s:%d...", host, port);
// 建立 tcp 连接
ret = _mbedtls_tcp_connect(&(handle->socket_fd), host, port);
if (ret != 0)
{
dlg_error("mbedtls_tcp_connect failed");
goto error;
}
// 设置 tls 客户端
mbedtls_ssl_config_defaults(&(handle->ssl_conf), MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT);
mbedtls_ssl_conf_read_timeout(&(handle->ssl_conf), param->hanshake_timeout_ms);
// 配置认证等级
mbedtls_ssl_conf_authmode(&(handle->ssl_conf), param->verify_mode);
// 协议版本设置
if(param->tls_version == QTF_TLS_VERSION_TLS1_2)
{
mbedtls_ssl_conf_max_tls_version(&(handle->ssl_conf), MBEDTLS_SSL_VERSION_TLS1_2);
mbedtls_ssl_conf_min_tls_version(&(handle->ssl_conf), MBEDTLS_SSL_VERSION_TLS1_2);
}
else if(param->tls_version == QTF_TLS_VERSION_TLS1_3)
{
mbedtls_ssl_conf_max_tls_version(&(handle->ssl_conf), MBEDTLS_SSL_VERSION_TLS1_3);
mbedtls_ssl_conf_min_tls_version(&(handle->ssl_conf), MBEDTLS_SSL_VERSION_TLS1_3);
}
else
{
mbedtls_ssl_conf_max_tls_version(&(handle->ssl_conf), MBEDTLS_SSL_VERSION_TLS1_3);
mbedtls_ssl_conf_min_tls_version(&(handle->ssl_conf), MBEDTLS_SSL_VERSION_TLS1_2);
}
mbedtls_ssl_conf_rng(&(handle->ssl_conf), mbedtls_ctr_drbg_random, &(handle->ctr_drbg));
// todo: config ciphersuites
ret = mbedtls_ssl_setup(&(handle->ssl), &(handle->ssl_conf));
if (ret != 0)
{
dlg_error("mbedtls_ssl_setup failed");
goto error;
}
// 配置 tcp 收发函数,可自定义也可使用 mbedtls 的实现
mbedtls_ssl_set_bio(&(handle->ssl), &(handle->socket_fd), mbedtls_net_send, mbedtls_net_recv, mbedtls_net_recv_timeout);
// 设置服务器域名,目的是配置 SNI (Server Name Indication) 扩展。
ret = mbedtls_ssl_set_hostname(&(handle->ssl), host);
if (ret != 0)
{
dlg_error("mbedtls_ssl_set_hostname failed");
goto error;
}
// 执行握手流程
while ((ret = mbedtls_ssl_handshake(&(handle->ssl))) != 0)
{
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
{
dlg_error("mbedtls_ssl_handshake failed 0x%04x", ret < 0 ? -ret : ret);
goto error;
}
}
// 证书校验结果
ret = mbedtls_ssl_get_verify_result(&(handle->ssl));
if (ret < 0)
{
dlg_error("mbedtls_ssl_get_verify_result 0x%04x", ret < 0 ? -ret : ret);
goto error;
}
return handle;
error:
if(handle)
{
__tls_net_deinit(handle);
free(handle);
}
return NULL;
}
int qtf_tls_send(void *handle, const void *buf, uint32_t len, uint32_t timeout_ms)
{
qtf_tls_handle_t *tls_handle = (qtf_tls_handle_t *)handle;
int ret = 0;
if(!handle || !buf || !len)
{
dlg_error("invalid param");
return -1;
}
// 发送数据
while ((ret = mbedtls_ssl_write(&(tls_handle->ssl), (const unsigned char *)buf, len)) <= 0)
{
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
{
dlg_error("mbedtls_ssl_write failed 0x%04x", ret < 0 ? -ret : ret);
return -1;
}
}
return ret;
}
int qtf_tls_recv(void *handle, void *buf, uint32_t len, uint32_t timeout_ms)
{
qtf_tls_handle_t *tls_handle = (qtf_tls_handle_t *)handle;
int ret = 0;
if(!handle || !buf || !len)
{
dlg_error("invalid param");
return -1;
}
// 接收数据
while ((ret = mbedtls_ssl_read(&(tls_handle->ssl), (unsigned char *)buf, len)) <= 0)
{
if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE)
{
dlg_error("mbedtls_ssl_read failed 0x%04x", ret < 0 ? -ret : ret);
return -1;
}
}
return ret;
}
int qtf_tls_close(void *handle)
{
int ret = 0;
qtf_tls_handle_t *tls_handle = (qtf_tls_handle_t *)handle;
if (!handle)
{
dlg_error("invalid param");
return -1;
}
// 关闭连接,释放资源
do
{
ret = mbedtls_ssl_close_notify(&(tls_handle->ssl));
} while (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE);
__tls_net_deinit(tls_handle);
free(tls_handle);
return 0;
}