Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL improvements #5716

Merged
merged 5 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion ports/espressif/common-hal/ssl/SSLContext.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include "bindings/espidf/__init__.h"

#include "components/mbedtls/esp_crt_bundle/include/esp_crt_bundle.h"

#include "py/runtime.h"

void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t *self) {
Expand All @@ -47,6 +49,11 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
sock->ssl_context = self;
sock->sock = socket;

// Create a copy of the ESP-TLS config object and store the server hostname
// Note that ESP-TLS will use common_name for both SNI and verification
memcpy(&sock->ssl_config, &self->ssl_config, sizeof(self->ssl_config));
sock->ssl_config.common_name = server_hostname;

esp_tls_t *tls_handle = esp_tls_init();
if (tls_handle == NULL) {
mp_raise_espidf_MemoryError();
Expand All @@ -55,6 +62,36 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t

// TODO: do something with the original socket? Don't call a close on the internal LWIP.

// Should we store server hostname on the socket in case connect is called with an ip?
return sock;
}

void common_hal_ssl_sslcontext_load_verify_locations(ssl_sslcontext_obj_t *self,
const char *cadata) {
self->ssl_config.crt_bundle_attach = NULL;
self->ssl_config.use_global_ca_store = false;
self->ssl_config.cacert_buf = (const unsigned char *)cadata;
self->ssl_config.cacert_bytes = strlen(cadata) + 1;
}

void common_hal_ssl_sslcontext_set_default_verify_paths(ssl_sslcontext_obj_t *self) {
self->ssl_config.crt_bundle_attach = esp_crt_bundle_attach;
self->ssl_config.use_global_ca_store = true;
self->ssl_config.cacert_buf = NULL;
self->ssl_config.cacert_bytes = 0;
}

bool common_hal_ssl_sslcontext_get_check_hostname(ssl_sslcontext_obj_t *self) {
if (self->ssl_config.skip_common_name) {
return 0;
} else {
return 1;
}
timhawes marked this conversation as resolved.
Show resolved Hide resolved
}

void common_hal_ssl_sslcontext_set_check_hostname(ssl_sslcontext_obj_t *self, bool value) {
if (value) {
self->ssl_config.skip_common_name = 0;
} else {
self->ssl_config.skip_common_name = 1;
}
timhawes marked this conversation as resolved.
Show resolved Hide resolved
}
4 changes: 1 addition & 3 deletions ports/espressif/common-hal/ssl/SSLSocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self) {

void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self,
const char *host, size_t hostlen, uint32_t port) {
esp_tls_cfg_t *tls_config = NULL;
tls_config = &self->ssl_context->ssl_config;
int result = esp_tls_conn_new_sync(host, hostlen, port, tls_config, self->tls);
int result = esp_tls_conn_new_sync(host, hostlen, port, &self->ssl_config, self->tls);
self->sock->connected = result >= 0;
if (result < 0) {
int esp_tls_code;
Expand Down
1 change: 1 addition & 0 deletions ports/espressif/common-hal/ssl/SSLSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ typedef struct {
socketpool_socket_obj_t *sock;
esp_tls_t *tls;
ssl_sslcontext_obj_t *ssl_context;
esp_tls_cfg_t ssl_config;
} ssl_sslsocket_obj_t;

#endif // MICROPY_INCLUDED_ESPRESSIF_COMMON_HAL_SSL_SSLSOCKET_H
76 changes: 71 additions & 5 deletions shared-bindings/ssl/SSLContext.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include "py/objtuple.h"
#include "py/objlist.h"
#include "py/objproperty.h"
#include "py/runtime.h"
#include "py/mperrno.h"

Expand All @@ -51,10 +52,69 @@ STATIC mp_obj_t ssl_sslcontext_make_new(const mp_obj_type_t *type, size_t n_args
return MP_OBJ_FROM_PTR(s);
}

//| def wrap_socket(sock: socketpool.Socket, *, server_side: bool = False, server_hostname: Optional[str] = None) -> ssl.SSLSocket:
//| """Wraps the socket into a socket-compatible class that handles SSL negotiation.
//| The socket must be of type SOCK_STREAM."""
//| ...
//| def load_verify_locations(self, cadata: Optional[str] = None) -> None:
//| """Load a set of certification authority (CA) certificates used to validate
//| other peers' certificates."""
//|

STATIC mp_obj_t ssl_sslcontext_load_verify_locations(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
enum { ARG_cadata };
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_cadata, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_obj = mp_const_none} },
};
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(pos_args[0]);

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

const char *cadata = mp_obj_str_get_str(args[ARG_cadata].u_obj);

common_hal_ssl_sslcontext_load_verify_locations(self, cadata);
return mp_const_none;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_load_verify_locations_obj, 1, ssl_sslcontext_load_verify_locations);

//| def set_default_verify_paths(self) -> None:
//| """Load a set of default certification authority (CA) certificates."""
//|

STATIC mp_obj_t ssl_sslcontext_set_default_verify_paths(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(pos_args[0]);

common_hal_ssl_sslcontext_set_default_verify_paths(self);
return mp_const_none;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_set_default_verify_paths_obj, 1, ssl_sslcontext_set_default_verify_paths);

//| check_hostname: bool
//| """Whether to match the peer certificate's hostname."""
//|

STATIC mp_obj_t ssl_sslcontext_get_check_hostname(mp_obj_t self_in) {
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(self_in);

return mp_obj_new_bool(common_hal_ssl_sslcontext_get_check_hostname(self));
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslcontext_get_check_hostname_obj, ssl_sslcontext_get_check_hostname);

STATIC mp_obj_t ssl_sslcontext_set_check_hostname(mp_obj_t self_in, mp_obj_t value) {
ssl_sslcontext_obj_t *self = MP_OBJ_TO_PTR(self_in);

common_hal_ssl_sslcontext_set_check_hostname(self, mp_obj_is_true(value));
return mp_const_none;
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_sslcontext_set_check_hostname_obj, ssl_sslcontext_set_check_hostname);

const mp_obj_property_t ssl_sslcontext_check_hostname_obj = {
.base.type = &mp_type_property,
.proxy = {(mp_obj_t)&ssl_sslcontext_get_check_hostname_obj,
(mp_obj_t)&ssl_sslcontext_set_check_hostname_obj,
MP_ROM_NONE},
};

//| def wrap_socket(self, sock: socketpool.Socket, *, server_side: bool = False, server_hostname: Optional[str] = None) -> ssl.SSLSocket:
//| """Wraps the socket into a socket-compatible class that handles SSL negotiation.
//| The socket must be of type SOCK_STREAM."""
//|

STATIC mp_obj_t ssl_sslcontext_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
Expand All @@ -69,7 +129,10 @@ STATIC mp_obj_t ssl_sslcontext_wrap_socket(size_t n_args, const mp_obj_t *pos_ar
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

const char *server_hostname = mp_obj_str_get_str(args[ARG_server_hostname].u_obj);
const char *server_hostname = NULL;
if (args[ARG_server_hostname].u_obj != mp_const_none) {
server_hostname = mp_obj_str_get_str(args[ARG_server_hostname].u_obj);
}
bool server_side = args[ARG_server_side].u_bool;
if (server_side && server_hostname != NULL) {
mp_raise_ValueError(translate("Server side context cannot have hostname"));
Expand All @@ -83,6 +146,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_wrap_socket_obj, 1, ssl_sslcont

STATIC const mp_rom_map_elem_t ssl_sslcontext_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_sslcontext_wrap_socket_obj) },
{ MP_ROM_QSTR(MP_QSTR_load_verify_locations), MP_ROM_PTR(&ssl_sslcontext_load_verify_locations_obj) },
{ MP_ROM_QSTR(MP_QSTR_set_default_verify_paths), MP_ROM_PTR(&ssl_sslcontext_set_default_verify_paths_obj) },
{ MP_ROM_QSTR(MP_QSTR_check_hostname), MP_ROM_PTR(&ssl_sslcontext_check_hostname_obj) },
};

STATIC MP_DEFINE_CONST_DICT(ssl_sslcontext_locals_dict, ssl_sslcontext_locals_dict_table);
Expand Down
8 changes: 8 additions & 0 deletions shared-bindings/ssl/SSLContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,12 @@ void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t *self);
ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t *self,
socketpool_socket_obj_t *sock, bool server_side, const char *server_hostname);

void common_hal_ssl_sslcontext_load_verify_locations(ssl_sslcontext_obj_t *self,
const char *cadata);

void common_hal_ssl_sslcontext_set_default_verify_paths(ssl_sslcontext_obj_t *self);

bool common_hal_ssl_sslcontext_get_check_hostname(ssl_sslcontext_obj_t *self);
void common_hal_ssl_sslcontext_set_check_hostname(ssl_sslcontext_obj_t *self, bool value);

#endif // MICROPY_INCLUDED_SHARED_BINDINGS_SSL_SSLCONTEXT_H