Skip to content

Commit

Permalink
Dont start http server in Scheduler.__init__
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jul 8, 2021
1 parent 04b6be4 commit d6e9b99
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 66 deletions.
4 changes: 2 additions & 2 deletions distributed/cli/dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def del_pid_file():
resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))

loop = IOLoop.current()
logger.info("-" * 47)

scheduler = Scheduler(
loop=loop,
Expand All @@ -196,12 +195,13 @@ def del_pid_file():
http_prefix=dashboard_prefix,
**kwargs,
)
logger.info("-" * 47)

install_signal_handlers(loop)

async def run():
logger.info("-" * 47)
await scheduler
logger.info("-" * 47)
await scheduler.finished()

try:
Expand Down
3 changes: 3 additions & 0 deletions distributed/comm/tests/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ async def test_large_transfer(cleanup):


@pytest.mark.asyncio
@pytest.mark.filterwarnings(
"ignore:Dashboard and Scheduler are using the same server on port"
)
@pytest.mark.parametrize(
"dashboard,protocol,security,port",
[
Expand Down
2 changes: 1 addition & 1 deletion distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ async def handle_comm(self, comm):

logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op
await self

try:
while True:
try:
Expand Down
20 changes: 15 additions & 5 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,16 @@ async def _():
if self.status == Status.created:
await self._start()
await self.scheduler
await self._correct_state()
if self.workers:
await asyncio.wait(list(self.workers.values())) # maybe there are more
return self
try:
await self._correct_state()
if self.workers:
await asyncio.wait(
list(self.workers.values())
) # maybe there are more
return self
except Exception:
await self._close()
raise

return _().__await__()

Expand Down Expand Up @@ -428,7 +434,11 @@ async def _close(self):

await self.scheduler.close()
for w in self._created:
assert w.status == Status.closed, w.status
assert w.status in [
Status.closed,
# Failure during startup
Status.undefined,
], w.status

if hasattr(self, "_old_logging_level"):
silence_logging(self._old_logging_level)
Expand Down
17 changes: 11 additions & 6 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def test_Client_unused_kwargs_with_address(loop):


def test_Client_twice(loop):
with Client(loop=loop, silence_logs=False, dashboard_address=None) as c:
with Client(loop=loop, silence_logs=False, dashboard_address=None) as f:
with Client(loop=loop, silence_logs=False, dashboard_address=":0") as c:
with Client(loop=loop, silence_logs=False, dashboard_address=":0") as f:
assert c.cluster.scheduler.port != f.cluster.scheduler.port


Expand Down Expand Up @@ -1048,7 +1048,9 @@ async def test_no_workers(cleanup):

@pytest.mark.asyncio
async def test_cluster_names():
async with LocalCluster(processes=False, asynchronous=True) as unnamed_cluster:
async with LocalCluster(
processes=False, asynchronous=True, dashboard_address=":0"
) as unnamed_cluster:
async with LocalCluster(
processes=False, asynchronous=True, name="mycluster"
) as named_cluster:
Expand All @@ -1070,12 +1072,15 @@ async def test_local_cluster_redundant_kwarg(nanny):
# Extra arguments are forwarded to the worker class. Depending on
# whether we use the nanny or not, the error treatment is quite
# different and we should assert that an exception is raised
async with await LocalCluster(
typo_kwarg="foo", processes=nanny, n_workers=1
async with LocalCluster(
typo_kwarg="foo",
processes=nanny,
n_workers=1,
asynchronous=True,
) as cluster:

# This will never work but is a reliable way to block without hard
# coding any sleep values
async with Client(cluster) as c:
async with Client(cluster, asynchronous=True) as c:
f = c.submit(sleep, 0)
await f
3 changes: 0 additions & 3 deletions distributed/deploy/tests/test_spec_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import dask
from dask.distributed import Client, Nanny, Scheduler, SpecCluster, Worker

from distributed.compatibility import WINDOWS
from distributed.core import Status
from distributed.deploy.spec import ProcessInterface, close_clusters, run_spec
from distributed.metrics import time
Expand Down Expand Up @@ -218,7 +217,6 @@ async def test_restart(cleanup):
assert time() < start + 60


@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out")
@pytest.mark.asyncio
async def test_broken_worker():
with pytest.raises(Exception) as info:
Expand All @@ -232,7 +230,6 @@ async def test_broken_worker():
assert "Broken" in str(info.value)


@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out")
@pytest.mark.slow
def test_spec_close_clusters(loop):
workers = {0: {"cls": Worker}}
Expand Down
8 changes: 6 additions & 2 deletions distributed/deploy/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ def test_submit(self):
assert future.result() == 2

def test_context_manager(self):
with self.Cluster(**self.kwargs) as c:
kwargs = self.kwargs.copy()
kwargs.pop("dashboard_address")
with self.Cluster(dashboard_address=":54321", **kwargs) as c:
with Client(c) as e:
assert e.nthreads()

def test_no_workers(self):
with self.Cluster(0, scheduler_port=0, **self.kwargs):
kwargs = self.kwargs.copy()
kwargs.pop("dashboard_address")
with self.Cluster(0, dashboard_address=":54321", scheduler_port=0, **kwargs):
pass
47 changes: 29 additions & 18 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@
from .utils_perf import disable_gc_diagnosis, enable_gc_diagnosis
from .variable import VariableExtension

try:
import bokeh # noqa: F401

HAS_BOKEH = True
except ImportError:
HAS_BOKEH = False

try:
from cython import compiled
except ImportError:
Expand Down Expand Up @@ -3438,24 +3445,11 @@ def __init__(
default_port=self.default_port,
)

http_server_modules = dask.config.get("distributed.scheduler.http.routes")
show_dashboard = dashboard or (dashboard is None and dashboard_address)
missing_bokeh = False
# install vanilla route if show_dashboard but bokeh is not installed
if show_dashboard:
try:
import distributed.dashboard.scheduler
except ImportError:
missing_bokeh = True
http_server_modules.append("distributed.http.scheduler.missing_bokeh")
routes = get_handlers(
server=self, modules=http_server_modules, prefix=http_prefix
self._show_dashboard: bool = dashboard or (
dashboard is None and dashboard_address
)
self.start_http_server(routes, dashboard_address, default_port=8787)
if show_dashboard and not missing_bokeh:
distributed.dashboard.scheduler.connect(
self.http_application, self.http_server, self, prefix=http_prefix
)
self._dashboard_address = dashboard_address
self._http_prefix = http_prefix

# Communication state
self.loop = loop or IOLoop.current()
Expand Down Expand Up @@ -3759,6 +3753,23 @@ async def start(self):

self.clear_task_state()

http_server_modules = dask.config.get("distributed.scheduler.http.routes")
assert isinstance(http_server_modules, list)

# install vanilla route if show_dashboard but bokeh is not installed
if self._show_dashboard and not HAS_BOKEH:
http_server_modules.append("distributed.http.scheduler.missing_bokeh")
routes = get_handlers(
server=self, modules=http_server_modules, prefix=self._http_prefix
)
self.start_http_server(routes, self._dashboard_address, default_port=8787)
if self._show_dashboard and HAS_BOKEH:
import distributed.dashboard.scheduler

distributed.dashboard.scheduler.connect(
self.http_application, self.http_server, self, prefix=self._http_prefix
)

with suppress(AttributeError):
for c in self._worker_coroutines:
c.cancel()
Expand Down Expand Up @@ -5463,7 +5474,7 @@ async def gather(self, comm=None, keys=None, serializers=None):
def clear_task_state(self):
# XXX what about nested state such as ClientState.wants_what
# (see also fire-and-forget...)
logger.info("Clear task state")
logger.debug("Clear task state")
for collection in self._task_state_collections:
collection.clear()

Expand Down
3 changes: 1 addition & 2 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6378,7 +6378,6 @@ async def f(stacklevel):
assert "Dask Performance Report" in data
assert "x = da.random" in data
assert "Threads: 4" in data
assert "distributed.scheduler - INFO - Clear task state" in data
assert dask.__version__ in data

# Stacklevel two captures code two frames back -- which in this case
Expand Down Expand Up @@ -6745,7 +6744,7 @@ def f(x):
with LocalCluster(
n_workers=1,
processes=False,
dashboard_address=False,
dashboard_address=":0",
worker_dashboard_address=False,
) as cluster2:
with Client(cluster2) as c1:
Expand Down
7 changes: 7 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2936,3 +2936,10 @@ async def test_transition_counter(c, s, a, b):
assert s.transition_counter == 0
await c.submit(inc, 1)
assert s.transition_counter > 1


def test_init_twice_no_warning():
with pytest.warns(None) as records:
for _ in range(2):
Scheduler()
assert not records
26 changes: 26 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest
from tornado import gen
from tornado.ioloop import IOLoop

from distributed import Client, Nanny, Scheduler, Worker, config, default_client
from distributed.core import rpc
Expand All @@ -18,6 +19,7 @@
gen_test,
inc,
new_config,
start_cluster,
tls_only_security,
wait_for_port,
)
Expand Down Expand Up @@ -267,3 +269,27 @@ def test_tls_cluster(tls_client):
async def test_tls_scheduler(security, cleanup):
async with Scheduler(security=security, host="localhost") as s:
assert s.address.startswith("tls")


from distributed.core import Status


@pytest.mark.asyncio
@pytest.mark.parametrize("w_cls", [Worker, Nanny])
async def test_start_cluster_closes_scheduler_worker_failure(w_cls):
nthreads = [("127.0.0.1", 0)]
scheduler = "127.0.0.1"
loop = IOLoop.current()
for _ in range(2):
with pytest.raises(TypeError, match="got an unexpected keyword argument"):
await start_cluster(
nthreads,
scheduler,
loop,
security=None,
Worker=w_cls,
scheduler_kwargs={},
worker_kwargs={"dont": "start"},
)
assert all([s.status == Status.closed for s in Scheduler._instances])
assert all([w.status == Status.closed for w in Worker._instances])
57 changes: 30 additions & 27 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,34 +796,37 @@ async def start_cluster(
host=scheduler_addr,
**scheduler_kwargs,
)
workers = [
Worker(
s.address,
nthreads=ncore[1],
name=i,
security=security,
loop=loop,
validate=True,
host=ncore[0],
**(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs),
)
for i, ncore in enumerate(nthreads)
]
# for w in workers:
# w.rpc = workers[0].rpc

await asyncio.gather(*workers)
try:
workers = [
Worker(
s.address,
nthreads=ncore[1],
name=i,
security=security,
loop=loop,
validate=True,
host=ncore[0],
**(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs),
)
for i, ncore in enumerate(nthreads)
]
# for w in workers:
# w.rpc = workers[0].rpc
await asyncio.gather(*workers)

start = time()
while len(s.workers) < len(nthreads) or any(
comm.comm is None for comm in s.stream_comms.values()
):
await asyncio.sleep(0.01)
if time() - start > 5:
await asyncio.gather(*[w.close(timeout=1) for w in workers])
await s.close(fast=True)
raise Exception("Cluster creation timeout")
return s, workers
start = time()
while len(s.workers) < len(nthreads) or any(
comm.comm is None for comm in s.stream_comms.values()
):
await asyncio.sleep(0.01)
if time() - start > 5:
await asyncio.gather(*[w.close(timeout=1) for w in workers])
await s.close(fast=True)
raise Exception("Cluster creation timeout")
return s, workers
except Exception:
await s.close()
raise


async def end_cluster(s, workers):
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ parentdir_prefix = distributed-
addopts = -v -rsxfE --durations=20
filterwarnings =
error:Since distributed.*:PendingDeprecationWarning

# See https://github.com/dask/distributed/issues/4806
error:Port:UserWarning:distributed.node
minversion = 4
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
Expand Down

0 comments on commit d6e9b99

Please sign in to comment.