Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

63 bring in server #64

Merged
merged 6 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ dependencies = [
"numpy>=2.2.1",
# Install solvers
"clarabel>=0.9.0",
"pyarrow>=19.0.0",
"loguru>=0.7.3",
]

[project.urls]
Expand Down Expand Up @@ -48,3 +50,7 @@ DEP002 = ["clarabel"]

[tool.bandit]
exclude_dirs = ["tests"]

[project.scripts]
server = "cvx.ball.server:BallServer.start"
client = "example.client:main"
95 changes: 95 additions & 0 deletions src/cvx/ball/numpy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import threading

import loguru
import numpy as np
import pyarrow.flight as fl


class NumpyServer(fl.FlightServerBase):
def __init__(self, host, port, logger=None, **kwargs):
uri = f"grpc+tcp://{host}:{port}"
super().__init__(uri, **kwargs)
self._logger = logger or loguru.logger
self._storage = {} # Dictionary to store uploaded data
self._lock = threading.Lock() # Lock for thread safety

@property
def logger(self):
return self._logger

@staticmethod
def _handle_arrow_table(table, logger) -> dict[str, np.ndarray]:
# Directly work with the Arrow Table (no Polars)
logger.info(f"Handling Arrow Table: {table}")
logger.info(f"Names: {table.schema.names}")

matrices = {}
for name in table.schema.names:
logger.info(f"Name: {name}")
struct = table.column(name)[0].as_py()

# Extract the matrix and shape data from the Arrow Table
matrix_data = np.array(struct["data"]) # .to_numpy() # Flattened matrix data
shape = np.array(struct["shape"]) # .to_numpy() # Shape of the matrix

logger.info(f"Matrix (flattened): {matrix_data}")
logger.info(f"Shape: {shape}")

if len(matrix_data) != np.prod(shape):
raise fl.FlightServerError("Data length does not match the provided shape")

# Reshape the flattened matrix data based on the shape
matrix = matrix_data.reshape(shape)
logger.info(f"Reshaped Matrix: {matrix}")

matrices[name] = matrix

return matrices

@staticmethod
def _extract_command_from_ticket(ticket):
"""Helper method to extract the command from a Flight Ticket."""
return ticket.ticket.decode("utf-8")

def do_put(self, context, descriptor, reader, writer):
with self._lock:
# Read and store the data
command = descriptor.command.decode("utf-8")
self.logger.info(f"Processing PUT request for command: {command}")

table = reader.read_all()
self.logger.info(f"Table: {table}")

# Store the table using the command as key
self._storage[command] = table

self.logger.info(f"Data stored for command: {command}")

return fl.FlightDescriptor.for_command(command)

def do_get(self, context, ticket):
# Get the command from the ticket
command = self._extract_command_from_ticket(ticket)
self.logger.info(f"Processing GET request for command: {command}")

# Retrieve the stored table
if command not in self._storage:
raise fl.FlightServerError(f"No data found for command: {command}")

table = self._storage[command]
self.logger.info(f"Retrieved data for command: {command}")

matrices = NumpyServer._handle_arrow_table(table, logger=self.logger)
result_table = self.f(matrices)

self.logger.info("Computation completed. Returning results.")
stream = fl.RecordBatchStream(result_table)

return stream

@classmethod
def start(cls, port=5008, logger=None, **kwargs):
logger = logger or loguru.logger # pragma: no cover
server = cls("127.0.0.1", port=port, logger=logger, **kwargs) # pragma: no cover
server.logger.info(f"Starting {cls} Flight server on port {port}...") # pragma: no cover
server.serve() # pragma: no cover
23 changes: 23 additions & 0 deletions src/cvx/ball/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import pyarrow as pa

from .numpy_server import NumpyServer
from .solver import min_circle_cvx


class BallServer(NumpyServer):
def f(self, matrices: dict[str, np.ndarray]) -> pa.Table:
self.logger.info(f"Matrices: {matrices.keys()}")
matrix = matrices["input"]

self.logger.info(f"Matrix: {matrix}")

# Compute the smallest enclosing ball
self.logger.info("Computing smallest enclosing ball...")
radius, midpoint = min_circle_cvx(matrix, solver="CLARABEL")

# Create result table
radius_array = pa.array([radius], type=pa.float64())
midpoint_array = pa.array([midpoint], type=pa.list_(pa.float64()))
result_table = pa.table({"radius": radius_array, "midpoint": midpoint_array})
return result_table
Empty file added src/example/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions src/example/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import pyarrow as pa
import pyarrow.flight as fl
from loguru import logger


def numpy2pyarrow(data):
return pa.array([{"data": data.flatten(), "shape": data.shape}])


def compute(client, data, logger=None):
table = pa.table({"input": numpy2pyarrow(data)})
logger.info("Created example data.")

# Upload data
command = "compute_ball"
descriptor = fl.FlightDescriptor.for_command(command)

logger.info(f"Uploading data with command: {command}")
writer, _ = client.do_put(descriptor, table.schema)
writer.write_table(table)
writer.close()

# Retrieve result
ticket = fl.Ticket(command) # Create a Ticket with the command
reader = client.do_get(ticket)

result_table = reader.read_all()
logger.info("Result retrieved successfully.")

results = {name: result_table.column(name)[0].as_py() for name in result_table.schema.names}
logger.info(f"Results: {results}")
return results


def main():
# Connect to the server
client = fl.connect("grpc+tcp://127.0.0.1:5008")
logger.info("Connected to the server.")

# Example data
data = np.random.rand(10000, 20) # 10000 points in 20D space
compute(client, data, logger=logger)


if __name__ == "__main__":
main()
13 changes: 0 additions & 13 deletions src/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +0,0 @@
# Copyright 2025 Stanford University Convex Optimization Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
144 changes: 144 additions & 0 deletions src/tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import threading
import time

import numpy as np
import pyarrow as pa
import pyarrow.flight as fl
import pytest

from cvx.ball.server import BallServer # Adjust to your actual import path

from .utils.reader import TableReader


@pytest.fixture(scope="module")
def server():
"""Fixture to initialize and properly kill the BallServer for each test."""
server = BallServer("127.0.0.1", 5007)

# Function to run the server in a separate thread
def run_server():
server.serve()

# Start the server in a separate thread
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()

# Give the server a moment to start
time.sleep(1)

yield server
# Connect to the server (flight client)
# flight_client = fl.connect("grpc+tcp://127.0.0.1:5007")

# yield flight_client # Provide the flight client to the test

# After the test, ensure the server is properly cleaned up
# flight_client.close() # Close the client connection

# After the test, ensure the server is properly cleaned up
server_thread.join(timeout=5) # Ensure the server thread has time to shutdown


@pytest.fixture(scope="module")
def client(server):
"""Fixture to initialize and properly kill the BallServer for each test."""
# Connect to the server (flight client)
flight_client = fl.connect("grpc+tcp://127.0.0.1:5007")

yield flight_client # Provide the flight client to the test

# After the test, ensure the server is properly cleaned up
flight_client.close() # Close the client connection


@pytest.fixture
def mock_table():
"""Fixture to create a mock Arrow Table for tests."""
matrix_data = [1, 2, 3, 4] # Flattened 2x2 matrix
shape = [2, 2] # Shape of the matrix
struct = {"data": matrix_data, "shape": shape}

# Create an Arrow Table with the matrix data and shape
table = pa.table({"input": [struct]})
return table


@pytest.fixture
def mock_table_faulty():
"""Fixture to create a mock Arrow Table for tests."""
matrix_data = [1, 2, 3, 4, 5, 6] # Flattened 3x2 matrix
shape = [2, 2] # Shape of the matrix
struct = {"data": matrix_data, "shape": shape}

# Create an Arrow Table with the matrix data and shape
table = pa.table({"input": [struct]})
return table


def test_client(client, mock_table):
# Simulate a 'do_put' request
command = "compute_ball"
descriptor = fl.FlightDescriptor.for_command(command)

writer, _ = client.do_put(descriptor, mock_table.schema)
writer.write_table(mock_table)
writer.close()

ticket = fl.Ticket(command) # Create a Ticket with the command
reader = client.do_get(ticket)

result = reader.read_all()
assert result.column("radius")[0].as_py() == pytest.approx(1.4142135605902473)
assert result.column("midpoint")[0].as_py() == pytest.approx(np.array([2.0, 3.0]))


def test_do_put_server(server, mock_table):
command = "compute_ball"
descriptor = fl.FlightDescriptor.for_command(command)

reader = TableReader(mock_table)

server.do_put(None, descriptor, reader, None)


def test_do_get_server(server, mock_table):
command = "compute_ball"
descriptor = fl.FlightDescriptor.for_command(command)

# fill the storage for the correct command
reader = TableReader(mock_table)
server.do_put(None, descriptor, reader, None)

# from the ticket we can extract the correct storage
ticket = fl.Ticket(command)
server.do_get(None, ticket)


def test_wrong_command(server, mock_table):
command = "compute_ball"
descriptor = fl.FlightDescriptor.for_command(command)

# fill the storage for the correct command
reader = TableReader(mock_table)
server.do_put(None, descriptor, reader, None)

with pytest.raises(fl.FlightServerError):
command = "Dunno"
# from the ticket we can extract the correct storage
ticket = fl.Ticket(command)
server.do_get(None, ticket)


def test_faulty_data(server, mock_table_faulty):
command = "compute_ball"
descriptor = fl.FlightDescriptor.for_command(command)

# fill the storage for the correct command
reader = TableReader(mock_table_faulty)
server.do_put(None, descriptor, reader, None)

with pytest.raises(fl.FlightServerError):
# from the ticket we can extract the correct storage
ticket = fl.Ticket(command)
server.do_get(None, ticket)
Empty file added src/tests/utils/__init__.py
Empty file.
Loading