Skip to content

Commit

Permalink
Merge pull request python-trio#814 from njsmith/fixup-743
Browse files Browse the repository at this point in the history
Tweaks for python-trio#743
  • Loading branch information
njsmith authored Dec 29, 2018
2 parents 0e6d2e6 + 0fd4716 commit 2cde5fb
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 65 deletions.
5 changes: 5 additions & 0 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ And if you're implementing a server, you can use :class:`SSLListener`:
:show-inheritance:
:members:

Some methods on :class:`SSLStream` raise :exc:`NeedHandshakeError` if
you call them before the handshake completes:

.. autoexception:: NeedHandshakeError


.. module:: trio.socket

Expand Down
7 changes: 7 additions & 0 deletions newsfragments/735.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
There are a number of methods on :class:`trio.ssl.SSLStream` that
report information about the negotiated TLS connection, like
``selected_alpn_protocol``, and thus cannot succeed until after the
handshake has been performed. Previously, we returned None from these
methods, like the stdlib :mod:`ssl` module does, but this is
confusing, because that can also be a valid return value. Now we raise
:exc:`trio.ssl.NeedHandshakeError` instead.
15 changes: 0 additions & 15 deletions newsfragments/743.misc.rst

This file was deleted.

2 changes: 1 addition & 1 deletion trio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
TrioInternalError, RunFinishedError, WouldBlock, Cancelled,
BusyResourceError, ClosedResourceError, MultiError, run, open_nursery,
open_cancel_scope, current_effective_deadline, TASK_STATUS_IGNORED,
current_time, BrokenResourceError, EndOfChannel, NoHandshakeError
current_time, BrokenResourceError, EndOfChannel
)

from ._timeouts import (
Expand Down
3 changes: 1 addition & 2 deletions trio/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ def _public(fn):

from ._exceptions import (
TrioInternalError, RunFinishedError, WouldBlock, Cancelled,
BusyResourceError, ClosedResourceError, BrokenResourceError, EndOfChannel,
NoHandshakeError
BusyResourceError, ClosedResourceError, BrokenResourceError, EndOfChannel
)

from ._multierror import MultiError
Expand Down
28 changes: 0 additions & 28 deletions trio/_core/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,34 +125,6 @@ class BrokenResourceError(Exception):
"""


class NoHandshakeError(Exception):
"""Raised when a method like ``select_alpn_protocol`` is called
before the handshake is established.
Some methods defined in the :class:`ssl.SSLSocket` from the stdlib
return ``None`` if the handshake hasn't happened yet.
These are:
- ``get_channel_binding``: https://docs.python.org/3/library/ssl.html#ssl.SSLSocket.get_channel_binding
- ``selected_alpn_protocol``: https://docs.python.org/3/library/ssl.html#ssl.SSLSocket.selected_alpn_protocol
- ``selected_npn_protocol``: https://docs.python.org/3/library/ssl.html#ssl.SSLSocket.selected_npn_protocol
Note that these methods might also return ``None```in other cases.
In case of calling `selected_alpn_protocol`` and ``selected_npn_protocol``
other cases of returning ``None`` are:
- If the other party does not support ALPN/NPN.
- If ``SSLContext.set_alpn_protocols()`` or ``SSLContext.set_npn_protocols()`` was not called.
and in the case of ``get_channel_binding``:
- If not connected.
"""


class EndOfChannel(Exception):
"""Raised when trying to receive from a :class:`trio.abc.ReceiveChannel`
that has no more data to receive.
Expand Down
36 changes: 25 additions & 11 deletions trio/_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@
################################################################


class NeedHandshakeError(Exception):
"""Some :class:`SSLStream` methods can't return any meaningful data until
after the handshake. If you call them before the handshake, they raise
this error.
"""


class _Once:
def __init__(self, afn, *args):
self._afn = afn
Expand Down Expand Up @@ -207,15 +215,6 @@ class SSLStream(Stream):
documentation on SSL/TLS as well. SSL/TLS is subtle and quick to
anger. Really. I'm not kidding.
To illustrate the point with an example, some of the methods of the
:class:`~ssl.SSLContext` return ``None`` when no handshake is established.
To make it behave more explicitly, we decided to raise `trio.core.NoHandshakeError`
in the :mod:`ssl` methods defined in ``_after_handshake``,
in case no handshake is established.
Note that these methods still return ``None`` in other cases, as detailed
in ``trio.core.NoHandshakeError``.
Args:
transport_stream (~trio.abc.Stream): The stream used to transport
encrypted data. Required.
Expand Down Expand Up @@ -290,6 +289,13 @@ class SSLStream(Stream):
Internally, this class is implemented using an instance of
:class:`ssl.SSLObject`, and all of :class:`~ssl.SSLObject`'s methods and
attributes are re-exported as methods and attributes on this class.
However, there is one difference: :class:`~ssl.SSLObject` has several
methods that return information about the encrypted connection, like
:meth:`~ssl.SSLSocket.cipher` or
:meth:`~ssl.SSLSocket.selected_alpn_protocol`. If you call them before the
handshake, when they can't possibly return useful data, then
:class:`ssl.SSLObject` returns None, but :class:`trio.ssl.SSLStream`
raises :exc:`NeedHandshakeError`.
This also means that if you register a SNI callback using
:obj:`~ssl.SSLContext.sni_callback`, then the first argument your callback
Expand Down Expand Up @@ -357,15 +363,23 @@ def __init__(
}

_after_handshake = {
"get_channel_binding",
"session_reused",
"getpeercert",
"selected_npn_protocol",
"cipher",
"shared_ciphers",
"compression",
"get_channel_binding",
"selected_alpn_protocol",
"version",
}

def __getattr__(self, name):
if name in self._forwarded:
if name in self._after_handshake and not self._handshook.done:
raise _core.NoHandshakeError
raise NeedHandshakeError(
"call do_handshake() before calling {!r}".format(name)
)

return getattr(self._ssl_object, name)
else:
Expand Down
2 changes: 1 addition & 1 deletion trio/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# here.

# Trio-specific symbols:
from ._ssl import (SSLStream, SSLListener)
from ._ssl import SSLStream, SSLListener, NeedHandshakeError

# Symbols re-exported from the stdlib ssl module:

Expand Down
14 changes: 7 additions & 7 deletions trio/tests/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .. import _core
from .._highlevel_socket import SocketStream, SocketListener
from .._highlevel_generic import aclose_forcefully
from .._core import ClosedResourceError, BrokenResourceError, NoHandshakeError
from .._core import ClosedResourceError, BrokenResourceError
from .._highlevel_open_tcp_stream import open_tcp_stream
from .. import ssl as tssl
from .. import socket as tsocket
Expand Down Expand Up @@ -1112,10 +1112,10 @@ async def client_side(cancel_scope):
async def test_selected_alpn_protocol_before_handshake():
client, server = ssl_memory_stream_pair()

with pytest.raises(NoHandshakeError):
with pytest.raises(tssl.NeedHandshakeError):
client.selected_alpn_protocol()

with pytest.raises(NoHandshakeError):
with pytest.raises(tssl.NeedHandshakeError):
server.selected_alpn_protocol()


Expand All @@ -1138,10 +1138,10 @@ async def test_selected_alpn_protocol_when_not_set():
async def test_selected_npn_protocol_before_handshake():
client, server = ssl_memory_stream_pair()

with pytest.raises(NoHandshakeError):
with pytest.raises(tssl.NeedHandshakeError):
client.selected_npn_protocol()

with pytest.raises(NoHandshakeError):
with pytest.raises(tssl.NeedHandshakeError):
server.selected_npn_protocol()


Expand All @@ -1164,10 +1164,10 @@ async def test_selected_npn_protocol_when_not_set():
async def test_get_channel_binding_before_handshake():
client, server = ssl_memory_stream_pair()

with pytest.raises(NoHandshakeError):
with pytest.raises(tssl.NeedHandshakeError):
client.get_channel_binding()

with pytest.raises(NoHandshakeError):
with pytest.raises(tssl.NeedHandshakeError):
server.get_channel_binding()


Expand Down

0 comments on commit 2cde5fb

Please sign in to comment.