This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathbase_channel.py
158 lines (134 loc) · 5.2 KB
/
base_channel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import threading
import time
from abc import ABC, abstractmethod
from queue import Empty, Queue
from .log_utils import LogType, nni_log
from .commands import CommandType
INTERVAL_SECONDS = 0.5
class BaseChannel(ABC):
def __init__(self, args):
self.is_keep_parsed = args.node_count > 1
self.args = args
self.node_id = self.args.node_id
@abstractmethod
def _inner_send(self, message):
pass
@abstractmethod
def _inner_receive(self):
return []
@abstractmethod
def _inner_open(self):
pass
@abstractmethod
def _inner_close(self):
pass
def open(self):
# initialize receive, send threads.
self.is_running = True
self.receive_queue = Queue()
self.receive_thread = threading.Thread(target=self._receive_loop)
self.receive_thread.start()
self.send_queue = Queue()
self.send_thread = threading.Thread(target=self._send_loop)
self.send_thread.start()
self._inner_open()
client_info = {
"isReady": True,
"runnerId": self.args.runner_id,
"expId": self.args.exp_id,
}
nni_log(LogType.Info, 'Channel: send ready information %s' % client_info)
self.send(CommandType.Initialized, client_info)
def close(self):
self.is_running = False
try:
self._inner_close()
except Exception as err:
# ignore any error on closing
print("error on closing channel: %s" % err)
def send(self, command, data):
"""Send command to Training Service.
command: CommandType object.
data: string payload.
the message is sent synchronized.
"""
data["node"] = self.node_id
data = json.dumps(data)
data = data.encode('utf8')
message = b'%b%014d%b' % (command.value, len(data), data)
self.send_queue.put(message)
def sent(self):
return self.send_queue.qsize() == 0
def received(self):
return self.receive_queue.qsize() > 0
def receive(self):
"""Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str)
"""
command = None
data = None
try:
command_content = self.receive_queue.get(False)
if command_content is not None:
if (len(command_content) < 16):
# invalid header
nni_log(LogType.Error, 'incorrect command is found, command must be greater than 16 bytes!')
return None, None
header = command_content[:16]
command = CommandType(header[:2])
length = int(header[2:])
if (len(command_content)-16 != length):
nni_log(LogType.Error, 'incorrect command length, length {}, actual data length is {}, header {}.'
.format(length, len(command_content)-16, header))
return None, None
data = command_content[16:16+length]
data = json.loads(data.decode('utf8'))
if self.node_id is None:
nni_log(LogType.Info, 'Received command, header: [%s], data: [%s]' % (header, data))
else:
nni_log(LogType.Info, 'Received command(%s), header: [%s], data: [%s]' % (self.node_id, header, data))
except Empty:
# do nothing, if no command received.
pass
except Exception as identifier:
nni_log(LogType.Error, 'meet unhandled exception in base_channel: %s' % identifier)
return command, data
def _fetch_message(self, buffer, has_new_line=False):
messages = []
while(len(buffer)) >= 16:
header = buffer[:16]
length = int(header[2:])
message_length = length+16
total_length = message_length
if has_new_line:
total_length += 1
# break, if buffer is too short.
if len(buffer) < total_length:
break
data = buffer[16:message_length]
if has_new_line and 10 != buffer[total_length-1]:
nni_log(LogType.Error, 'end of message should be \\n, but got {}'.format(self.in_cache[total_length-1]))
buffer = buffer[total_length:]
messages.append(header + data)
return messages, buffer
def _receive_loop(self):
while (self.is_running):
messages = self._inner_receive()
if messages is not None:
for message in messages:
self.receive_queue.put(message)
time.sleep(INTERVAL_SECONDS)
def _send_loop(self):
while (self.is_running):
message = None
try:
# no sleep, since it's a block call with INTERVAL_SECONDS second timeout
message = self.send_queue.get(True, INTERVAL_SECONDS)
except Empty:
# do nothing, if no command received.
pass
if message is not None:
self._inner_send(message)