Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge connection info into existing connection file if it already exists #1133

Merged
merged 6 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from IPython.core.profiledir import ProfileDir
from IPython.core.shellapp import InteractiveShellApp, shell_aliases, shell_flags
from jupyter_client import write_connection_file
from jupyter_client.connect import ConnectionFileMixin
from jupyter_client.session import Session, session_aliases, session_flags
from jupyter_core.paths import jupyter_runtime_dir
Expand All @@ -44,10 +43,11 @@
from traitlets.utils.importstring import import_item
from zmq.eventloop.zmqstream import ZMQStream

from .control import ControlThread
from .heartbeat import Heartbeat
from .connect import get_connection_info, write_connection_file

# local imports
from .control import ControlThread
from .heartbeat import Heartbeat
from .iostream import IOPubThread
from .ipkernel import IPythonKernel
from .parentpoller import ParentPollerUnix, ParentPollerWindows
Expand Down Expand Up @@ -260,12 +260,7 @@ def _bind_socket(self, s, port):
def write_connection_file(self):
"""write connection info to JSON file"""
cf = self.abs_connection_file
if os.path.exists(cf):
self.log.debug("Connection file %s already exists", cf)
return
self.log.debug("Writing connection file: %s", cf)
write_connection_file(
cf,
connection_info = dict(
ip=self.ip,
key=self.session.key,
transport=self.transport,
Expand All @@ -275,6 +270,19 @@ def write_connection_file(self):
iopub_port=self.iopub_port,
control_port=self.control_port,
)
if os.path.exists(cf):
# If the file exists, merge our info into it. For example, if the
# original file had port number 0, we update with the actual port
# used.
existing_connection_info = get_connection_info(cf, unpack=True)
connection_info = dict(existing_connection_info, **connection_info)
if connection_info == existing_connection_info:
self.log.debug("Connection file %s with current information already exists", cf)
return

self.log.debug("Writing connection file: %s", cf)

write_connection_file(cf, **connection_info)

def cleanup_connection_file(self):
"""Clean up our connection file."""
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/tests/test_ipkernel_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class user_mod:
__dict__ = {}


async def test_properities(ipkernel: IPythonKernel) -> None:
async def test_properties(ipkernel: IPythonKernel) -> None:
ipkernel.user_module = user_mod()
ipkernel.user_ns = {}

Expand Down
71 changes: 71 additions & 0 deletions ipykernel/tests/test_kernelapp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import os
import threading
import time
from unittest.mock import patch

import pytest
from jupyter_core.paths import secure_write
from traitlets.config.loader import Config

from ipykernel.kernelapp import IPKernelApp

from .conftest import MockKernel
from .utils import TemporaryWorkingDirectory

try:
import trio
Expand Down Expand Up @@ -47,6 +51,73 @@ def trigger_stop():
app.close()


@pytest.mark.skipif(os.name == "nt", reason="permission errors on windows")
def test_merge_connection_file():
cfg = Config()
with TemporaryWorkingDirectory() as d:
cfg.ProfileDir.location = d
cf = os.path.join(d, "kernel.json")
initial_connection_info = {
"ip": "*",
"transport": "tcp",
"shell_port": 0,
"hb_port": 0,
"iopub_port": 0,
"stdin_port": 0,
"control_port": 53555,
"key": "abc123",
"signature_scheme": "hmac-sha256",
"kernel_name": "My Kernel",
}
# We cannot use connect.write_connection_file since
# it replaces port number 0 with a random port
# and we want IPKernelApp to do that replacement.
with secure_write(cf) as f:
json.dump(initial_connection_info, f)
assert os.path.exists(cf)

app = IPKernelApp(config=cfg, connection_file=cf)

# Calling app.initialize() does not work in the test, so we call the relevant functions that initialize() calls
# We must pass in an empty argv, otherwise the default is to try to parse the test runner's argv
super(IPKernelApp, app).initialize(argv=[""])
app.init_connection_file()
app.init_sockets()
app.init_heartbeat()
app.write_connection_file()

# Initialize should have merged the actual connection info
# with the connection info in the file
assert cf == app.abs_connection_file
assert os.path.exists(cf)

with open(cf) as f:
new_connection_info = json.load(f)

# ports originally set as 0 have been replaced
for port in ("shell", "hb", "iopub", "stdin"):
key = f"{port}_port"
# We initially had the port as 0
assert initial_connection_info[key] == 0
# the port is not 0 now
assert new_connection_info[key] > 0
# the port matches the port the kernel actually used
assert new_connection_info[key] == getattr(app, key), f"{key}"
del new_connection_info[key]
del initial_connection_info[key]

# The wildcard ip address was also replaced
assert new_connection_info["ip"] != "*"
del new_connection_info["ip"]
del initial_connection_info["ip"]

# everything else in the connection file is the same
assert initial_connection_info == new_connection_info

app.close()
os.remove(cf)


@pytest.mark.skipif(trio is None, reason="requires trio")
def test_trio_loop():
app = IPKernelApp(trio_loop=True)
Expand Down