-
Notifications
You must be signed in to change notification settings - Fork 221
/
Copy pathprocessproxy.py
1666 lines (1453 loc) · 70.3 KB
/
processproxy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Kernel managers that operate against a remote process."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import abc
import asyncio
import base64
import errno
import getpass
import json
import logging
import os
import random
import re
import signal
import subprocess
import sys
import time
import warnings
from calendar import timegm
from enum import Enum
from socket import (
AF_INET,
SHUT_RDWR,
SHUT_WR,
SO_REUSEADDR,
SOCK_STREAM,
SOL_SOCKET,
gethostbyname,
gethostname,
socket,
timeout,
)
from typing import Any
import paramiko
import pexpect
from Cryptodome.Cipher import AES, PKCS1_v1_5
from Cryptodome.PublicKey import RSA
from Cryptodome.Util.Padding import unpad
from jupyter_client import launch_kernel, localinterfaces
from jupyter_server import _tz
from jupyter_server.serverapp import random_ports
from paramiko.client import SSHClient
from tornado import web
from tornado.ioloop import PeriodicCallback
from traitlets.config import SingletonConfigurable
from zmq.ssh import tunnel
from ..sessions.kernelsessionmanager import KernelSessionManager
# Default logging level of paramiko produces too much noise - raise to warning only.
logging.getLogger("paramiko").setLevel(os.getenv("EG_SSH_LOG_LEVEL", logging.WARNING))
# Pop certain env variables that don't need to be logged, e.g. remote_pwd
env_pop_list = ["EG_REMOTE_PWD", "LS_COLORS"]
# Comma separated list of env variables that shouldn't be logged
sensitive_env_keys = os.getenv("EG_SENSITIVE_ENV_KEYS", "").lower().split(",")
redaction_mask = os.getenv("EG_REDACTION_MASK", "********")
default_kernel_launch_timeout = float(os.getenv("EG_KERNEL_LAUNCH_TIMEOUT", "30"))
max_poll_attempts = int(os.getenv("EG_MAX_POLL_ATTEMPTS", "10"))
poll_interval = float(os.getenv("EG_POLL_INTERVAL", "0.5"))
socket_timeout = float(os.getenv("EG_SOCKET_TIMEOUT", "0.005"))
tunneling_enabled = bool(os.getenv("EG_ENABLE_TUNNELING", "False").lower() == "true")
ssh_port = int(os.getenv("EG_SSH_PORT", "22"))
eg_response_ip = os.getenv("EG_RESPONSE_IP", None)
desired_response_port = int(os.getenv("EG_RESPONSE_PORT", 8877))
response_port_retries = int(os.getenv("EG_RESPONSE_PORT_RETRIES", 10))
response_addr_any = bool(os.getenv("EG_RESPONSE_ADDR_ANY", "False").lower() == "true")
connection_interval = (
poll_interval / 100.0
) # already polling, so make connection timeout a fraction of outer poll
# Minimum port range size and max retries
min_port_range_size = int(os.getenv("EG_MIN_PORT_RANGE_SIZE", "1000"))
max_port_range_retries = int(os.getenv("EG_MAX_PORT_RANGE_RETRIES", "5"))
# Number of seconds in 100 years as the max keep-alive interval value.
max_keep_alive_interval = 100 * 365 * 24 * 60 * 60
# Allow users to specify local ips (regular expressions can be used) that should not be included
# when determining the response address. For example, on systems with many network interfaces,
# some may have their IPs appear the local interfaces list (e.g., docker's 172.17.0.* is an example)
# that should not be used. This env can be used to indicate such IPs.
prohibited_local_ips = os.getenv("EG_PROHIBITED_LOCAL_IPS", "").split(",")
def _get_local_ip() -> str:
"""
Honor the prohibited IPs, locating the first not in the list.
"""
for ip in localinterfaces.public_ips():
is_prohibited = False
for prohibited_ip in prohibited_local_ips: # exhaust prohibited list, applying regexs
if re.match(prohibited_ip, ip):
is_prohibited = True
break
if not is_prohibited:
return ip
return localinterfaces.public_ips()[0] # all were prohibited, so go with the first
local_ip = _get_local_ip()
random.seed()
class KernelChannel(Enum):
"""
Enumeration used to better manage tunneling
"""
SHELL = "SHELL"
IOPUB = "IOPUB"
STDIN = "STDIN"
HEARTBEAT = "HB"
CONTROL = "CONTROL"
COMMUNICATION = (
"EG_COMM" # Optional channel for remote launcher to issue interrupts - NOT a ZMQ channel
)
class Response(asyncio.Event):
"""Combines the event behavior with the kernel launch response."""
_response = None
@property
def response(self):
return self._response
@response.setter
def response(self, value):
"""Set the response. NOTE: this marks the event as set."""
self._response = value
self.set()
class ResponseManager(SingletonConfigurable):
"""Singleton that manages the responses from each kernel launcher at startup.
This singleton does the following:
1. Acquires a public and private RSA key pair at first use to encrypt and decrypt the
received responses. The public key is sent to the launcher during startup
and is used by the launcher to encrypt the AES key the launcher uses to encrypt
the connection information, while the private key remains in the server and is
used to decrypt the AES key from the response - which it then uses to decrypt
the connection information.
2. Creates a single socket based on the configuration settings that is listened on
via a periodic callback.
3. On receipt, it decrypts the response (key then connection info) and posts the
response payload to a map identified by the kernel_id embedded in the response.
4. Provides a wait mechanism for callers to poll to get their connection info
based on their registration (of kernel_id).
"""
KEY_SIZE = 1024 # Can be small since it's only used to {en,de}crypt the AES key.
_instance = None
def __init__(self, **kwargs: dict[str, Any] | None):
"""Initialize the manager."""
super().__init__(**kwargs)
self._response_ip = None
self._response_port = None
self._response_socket = None
self._connection_processor = None
# Create encryption keys...
self._private_key = RSA.generate(ResponseManager.KEY_SIZE)
self._public_key = self._private_key.publickey()
self._public_pem = self._public_key.export_key("PEM")
# Event facility...
self._response_registry = {}
# Start the response manager (create socket, periodic callback, etc.) ...
self._start_response_manager()
@property
def public_key(self) -> str:
"""Provides the string-form of public key PEM with header/footer/newlines stripped."""
return (
self._public_pem.decode()
.replace("-----BEGIN PUBLIC KEY-----", "")
.replace("-----END PUBLIC KEY-----", "")
.replace("\n", "")
)
@property
def response_address(self) -> str:
return self._response_ip + ":" + str(self._response_port)
def register_event(self, kernel_id: str) -> None:
"""Register kernel_id so its connection information can be processed."""
self._response_registry[kernel_id] = Response()
async def get_connection_info(self, kernel_id: str) -> dict:
"""Performs a timeout wait on the event, returning the conenction information on completion."""
await asyncio.wait_for(self._response_registry[kernel_id].wait(), connection_interval)
return self._response_registry.pop(kernel_id).response
def _prepare_response_socket(self) -> None:
"""Prepares the response socket on which connection info arrives from remote kernel launcher."""
s = socket(AF_INET, SOCK_STREAM)
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
# If response_addr_any is enabled (default disabled), we will permit the server to listen
# on all addresses, else we will honor a configured response IP (via env) over the local IP
# (which is the default).
# Multiple IP bindings should be configured for containerized configurations (k8s) that need to
# launch kernels into external YARN clusters.
bind_ip = local_ip if eg_response_ip is None else eg_response_ip
bind_ip = bind_ip if response_addr_any is False else ""
response_port = desired_response_port
for port in random_ports(response_port, response_port_retries + 1):
try:
s.bind((bind_ip, port))
except OSError as e:
if e.errno == errno.EADDRINUSE:
self.log.info(f"Response port {port} is already in use, trying another port...")
continue
elif e.errno in (errno.EACCES, getattr(errno, "WSAEACCES", errno.EACCES)):
self.log.warning(
f"Permission to bind to response port {port} denied - continuing..."
)
continue
else:
msg = f"Failed to bind to port '{port}' for response address due to: '{e}'"
raise RuntimeError(msg) from e
else:
response_port = port
break
else:
msg = f"No available response port could be found after {response_port_retries + 1} attempts"
self.log.critical(msg)
raise RuntimeError(msg)
self.log.info(
f"Enterprise Gateway is bound to port {response_port} "
f"for remote kernel connection information."
)
s.listen(128)
s.settimeout(socket_timeout)
self._response_socket = s
self._response_port = response_port
self._response_ip = local_ip if eg_response_ip is None else eg_response_ip
def _start_response_manager(self) -> None:
"""If not already started, creates and starts the periodic callback to process connections."""
if self._response_socket is None:
self._prepare_response_socket()
if self._connection_processor is None:
self._connection_processor = PeriodicCallback(self._process_connections, 0.1, 0.1)
self._connection_processor.start()
def stop_response_manager(self) -> None:
"""Stops the connection processor."""
if self._connection_processor is not None:
self._connection_processor.stop()
self._connection_processor = None
if self._response_socket is not None:
self._response_socket = None
async def _process_connections(self) -> None:
"""Checks the socket for data, if found, decrypts the payload and posts to 'wait map'."""
loop = asyncio.get_event_loop()
data = ""
try:
conn, addr = await loop.sock_accept(self._response_socket)
while True:
buffer = await loop.sock_recv(conn, 1024)
if not buffer: # send is complete, process payload
self.log.debug(f"Received payload '{data}'")
payload = self._decode_payload(data)
self.log.debug(f"Decrypted payload '{payload}'")
self._post_connection(payload)
break
data = data + buffer.decode(
encoding="utf-8"
) # append what we received until we get no more...
conn.close()
except timeout:
pass
except Exception as ex:
self.log.error(f"Failure occurred processing connection: {ex}")
def _decode_payload(self, data: json) -> dict:
"""
Decodes the payload.
Decodes the payload, identifying the payload's version and returns a dictionary
representing the kernel's connection information.
Version "0" payloads do not specify a kernel-id within the payload, nor do they
include a 'key', 'version' or 'conn_info' fields. They are purely an AES encrypted
form of the base64-encoded JSON connection information, and encrypted using the
kernel-id as a key. Since no kernel-id is in the payload, we will capture the keys
of registered kernel-ids and attempt to decrypt the payload until we find the
appropriate registrant.
Version "1+" payloads are a base64-encoded JSON string consisting of a 'version', 'key'
and 'conn_info' fields. The 'key' field will be decrpyted using the private key to
reveal the AES key, which is then used to decrypt the `conn_info` field.
Once decryption has taken place, the connection information string is loaded into a
dictionary and returned.
"""
payload_str = base64.b64decode(data)
try:
payload = json.loads(payload_str)
# Get the version
version = payload.get("version")
if version is None:
msg = "Payload received from kernel does not include a version indicator!"
raise ValueError(msg)
self.log.debug(f"Version {version} payload received.")
if version == 1:
# Decrypt the AES key using the RSA private key
encrypted_aes_key = base64.b64decode(payload["key"].encode())
cipher = PKCS1_v1_5.new(self._private_key)
aes_key = cipher.decrypt(encrypted_aes_key, b"\x42")
# Per docs, don't convey that decryption returned sentinel. So just let
# things fail "naturally".
# Decrypt and unpad the connection information using the just-decrypted AES key
cipher = AES.new(aes_key, AES.MODE_ECB)
encrypted_connection_info = base64.b64decode(payload["conn_info"].encode())
connection_info_str = unpad(cipher.decrypt(encrypted_connection_info), 16).decode()
else:
msg = f"Unexpected version indicator received: {version}!"
raise ValueError(msg)
except Exception as ex:
# Could be version "0", walk the registrant kernel-ids and attempt to decrypt using each as a key.
# If none are found, re-raise the triggering exception.
self.log.debug(f"decode_payload exception - {ex.__class__.__name__}: {ex}")
connection_info_str = None
for kernel_id in self._response_registry:
aes_key = kernel_id[0:16]
try:
cipher = AES.new(aes_key.encode("utf-8"), AES.MODE_ECB)
decrypted_payload = cipher.decrypt(payload_str)
# Version "0" responses use custom padding, so remove that here.
connection_info_str = "".join(
[decrypted_payload.decode("utf-8").rsplit("}", 1)[0], "}"]
)
# Try to load as JSON
new_connection_info = json.loads(connection_info_str)
# Add kernel_id into dict, then dump back to string so this can be processed as valid response
new_connection_info["kernel_id"] = kernel_id
connection_info_str = json.dumps(new_connection_info)
self.log.warning(
"WARNING!!!! Legacy kernel response received for kernel_id '{}'! "
"Update kernel launchers to current version!".format(kernel_id)
)
break # If we're here, we made it!
except Exception as ex2:
# Any exception fails this experiment and we continue
self.log.debug(
"Received the following exception detecting legacy kernel response - {}: {}".format(
ex2.__class__.__name__, ex2
)
)
connection_info_str = None
if connection_info_str is None:
raise ex
# and convert to usable dictionary
connection_info = json.loads(connection_info_str)
return connection_info
def _post_connection(self, connection_info: dict) -> None:
"""Posts connection information into "wait map" based on kernel_id value."""
kernel_id = connection_info.get("kernel_id")
if kernel_id is None:
self.log.error("No kernel id found in response! Kernel launch will fail.")
return
if kernel_id not in self._response_registry:
self.log.error(
f"Kernel id '{kernel_id}' has not been registered and will not be processed!"
)
return
self.log.debug(f"Connection info received for kernel '{kernel_id}': {connection_info}")
self._response_registry[kernel_id].response = connection_info
class BaseProcessProxyABC(metaclass=abc.ABCMeta):
"""
Process Proxy Abstract Base Class.
Defines the required methods for process proxy classes. Some implementation is also performed
by these methods - common to all subclasses.
"""
def __init__(self, kernel_manager: RemoteKernelManager, proxy_config: dict): # noqa: F821
"""
Initialize the process proxy instance.
Parameters
----------
kernel_manager : RemoteKernelManager
The kernel manager instance tied to this process proxy. This drives the process proxy method calls.
proxy_config : dict
The dictionary of per-kernel config settings. If none are specified, this will be an empty dict.
"""
self.kernel_manager = kernel_manager
self.proxy_config = proxy_config
# Initialize to 0 IP primarily so restarts of remote kernels don't encounter local-only enforcement during
# relaunch (see jupyter_client.manager.start_kernel().
self.kernel_manager.ip = "0.0.0.0" # noqa
self.log = kernel_manager.log
# extract the kernel_id string from the connection file and set the KERNEL_ID environment variable
if self.kernel_manager.kernel_id is None:
self.kernel_manager.kernel_id = (
os.path.basename(self.kernel_manager.connection_file)
.replace("kernel-", "")
.replace(".json", "")
)
self.kernel_id = self.kernel_manager.kernel_id
self.kernel_launch_timeout = default_kernel_launch_timeout
self.lower_port = 0
self.upper_port = 0
self._validate_port_range()
# Handle authorization sets...
# Take union of unauthorized users...
self.unauthorized_users = self.kernel_manager.unauthorized_users
if proxy_config.get("unauthorized_users"):
self.unauthorized_users = self.unauthorized_users.union(
proxy_config.get("unauthorized_users").split(",")
)
# Let authorized users override global value - if set on kernelspec...
if proxy_config.get("authorized_users"):
self.authorized_users = set(proxy_config.get("authorized_users").split(","))
else:
self.authorized_users = self.kernel_manager.authorized_users
# Represents the local process (from popen) if applicable. Note that we could have local_proc = None even when
# the subclass is a LocalProcessProxy (or YarnProcessProxy). This will happen if EG is restarted and the
# persisted kernel-sessions indicate that its now running on a different server. In those cases, we use the ip
# member variable to determine if the persisted state is local or remote and use signals with the pid to
# implement the poll, kill and send_signal methods. As a result, what was a local kernel with one EG instance
# could be a remote kernel in a restarted EG instance - and vice versa.
self.local_proc = None
self.ip = None
self.pid = 0
self.pgid = 0
_remote_user = os.getenv("EG_REMOTE_USER")
self.remote_pwd = os.getenv("EG_REMOTE_PWD")
self._use_gss_raw = os.getenv("EG_REMOTE_GSS_SSH", "False")
if self._use_gss_raw.lower() not in ("", "true", "false"):
msg = (
"Invalid Value for EG_REMOTE_GSS_SSH expected one of "
'"", "True", "False", got {!r}'.format(self._use_gss_raw)
)
raise ValueError(msg)
self.use_gss = self._use_gss_raw == "true"
if self.use_gss:
if self.remote_pwd or _remote_user:
warnings.warn(
"Both `EG_REMOTE_GSS_SSH` and one of `EG_REMOTE_PWD` or "
"`EG_REMOTE_USER` is set. "
"Those options are mutually exclusive, you configuration may be incorrect. "
"EG_REMOTE_GSS_SSH will take priority.",
stacklevel=2,
)
self.remote_user = None
else:
self.remote_user = _remote_user if _remote_user else getpass.getuser()
@abc.abstractmethod
async def launch_process(self, kernel_cmd: str, **kwargs: dict[str, Any] | None) -> None:
"""
Provides basic implementation for launching the process corresponding to the process proxy.
All overrides should call this method via `super()` so that basic/common operations can be
performed. Leaf class implementations are required to perform the actual process launch
depending on the type of process proxy.
Parameters
----------
kernel_cmd : str
The properly formatted string composed from the argv stanza of the kernelspec with
all curly-braced substitutions performed.
kwargs : optional
Additional arguments used during the launch - primarily the env to use for the kernel.
"""
env_dict = kwargs.get("env")
if env_dict is None:
env_dict = dict(os.environ.copy())
kwargs.update({"env": env_dict})
# see if KERNEL_LAUNCH_TIMEOUT was included from user. If so, override default
if env_dict.get("KERNEL_LAUNCH_TIMEOUT"):
self.kernel_launch_timeout = float(env_dict.get("KERNEL_LAUNCH_TIMEOUT"))
# add the applicable kernel_id and language to the env dict
env_dict["KERNEL_ID"] = self.kernel_id
kernel_language = "unknown-kernel-language"
if len(self.kernel_manager.kernel_spec.language) > 0:
kernel_language = self.kernel_manager.kernel_spec.language.lower()
# if already set in env: stanza, let that override.
env_dict["KERNEL_LANGUAGE"] = env_dict.get("KERNEL_LANGUAGE", kernel_language)
# Remove any potential sensitive (e.g., passwords) or annoying values (e.g., LG_COLORS)
for k in env_pop_list:
env_dict.pop(k, None)
self._enforce_authorization(**kwargs)
# Filter sensitive values from being logged
env_copy = kwargs.get("env").copy()
if sensitive_env_keys:
for key in list(env_copy):
if any(phrase in key.lower() for phrase in sensitive_env_keys):
env_copy[key] = redaction_mask
self.log.debug(f"BaseProcessProxy.launch_process() env: {env_copy}")
def launch_kernel(
self, cmd: list[str], **kwargs: dict[str, Any] | None
) -> subprocess.Popen[str | bytes]:
"""
Returns the result of launching the kernel via Popen.
This method exists to allow process proxies to perform any final preparations for
launch, including the removal of any arguments that are not recoginized by Popen.
"""
# Remove kernel_headers
kwargs.pop("kernel_headers", None)
return launch_kernel(cmd, **kwargs)
def cleanup(self) -> None: # noqa
"""Performs optional cleanup after kernel is shutdown. Child classes are responsible for implementations."""
pass
def poll(self) -> Any | None:
"""
Determines if process proxy is still alive.
If this corresponds to a local (popen) process, poll() is called on the subprocess.
Otherwise, the zero signal is used to determine if active.
"""
if self.local_proc:
return self.local_proc.poll()
return self.send_signal(0)
def wait(self) -> int | None:
"""
Wait for the process to become inactive.
"""
# If we have a local_proc, call its wait method. This will clean up any defunct processes when the kernel
# is shutdown (when using waitAppCompletion = false). Otherwise (if no local_proc) we'll use polling to
# determine if a (remote or revived) process is still active.
if self.local_proc:
return self.local_proc.wait()
for _ in range(max_poll_attempts):
if self.poll():
time.sleep(poll_interval)
else:
break
else:
self.log.warning(
"Wait timeout of {} seconds exhausted. Continuing...".format(
max_poll_attempts * poll_interval
)
)
return None
def send_signal(self, signum: int) -> bool | None:
"""
Send signal `signum` to process proxy.
Parameters
----------
signum : int
The signal number to send. Zero is used to determine heartbeat.
"""
# if we have a local process, use its method, else determine if the ip is local or remote and issue
# the appropriate version to signal the process.
result = None
if self.local_proc:
if self.pgid > 0 and hasattr(os, "killpg"):
try:
os.killpg(self.pgid, signum)
return result
except OSError:
pass
result = self.local_proc.send_signal(signum)
else:
if self.ip and self.pid > 0:
if BaseProcessProxyABC.ip_is_local(self.ip):
result = self.local_signal(signum)
else:
result = self.remote_signal(signum)
return result
def kill(self) -> bool | None:
"""
Terminate the process proxy process.
First attempts graceful termination, then forced termination.
Note that this should only be necessary if the message-based kernel termination has
proven unsuccessful.
"""
# If we have a local process, use its method, else signal soft kill first before hard kill.
result = self.terminate() # Send -15 signal first
i = 1
while self.poll() is None and i <= max_poll_attempts:
time.sleep(poll_interval)
i = i + 1
if i > max_poll_attempts: # Send -9 signal if process is still alive
if self.local_proc:
result = self.local_proc.kill()
self.log.debug(f"BaseProcessProxy.kill(): {result}")
else:
if self.ip and self.pid > 0:
if BaseProcessProxyABC.ip_is_local(self.ip):
result = self.local_signal(signal.SIGKILL)
else:
result = self.remote_signal(signal.SIGKILL)
self.log.debug(f"SIGKILL signal sent to pid: {self.pid}")
return result
def terminate(self) -> bool | None:
"""
Gracefully terminate the process proxy process.
Note that this should only be necessary if the message-based kernel termination has
proven unsuccessful.
"""
# If we have a local process, use its method, else send signal SIGTERM to soft kill.
result = None
if self.local_proc:
result = self.local_proc.terminate()
self.log.debug(f"BaseProcessProxy.terminate(): {result}")
else:
if self.ip and self.pid > 0:
if BaseProcessProxyABC.ip_is_local(self.ip):
result = self.local_signal(signal.SIGTERM)
else:
result = self.remote_signal(signal.SIGTERM)
self.log.debug(f"SIGTERM signal sent to pid: {self.pid}")
return result
@staticmethod
def ip_is_local(ip: str) -> bool:
"""
Returns True if `ip` is considered local to this server, False otherwise.
"""
return localinterfaces.is_public_ip(ip) or localinterfaces.is_local_ip(ip)
def _get_ssh_client(self, host: str) -> SSHClient | None:
"""
Create a SSH Client based on host, username and password if provided.
If there is any AuthenticationException/SSHException, raise HTTP Error 403 as permission denied.
:param host:
:return: ssh client instance
"""
ssh = None
try:
ssh = paramiko.SSHClient()
ssh.load_system_host_keys()
host_ip = gethostbyname(host)
if self.use_gss:
self.log.debug("Connecting to remote host via GSS.")
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(host_ip, port=ssh_port, gss_auth=True)
else:
ssh.set_missing_host_key_policy(paramiko.RejectPolicy())
if self.remote_pwd:
self.log.debug("Connecting to remote host with username and password.")
ssh.connect(
host_ip,
port=ssh_port,
username=self.remote_user,
password=self.remote_pwd,
)
else:
self.log.debug("Connecting to remote host with ssh key.")
ssh.connect(host_ip, port=ssh_port, username=self.remote_user)
except Exception as e:
http_status_code = 500
current_host = gethostbyname(gethostname())
error_message = (
"Exception '{}' occurred when creating a SSHClient at {} connecting "
"to '{}:{}' with user '{}', message='{}'.".format(
type(e).__name__, current_host, host, ssh_port, self.remote_user, e
)
)
if e is paramiko.SSHException or paramiko.AuthenticationException:
http_status_code = 403
error_message_prefix = "Failed to authenticate SSHClient with password"
error_message = error_message_prefix + (
" provided" if self.remote_pwd else "-less SSH"
)
error_message = error_message + "and EG_REMOTE_GSS_SSH={!r} ({})".format(
self._use_gss_raw, self.use_gss
)
self.log_and_raise(http_status_code=http_status_code, reason=error_message)
return ssh
def rsh(self, host: str, command: str) -> list[str]:
"""
Executes a command on a remote host using ssh.
Parameters
----------
host : str
The host on which the command is executed.
command : str
The command to execute.
Returns
-------
lines : List
The command's output. If stdout is zero length, the stderr output is returned.
"""
ssh = self._get_ssh_client(host)
try:
stdin, stdout, stderr = ssh.exec_command(command, timeout=30)
lines = stdout.readlines()
if len(lines) == 0: # if nothing in stdout, return stderr
lines = stderr.readlines()
except Exception as e:
# Let caller decide if exception should be logged
raise e
finally:
if ssh:
ssh.close()
return lines
def remote_signal(self, signum: int) -> bool | None:
"""
Sends signal `signum` to process proxy on remote host.
"""
val = None
# if we have a process group, use that, else use the pid...
target = "-" + str(self.pgid) if self.pgid > 0 and signum > 0 else str(self.pid)
cmd = f"kill -{signum} {target}; echo $?"
if signum > 0: # only log if meaningful signal (not for poll)
self.log.debug(f"Sending signal: {signum} to target: {target} on host: {self.ip}")
try:
result = self.rsh(self.ip, cmd)
except Exception as e:
self.log.warning(
"Remote signal({}) to '{}' on host '{}' failed with exception '{}'.".format(
signum, target, self.ip, e
)
)
return False
for line in result:
val = line.strip()
if val == "0":
return None
return False
def local_signal(self, signum: int) -> bool | None:
"""
Sends signal `signum` to local process.
"""
# if we have a process group, use that, else use the pid...
target = "-" + str(self.pgid) if self.pgid > 0 and signum > 0 else str(self.pid)
if signum > 0: # only log if meaningful signal (not for poll)
self.log.debug(f"Sending signal: {signum} to target: {target}")
cmd = ["kill", "-" + str(signum), target]
with open(os.devnull, "w") as devnull:
result = subprocess.call(cmd, stderr=devnull)
if result == 0:
return None
return False
def _enforce_authorization(self, **kwargs: dict[str, Any] | None) -> None:
"""
Applies any authorization configuration using the kernel user.
Regardless of impersonation enablement, this method first adds the appropriate value for
EG_IMPERSONATION_ENABLED into environment (for use by kernelspecs), then ensures that KERNEL_USERNAME
has a value and is present in the environment (again, for use by kernelspecs). If unset, KERNEL_USERNAME
will be defaulted to the current user.
Authorization is performed by comparing the value of KERNEL_USERNAME with each value in the set of
unauthorized users. If any (case-sensitive) matches are found, HTTP error 403 (Forbidden) will be raised
- preventing the launch of the kernel. If the authorized_users set is non-empty, it is then checked to
ensure the value of KERNEL_USERNAME is present in that list. If not found, HTTP error 403 will be raised.
It is assumed that the kernelspec logic will take the appropriate steps to impersonate the user identified
by KERNEL_USERNAME when impersonation_enabled is True.
"""
# Get the env
env_dict = kwargs.get("env")
# Although it may already be set in the env, just override in case it was only set via command line or config
# Convert to string since execve() (called by Popen in base classes) wants string values.
env_dict["EG_IMPERSONATION_ENABLED"] = str(self.kernel_manager.impersonation_enabled)
# Ensure KERNEL_USERNAME is set
kernel_username = KernelSessionManager.get_kernel_username(**kwargs)
# Now perform authorization checks
if kernel_username in self.unauthorized_users:
self._raise_authorization_error(kernel_username, "not authorized")
# If authorized users are non-empty, ensure user is in that set.
if self.authorized_users.__len__() > 0 and kernel_username not in self.authorized_users:
self._raise_authorization_error(kernel_username, "not in the set of users authorized")
def _raise_authorization_error(self, kernel_username: str, differentiator_clause: str) -> None:
"""
Raises a 403 status code after building the appropriate message.
"""
kernel_name = self.kernel_manager.kernel_spec.display_name
kernel_clause = f" '{kernel_name}'." if kernel_name is not None else "s."
error_message = (
"User '{}' is {} to start kernel{} "
"Ensure KERNEL_USERNAME is set to an appropriate value and retry the request.".format(
kernel_username, differentiator_clause, kernel_clause
)
)
self.log_and_raise(http_status_code=403, reason=error_message)
def get_process_info(self) -> dict[str, Any]:
"""
Captures the base information necessary for kernel persistence relative to process proxies.
The superclass method must always be called first to ensure proper ordering. Since this is the
most base class, no call to `super()` is necessary.
"""
process_info = {"pid": self.pid, "pgid": self.pgid, "ip": self.ip}
return process_info
def load_process_info(self, process_info: dict[str, Any]) -> None:
"""
Loads the base information necessary for kernel persistence relative to process proxies.
The superclass method must always be called first to ensure proper ordering. Since this is the
most base class, no call to `super()` is necessary.
"""
self.pid = process_info["pid"]
self.pgid = process_info["pgid"]
self.ip = process_info["ip"]
self.kernel_manager.ip = process_info["ip"]
def _validate_port_range(self) -> None:
"""
Validates the port range configuration option to ensure appropriate values.
"""
# Let port_range override global value - if set on kernelspec...
port_range = self.kernel_manager.port_range
if self.proxy_config.get("port_range"):
port_range = self.proxy_config.get("port_range")
try:
port_ranges = port_range.split("..")
self.lower_port = int(port_ranges[0])
self.upper_port = int(port_ranges[1])
port_range_size = self.upper_port - self.lower_port
if port_range_size != 0:
if port_range_size < min_port_range_size:
self.log_and_raise(
http_status_code=500,
reason="Port range validation failed for range: '{}'. "
"Range size must be at least {} as specified by env EG_MIN_PORT_RANGE_SIZE".format(
port_range, min_port_range_size
),
)
# According to RFC 793, port is a 16-bit unsigned int. Which means the port
# numbers must be in the range (0, 65535). However, within that range,
# ports 0 - 1023 are called "well-known ports" and are typically reserved for
# specific purposes. For example, 0 is reserved for random port assignment,
# 80 is used for HTTP, 443 for TLS/SSL, 25 for SMTP, etc. But, there is
# flexibility as one can choose any port with the aforementioned protocols.
# Ports 1024 - 49151 are called "user or registered ports" that are bound to
# services running on the server listening to client connections. And, ports
# 49152 - 65535 are called "dynamic or ephemeral ports". A TCP connection
# has two endpoints. Each endpoint consists of an IP address and a port number.
# And, each connection is made up of a 4-tuple consisting of -- client-IP,
# client-port, server-IP, and server-port. A service runs on a server with a
# specific IP and is bound to a specific "user or registered port" that is
# advertised for clients to connect. So, when a client connects to a service
# running on a server, three out of 4-tuple - client-IP, client-port, server-IP -
# are already known. To be able to serve multiple clients concurrently, the
# server's IP stack assigns an ephemeral port for the connection to complete
# the 4-tuple.
#
# In case of JEG, we will accept ports in the range 1024 - 65535 as these days
# admins use dedicated hosts for individual services.
if self.lower_port < 1024 or self.lower_port > 65535:
self.log_and_raise(
http_status_code=500,
reason="Invalid port range '{}' specified. "
"Range for valid port numbers is (1024, 65535).".format(port_range),
)
if self.upper_port < 1024 or self.upper_port > 65535:
self.log_and_raise(
http_status_code=500,
reason="Invalid port range '{}' specified. "
"Range for valid port numbers is (1024, 65535).".format(port_range),
)
except ValueError as ve:
self.log_and_raise(
http_status_code=500,
reason="Port range validation failed for range: '{}'. "
"Error was: {}".format(port_range, ve),
)
except IndexError as ie:
self.log_and_raise(
http_status_code=500,
reason="Port range validation failed for range: '{}'. "
"Error was: {}".format(port_range, ie),
)
self.kernel_manager.port_range = port_range
def select_ports(self, count: int) -> list[int]:
"""
Selects and returns n random ports that adhere to the configured port range, if applicable.
Parameters
----------
count : int
The number of ports to return
Returns
-------
List - ports available and adhering to the configured port range
"""
ports = []
sockets = []
for _ in range(count):
sock = self.select_socket()
ports.append(sock.getsockname()[1])
sockets.append(sock)
for sock in sockets:
sock.close()
return ports
def select_socket(self, ip: str | None = "") -> socket:
"""
Creates and returns a socket whose port adheres to the configured port range, if applicable.
Parameters
----------
ip : str
Optional ip address to which the port is bound
Returns
-------
socket - Bound socket that is available and adheres to configured port range
"""
sock = socket(AF_INET, SOCK_STREAM)
found_port = False
retries = 0
while not found_port:
try:
sock.bind((ip, self._get_candidate_port()))
found_port = True
except Exception:
retries = retries + 1
if retries > max_port_range_retries:
self.log_and_raise(
http_status_code=500,
reason="Failed to locate port within range {} after {} "
"retries!".format(self.kernel_manager.port_range, max_port_range_retries),
)
return sock