-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathserver_platform.py
364 lines (329 loc) · 13.2 KB
/
server_platform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import asyncio
import sys
from typing import Optional
import aiohttp
from aiohttp import ClientRequest, ClientResponse, hdrs
from aiohttp.connector import Connection
from aiohttp.http_writer import HttpVersion10, HttpVersion11
from aiohttp.http import StreamWriter
import base64
import functools
import logging
import zipfile
from io import BytesIO
import os
import ssl
from multidict import CIMultiDict
from cryptography.utils import int_to_bytes
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
from app.models import TLS, MakeRequestParams
import app.key_loader as key_loader
def _safe_header(string: str) -> str:
if "\r" in string or "\n" in string:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
return string
def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
return line.encode("latin-1")
class Latin1HeadersStreamWriter(StreamWriter):
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write request/response status and headers."""
if self._on_headers_sent is not None:
await self._on_headers_sent(headers)
# status + headers
buf = _py_serialize_headers(status_line, headers)
self._write(buf)
class Latin1HeadersClientRequest(ClientRequest):
async def send(self, conn: "Connection") -> "ClientResponse":
# Specify request target:
# - CONNECT request must send authority form URI
# - not CONNECT proxy must send absolute form URI
# - most common is origin form URI
if self.method == hdrs.METH_CONNECT:
connect_host = self.url.raw_host
assert connect_host is not None
path = f"{connect_host}:{self.url.port}"
elif self.proxy and not self.is_ssl():
path = str(self.url)
else:
path = self.url.raw_path_qs
protocol = conn.protocol
assert protocol is not None
writer = Latin1HeadersStreamWriter(
protocol,
self.loop,
on_chunk_sent=(
functools.partial(self._on_chunk_request_sent, self.method, self.url)
if self._traces
else None
),
on_headers_sent=(
functools.partial(self._on_headers_request_sent, self.method, self.url)
if self._traces
else None
),
)
if self.compress:
writer.enable_compression(self.compress) # type: ignore[arg-type]
if self.chunked is not None:
writer.enable_chunking()
# set default content-type
if (
self.method in self.POST_METHODS
and (
self._skip_auto_headers is None
or hdrs.CONTENT_TYPE not in self._skip_auto_headers
)
and hdrs.CONTENT_TYPE not in self.headers
):
self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"
v = self.version
if hdrs.CONNECTION not in self.headers:
if conn._connector.force_close:
if v == HttpVersion11:
self.headers[hdrs.CONNECTION] = "close"
elif v == HttpVersion10:
self.headers[hdrs.CONNECTION] = "keep-alive"
# status + headers
status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
await writer.write_headers(status_line, self.headers)
task: Optional["asyncio.Task[None]"]
if self.body or self._continue is not None or protocol.writing_paused:
coro = self.write_bytes(writer, conn)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
task = asyncio.Task(coro, loop=self.loop, eager_start=True)
else:
task = self.loop.create_task(coro)
if task.done():
task = None
else:
self._writer = task
else:
# We have nothing to write because
# - there is no body
# - the protocol does not have writing paused
# - we are not waiting for a 100-continue response
protocol.start_timeout()
writer.set_eof()
task = None
response_class = self.response_class
assert response_class is not None
self.response = response_class(
self.method,
self.original_url,
writer=task,
continue100=self._continue,
timer=self._timer,
request_info=self.request_info,
traces=self._traces,
loop=self.loop,
session=self._session,
)
return self.response
class ServerPlatform:
OB_CERTS_DIR = os.path.abspath(
os.environ.get("OB_CERTS_DIR", "/app/open_banking_certs")
)
def __init__(self, key_loader: key_loader.KeyLoader):
self.key_loader = key_loader
@staticmethod
def _set_tls_version_for_ssl_context(
ssl_context: ssl.SSLContext, tls_version: str | None = None
) -> None:
if tls_version is None:
return
tls_map = {
"TLSv1": ssl.TLSVersion.TLSv1,
"TLSv1_1": ssl.TLSVersion.TLSv1_1,
"TLSv1_2": ssl.TLSVersion.TLSv1_2,
"TLSv1_3": ssl.TLSVersion.TLSv1_3,
}
forced_tls_version = tls_map.get(tls_version)
if forced_tls_version is not None:
ssl_context.minimum_version = forced_tls_version
ssl_context.maximum_version = forced_tls_version
def get_ssl_context(self, tls: TLS | None) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
if tls:
ssl_context = self.key_loader.update_ssl_context(ssl_context, tls)
self._set_tls_version_for_ssl_context(ssl_context, tls.tls_version)
return ssl_context
def _handle_binary_response(self, response: bytes) -> bytes:
try:
archive = zipfile.ZipFile(BytesIO(response), "r")
# assume that there is only one file in the archive
logging.debug(f"Archive contains following files: {archive.namelist()}")
return archive.read(archive.namelist()[0])
except zipfile.BadZipFile:
logging.error("Response is not a zip archive")
return response
async def make_request(
self, request: MakeRequestParams, follow_redirects: bool = True
):
url = request.origin + request.path
data = request.body.encode()
request_headers = dict(request.headers)
logging.debug(
"Request(%r, %r, params=%r, headers=%r, method=%r)",
url,
data,
request.query,
request_headers,
request.method,
)
ssl_context = self.get_ssl_context(request.tls)
try:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=60.0),
request_class=Latin1HeadersClientRequest,
) as session:
async with session.request(
method=request.method,
url=url,
params={k: v for k, v in request.query},
data=data,
headers=request_headers,
ssl=ssl_context,
allow_redirects=follow_redirects,
) as response:
peercert = None
if response.connection and response.connection.transport:
sslobj = response.connection.transport.get_extra_info(
"ssl_object"
)
if sslobj:
peercert = sslobj.getpeercert(True)
if peercert is not None:
peercert = ssl.DER_cert_to_PEM_cert(peercert)
if (
response.headers.get("Content-Type")
== "application/octet-stream"
):
response_text = self._handle_binary_response(
await response.read()
).decode("utf-8")
else:
response_text = await response.text()
response_headers = [
(name, value) for name, value in response.headers.items()
]
return {
"status": response.status,
"response": response_text,
"headers": response_headers,
"certificate": peercert,
}
except aiohttp.ClientResponseError as e:
response_headers = []
if e.headers:
if isinstance(e.headers, CIMultiDict):
for name, value in e.headers.items():
response_headers.append((name, value))
else:
response_headers = e.headers
return {
"status": e.status,
"response": e.message,
"headers": response_headers,
}
@staticmethod
def _force_bytes(value: str | bytes) -> bytes:
"""Convert value to bytes if necessary
Arguments:
value {String, Bytes} -- Some value to convert to bytes
Raises:
TypeError: If wrong value is passed
Returns:
Bytes -- Value converted to bytes]
"""
if isinstance(value, str):
return value.encode("utf-8")
return value
@staticmethod
def _decode_signature(signature: bytes, hash_algorithm: str) -> bytes:
hash_algorithms_map = {"SHA256": 256}
try:
num_bits = hash_algorithms_map[hash_algorithm]
except KeyError:
raise ValueError(
f"Wrong hash algorithm: {hash_algorithm}. Allowed: {list(hash_algorithms_map.keys())}"
)
num_bytes = (num_bits + 7) // 8
r, s = decode_dss_signature(signature)
return int_to_bytes(r, num_bytes) + int_to_bytes(s, num_bytes)
async def sign_with_key(
self,
data: str | bytes,
key_path: str,
hash_algorithm: str | None = None,
crypto_algorithm: str | None = None,
) -> str:
"""Sign passed data with private key
Arguments:
data {String, Bytes} -- Data to be signed
key_path {String} -- Path to a file with a private key
hash_algorithm {String} -- Hash algorithm to use.
If not provided then `sha256` will be used
Returns:
String -- Base64 encoded signed with a private key string
"""
if hash_algorithm is None:
hash_algorithm = "SHA256"
hash_algorithm = hash_algorithm.upper()
hash_algorithms_map = {
"SHA256": hashes.SHA256,
"SHA512": hashes.SHA512,
}
try:
hash_obj = hash_algorithms_map[hash_algorithm]
except AttributeError:
raise AttributeError(
f"Wrong hash algorithm: {hash_algorithm}. Allowed: {list(hash_algorithms_map.keys())}"
)
data = self._force_bytes(data)
key = load_pem_private_key(
self.key_loader.get_content(key_path),
(lambda p: p.encode("utf-8") if p is not None else None)(
key_loader.read_key_password(key_path)
),
True,
)
signature = b""
if isinstance(key, RSAPrivateKey):
if crypto_algorithm and crypto_algorithm == "PS":
signature = key.sign(
data,
padding.PSS(
mgf=padding.MGF1(hash_obj()), salt_length=hash_obj.digest_size
),
hash_obj(),
)
else:
signature = key.sign(data, padding.PKCS1v15(), hash_obj())
elif isinstance(key, ec.EllipticCurvePrivateKey):
signature = key.sign(data, ec.ECDSA(hash_obj()))
signature = self._decode_signature(signature, hash_algorithm)
return base64.b64encode(signature).decode("utf8")
def get_server_platform() -> ServerPlatform:
key_loader_env = os.environ.get("KEY_LOADER", "FILE")
match key_loader_env:
case "FILE":
kl = key_loader.FileKeyLoader()
case "ENV":
kl = key_loader.EnvKeyLoader()
case _:
raise ValueError(f"Unsupported key loader: {key_loader_env}")
return ServerPlatform(kl)