Skip to content

Commit

Permalink
- Implemented shape check on inference
Browse files Browse the repository at this point in the history
- Disabled inference for LocalModels (made the code more complicated and wasn't needed imo)
- Surpressed a warning regarding a default parameter change in torch.load
  • Loading branch information
BeFranke committed Aug 8, 2024
1 parent d6bce32 commit 48a2e55
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
33 changes: 32 additions & 1 deletion fl_server_api/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_model_not_exist(self):
self.assertEqual(response.status_code, 400)
response_json = response.json()
self.assertIsNotNone(response_json)
self.assertEqual(f"Model {unused_id} not found.", response_json["detail"])
self.assertEqual(f"GlobalModel {unused_id} not found.", response_json["detail"])

def test_model_weights_corrupted(self):
inp = from_torch_tensor(torch.zeros(3, 3))
Expand Down Expand Up @@ -168,3 +168,34 @@ def _inference_result(self, torch_model: torch.nn.Module):
self.assertIsNotNone(inference)
inference_tensor = torch.as_tensor(inference)
self.assertTrue(torch.all(torch.tensor([2, 0, 0]) == inference_tensor))

def test_inference_input_shape_positive(self):
inp = from_torch_tensor(torch.zeros(3, 3))
model = Dummy.create_model(input_shape=[None, 3])
training = Dummy.create_training(actor=self.user, model=model)
input_file = SimpleUploadedFile(
"input.pt",
inp,
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/inference/",
{"model_id": str(training.model.id), "model_input": input_file}
)
self.assertEqual(response.status_code, 200)

def test_inference_input_shape_negative(self):
inp = from_torch_tensor(torch.zeros(3, 3))
model = Dummy.create_model(input_shape=[None, 5])
training = Dummy.create_training(actor=self.user, model=model)
input_file = SimpleUploadedFile(
"input.pt",
inp,
content_type="application/octet-stream"
)
response = self.client.post(
f"{BASE_URL}/inference/",
{"model_id": str(training.model.id), "model_input": input_file}
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json()[0], "Input shape does not match model input shape.")
21 changes: 17 additions & 4 deletions fl_server_api/views/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,18 @@ def _process_post(self, request: HttpRequest) -> HttpResponse:
self._logger.error(e)
raise ValidationError("Inference Request could not be interpreted!")

model = get_entity(Model, pk=model_id)
# (Benedikt) imo, it does not make sense to allow inference on the local model updates, therefore I changed this
model = get_entity(GlobalModel, pk=model_id)
input_tensor = to_torch_tensor(feature_vectors)
if isinstance(model, GlobalModel) and model.preprocessing is not None:
if model.preprocessing is not None:
preprocessing = model.get_preprocessing_torch_model()
input_tensor = preprocessing(input_tensor)

if model.input_shape is not None:
if not all(dim_input == dim_model for (dim_input, dim_model) in
zip(input_tensor.shape, model.input_shape) if dim_model is not None):
raise ValidationError("Input shape does not match model input shape.")

uncertainty_cls, inference, uncertainty = self.do_inference(model, input_tensor)
return self._make_response(uncertainty_cls, inference, uncertainty, return_format)

Expand All @@ -137,11 +144,17 @@ def _process_post_json(self, request: HttpRequest, body: Any = None) -> HttpResp
self._logger.error(e)
raise ValidationError("Inference Request could not be interpreted!")

model = get_entity(Model, pk=model_id)
model = get_entity(GlobalModel, pk=model_id)
input_tensor = torch.as_tensor(model_input)
if isinstance(model, GlobalModel) and model.preprocessing is not None:
if model.preprocessing is not None:
preprocessing = model.get_preprocessing_torch_model()
input_tensor = preprocessing(input_tensor)

if model.input_shape is not None:
if not all(dim_input == dim_model for (dim_input, dim_model) in
zip(input_tensor.shape, model.input_shape) if dim_model is not None):
raise ValidationError("Input shape does not match model input shape.")

uncertainty_cls, inference, uncertainty = self.do_inference(model, input_tensor)
return self._make_response(uncertainty_cls, inference, uncertainty, return_format)

Expand Down
3 changes: 2 additions & 1 deletion fl_server_core/utils/torch_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def to_torch(obj: Any, supported_types: Type[T] | Tuple[Type[T], ...]):
message="'torch.load' received a zip file that looks like a TorchScript archive",
category=UserWarning
)
t_obj = torch.load(obj)
# default for "weights_only" will chnage in upcomning torch versions!
t_obj = torch.load(obj, weights_only=False)
except Exception as e:
getLogger("fl.server").error(f"Error loading torch object: {e}")
raise TorchDeserializationException("Error loading torch object") from e
Expand Down

0 comments on commit 48a2e55

Please sign in to comment.