Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to post a file via the data param #194

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

if not sys.implementation.name == "circuitpython":
from types import TracebackType
from typing import Any, Dict, Optional, Type
from typing import IO, Any, Dict, Optional, Type

from circuitpython_typing.socket import (
SocketpoolModuleType,
Expand Down Expand Up @@ -387,19 +387,7 @@ def _build_boundary_data(self, files: dict): # pylint: disable=too-many-locals
boundary_objects.append("\r\n")

if hasattr(file_handle, "read"):
is_binary = False
try:
content = file_handle.read(1)
is_binary = isinstance(content, bytes)
except UnicodeError:
is_binary = False

if not is_binary:
raise ValueError("Files must be opened in binary mode")

file_handle.seek(0, SEEK_END)
content_length += file_handle.tell()
file_handle.seek(0)
content_length += self._get_file_length(file_handle)

boundary_objects.append(file_handle)
boundary_objects.append("\r\n")
Expand Down Expand Up @@ -428,6 +416,25 @@ def _check_headers(headers: Dict[str, str]):
f"Header part ({value}) from {key} must be of type str or bytes, not {type(value)}"
)

@staticmethod
def _get_file_length(file_handle: IO):
is_binary = False
try:
file_handle.seek(0)
# read at least 4 bytes incase we are reading a b64 stream
content = file_handle.read(4)
is_binary = isinstance(content, bytes)
except UnicodeError:
is_binary = False

if not is_binary:
raise ValueError("Files must be opened in binary mode")

file_handle.seek(0, SEEK_END)
content_length = file_handle.tell()
file_handle.seek(0)
return content_length

@staticmethod
def _send(socket: SocketType, data: bytes):
total_sent = 0
Expand Down Expand Up @@ -458,13 +465,16 @@ def _send_boundary_objects(self, socket: SocketType, boundary_objects: Any):
if isinstance(boundary_object, str):
self._send_as_bytes(socket, boundary_object)
else:
chunk_size = 32
b = bytearray(chunk_size)
while True:
size = boundary_object.readinto(b)
if size == 0:
break
self._send(socket, b[:size])
self._send_file(socket, boundary_object)

def _send_file(self, socket: SocketType, file_handle: IO):
chunk_size = 36
b = bytearray(chunk_size)
while True:
size = file_handle.readinto(b)
if size == 0:
break
self._send(socket, b[:size])

def _send_header(self, socket, header, value):
if value is None:
Expand Down Expand Up @@ -517,12 +527,16 @@ def _send_request( # pylint: disable=too-many-arguments

# If files are send, build data to send and calculate length
content_length = 0
data_is_file = False
boundary_objects = None
if files and isinstance(files, dict):
boundary_string, content_length, boundary_objects = (
self._build_boundary_data(files)
)
content_type_header = f"multipart/form-data; boundary={boundary_string}"
elif data and hasattr(data, "read"):
data_is_file = True
content_length = self._get_file_length(data)
else:
if data is None:
data = b""
Expand Down Expand Up @@ -551,7 +565,9 @@ def _send_request( # pylint: disable=too-many-arguments
self._send(socket, b"\r\n")

# Send data
if data:
if data_is_file:
self._send_file(socket, data)
elif data:
self._send(socket, bytes(data))
elif boundary_objects:
self._send_boundary_objects(socket, boundary_objects)
Expand Down
24 changes: 23 additions & 1 deletion tests/files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_actual_request_data(log_stream):
boundary = boundary_search[0]
if content_length_search:
content_length = content_length_search[0]
if "Content-Disposition" in log_arg:
if "Content-Disposition" in log_arg or "\\x" in log_arg:
# this will look like:
# b\'{content}\'
# and escaped characters look like:
Expand All @@ -63,6 +63,28 @@ def get_actual_request_data(log_stream):
return boundary, content_length, actual_request_post


def test_post_file_as_data( # pylint: disable=unused-argument
requests, sock, log_stream, post_url, request_logging
):
with open("tests/files/red_green.png", "rb") as file_1:
python_requests.post(post_url, data=file_1, timeout=30)
__, content_length, actual_request_post = get_actual_request_data(log_stream)

requests.post("http://" + mocket.MOCK_HOST_1 + "/post", data=file_1)

sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80))
sock.send.assert_has_calls(
[
mock.call(b"Content-Length"),
mock.call(b": "),
mock.call(content_length.encode()),
mock.call(b"\r\n"),
]
)
sent = b"".join(sock.sent_data)
assert sent.endswith(actual_request_post)


def test_post_files_text( # pylint: disable=unused-argument
sock, requests, log_stream, post_url, request_logging
):
Expand Down
Loading