-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcomm_socket.py
152 lines (131 loc) · 5.38 KB
/
comm_socket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""Implementation of the CommSocket class for data exchange with the Simulink model."""
import array
import socket
import struct
import threading
import numpy as np
from .. import logger
class CommSocket:
"""Class defining the sockets for communication with the Simulink simulation."""
HOST = "localhost"
def __init__(self, port: int, name: str = None):
"""
Class defining the sockets for communication with the Simulink simulation.
Parameters:
port: int
name: string, default: None
optional name of the socket for debugging purposes
"""
self._debug_prefix = f"{name}: " if name else ""
self.port = port
self.connection = None
self.address = None
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.connect_socket_thread = threading.Thread()
def _open_socket(self, timeout=300):
"""
Method for opening the socket and waiting for connection.
Args:
timeout: timeout for waiting for connection, default: 300 s
Raises:
TimeoutError: if the socket does not connect within the specified timeout
"""
if self.is_connected():
logger.info(f"{self._debug_prefix}Socket already connected")
else:
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server.setblocking(True)
self.server.bind((self.HOST, self.port))
self.server.listen(1)
self.server.settimeout(timeout)
try:
self.connection, self.address = self.server.accept()
except socket.timeout:
self.server.shutdown(socket.SHUT_RDWR)
self.server.close()
self.connection = None
raise TimeoutError
except Exception:
self.server.shutdown(socket.SHUT_RDWR)
self.server.close()
self.connection = None
def open_socket(self):
"""Method creating a thread for connecting with the simulation."""
if not self.is_connected():
self.connect_socket_thread = threading.Thread(
name="socket._open_socket()", target=self._open_socket
)
self.connect_socket_thread.start()
else:
logger.error(f"{self._debug_prefix}Socket already opened or connected")
def receive(self):
"""
Method for receiving data from the simulation.
Returns:
raw data received over the socket
"""
if self.is_connected():
data = self.connection.recv(2048)
data_array = array.array("d", data)
return data_array
else:
logger.error(
f"{self._debug_prefix}Socket not connected, nothing to receive"
)
return None
def send_data(self, set_values: np.ndarray, stop: bool = False):
"""
Method for sending data over the socket.
Args:
set_values: numpy array containing the data
stop: flag for stopping the simulation, default: False
"""
if self.is_connected():
set_values = set_values.flatten()
byte_order_str = "<d" + "d" * set_values.size
msg = struct.pack(byte_order_str, int(stop), *set_values)
self.connection.sendall(msg)
else:
logger.error(f"{self._debug_prefix}Socket not connected, data not sent")
def close(self):
"""Method for closing the socket."""
if self.connect_socket_thread.is_alive():
self.connect_socket_thread.join()
# This either times out, which causes a TimeoutError, or results in a
# connection, which can be closed now:
if self.connection:
try:
self.connection.shutdown(socket.SHUT_RDWR)
self.connection.close()
self.server.shutdown(socket.SHUT_RDWR)
self.server.close()
except Exception:
# This catches an error appearing after some time in the training
# process. It seems that the socket used to send the data to the
# Simulink model is closing before its close() method is called.
# The reasons have to be investigated (#TBD).
logger.info(
f"Something went wrong while closing socket "
f"({self.address}, {self.port})"
)
self.connection = None
self.address = None
self.server = None
else:
logger.info(f"{self._debug_prefix}Socket not connected, nothing to close")
def is_connected(self):
"""
Check for connection of the socket.
Returns:
boolean indicating whether the socket is connected
"""
return self.connection is not None and not self.connect_socket_thread.is_alive()
def wait_for_connection(self, timeout: float = None):
"""
Method for waiting for connection.
Args:
timeout: timeout for the joining of the connection thread, default: None
"""
self.connect_socket_thread.join(timeout=timeout)