Skip to content

Commit

Permalink
Optionally force gathering of only relay (STUN/TURN) candidates.
Browse files Browse the repository at this point in the history
  • Loading branch information
eerimoq authored and jlaine committed Jan 30, 2023
1 parent daedc1e commit e56c96c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/aioice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from .about import __version__
from .candidate import Candidate
from .ice import Connection, ConnectionClosed
from .ice import Connection, ConnectionClosed, TransportPolicy

# Set default logging handler to avoid "No handler found" warnings.
logging.getLogger(__name__).addHandler(logging.NullHandler())
30 changes: 29 additions & 1 deletion src/aioice/ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@
_mdns = threading.local()


class TransportPolicy(enum.Enum):
ALL = 0
"""
All ICE candidates will be considered.
"""

RELAY = 1
"""
Only ICE candidates whose IP addresses are being relayed,
such as those being passed through a STUN or TURN server,
will be considered.
"""


async def get_or_create_mdns_protocol(subscriber: object) -> mdns.MDnsProtocol:
if not hasattr(_mdns, "lock"):
_mdns.lock = asyncio.Lock()
Expand Down Expand Up @@ -282,6 +296,7 @@ class Connection:
:param turn_transport: The transport for TURN server, `"udp"` or `"tcp"`.
:param use_ipv4: Whether to use IPv4 candidates.
:param use_ipv6: Whether to use IPv6 candidates.
:param transport_policy: Transport policy.
"""

def __init__(
Expand All @@ -296,6 +311,7 @@ def __init__(
turn_transport: str = "udp",
use_ipv4: bool = True,
use_ipv6: bool = True,
transport_policy: TransportPolicy = TransportPolicy.ALL,
) -> None:
self.ice_controlling = ice_controlling
#: Local username, automatically set to a random value.
Expand Down Expand Up @@ -342,6 +358,17 @@ def __init__(
self._use_ipv4 = use_ipv4
self._use_ipv6 = use_ipv6

if (
stun_server is None
and turn_server is None
and transport_policy == TransportPolicy.RELAY
):
raise ValueError(
"Relay transport policy requires a STUN and/or TURN server."
)

self._transport_policy = transport_policy

@property
def local_candidates(self) -> List[Candidate]:
"""
Expand Down Expand Up @@ -880,7 +907,8 @@ async def get_component_candidates(
port=candidate_address[1],
type="host",
)
candidates.append(protocol.local_candidate)
if self._transport_policy == TransportPolicy.ALL:
candidates.append(protocol.local_candidate)
self._protocols += host_protocols

# query STUN server for server-reflexive candidates (IPv4 only)
Expand Down
45 changes: 44 additions & 1 deletion tests/test_ice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import unittest
from unittest import mock

from aioice import Candidate, ice, mdns, stun
from aioice import Candidate, TransportPolicy, ice, mdns, stun

from .turnserver import run_turn_server
from .utils import asynctest, invite_accept
Expand Down Expand Up @@ -1200,6 +1200,49 @@ async def test_gather_candidates_oserror(self, mock_create):
await conn.gather_candidates()
self.assertEqual(conn.local_candidates, [])

@asynctest
async def test_gather_candidates_relay_only_no_servers(self):
with self.assertRaises(ValueError) as cm:
ice.Connection(ice_controlling=True, transport_policy=TransportPolicy.RELAY)
self.assertEqual(
str(cm.exception),
"Relay transport policy requires a STUN and/or TURN server.",
)

@asynctest
async def test_gather_candidates_relay_only_with_stun_server(self):
async with run_turn_server() as stun_server:
conn_a = ice.Connection(
ice_controlling=True,
stun_server=stun_server.udp_address,
transport_policy=TransportPolicy.RELAY,
)
conn_b = ice.Connection(ice_controlling=False)

# invite / accept
await invite_accept(conn_a, conn_b)

# we whould only have a server-reflexive candidate in connection a
self.assertCandidateTypes(conn_a, set(["srflx"]))

@asynctest
async def test_gather_candidates_relay_only_with_turn_server(self):
async with run_turn_server(users={"foo": "bar"}) as turn_server:
conn_a = ice.Connection(
ice_controlling=True,
turn_server=turn_server.udp_address,
turn_username="foo",
turn_password="bar",
transport_policy=TransportPolicy.RELAY,
)
conn_b = ice.Connection(ice_controlling=False)

# invite / accept
await invite_accept(conn_a, conn_b)

# we whould only have a server-reflexive candidate in connection a
self.assertCandidateTypes(conn_a, set(["relay"]))

@asynctest
async def test_repr(self):
conn = ice.Connection(ice_controlling=True)
Expand Down

0 comments on commit e56c96c

Please sign in to comment.