diff --git a/src/marqo/tensor_search/models/index_settings.py b/src/marqo/tensor_search/models/index_settings.py index 4a975afd7..1012b6d48 100644 --- a/src/marqo/tensor_search/models/index_settings.py +++ b/src/marqo/tensor_search/models/index_settings.py @@ -1,5 +1,5 @@ import time -from typing import Dict, Any, Optional, List +from typing import Dict, Any, Optional, List, Union from pydantic import root_validator @@ -45,15 +45,19 @@ class IndexSettings(StrictBaseModel): @root_validator(pre=True) def validate_field_names(cls, values): # Verify no snake case field names (pydantic won't catch these due to allow_population_by_field_name = True) - def validate_dict_keys(d: dict): - for key in d.keys(): - if '_' in key: - raise ValueError(f"Invalid field name '{key}'. " - f"See Create Index API reference here https://docs.marqo.ai/2.0.0/API-Reference/Indexes/create_index/") - if isinstance(d[key], dict): - validate_dict_keys(d[key]) - - validate_dict_keys(values) + def validate_keys(d: Union[dict, list]): + if isinstance(d, dict): + for key in d.keys(): + if '_' in key: + raise ValueError(f"Invalid field name '{key}'. " + f"See Create Index API reference here https://docs.marqo.ai/2.0.0/API-Reference/Indexes/create_index/") + + validate_keys(d[key]) + elif isinstance(d, list): + for item in d: + validate_keys(item) + + validate_keys(values) return values diff --git a/tests/tensor_search/test_api.py b/tests/tensor_search/test_api.py index 0f1b43890..67b0bafd0 100644 --- a/tests/tensor_search/test_api.py +++ b/tests/tensor_search/test_api.py @@ -145,23 +145,54 @@ def test_invalid_argument_error(self): assert "Could not find model properties for" in response.json()["message"] def test_create_index_snake_case_fails(self): - # Verify snake case rejected for fields that have camel case as alias - response = self.client.post( - "/indexes/my_index", - json={ - "type": "structured", - "allFields": [{"name": "field1", "type": "text"}], - "tensorFields": [], - 'annParameters': { - 'spaceType': 'dotproduct', - 'parameters': { - 'ef_construction': 128, - 'm': 16 - } - } - } - ) - - self.assertEqual(response.status_code, 422) - self.assertTrue("Invalid field name 'ef_construction'" in response.text) - + """ + Verify snake case rejected for fields that have camel case as alias + """ + test_cases = [ + ({ + "type": "structured", + "allFields": [ + { + "name": "field1", + "type": "text" + }, + { + "name": "field2", + "type": "multimodal_combination", + "dependent_fields": ["field1"] + } + ], + "tensorFields": [], + }, 'dependent_fields', 'Snake case within a list'), + ({ + "type": "structured", + "allFields": [], + "tensorFields": [], + 'annParameters': { + 'spaceType': 'dotproduct', + 'parameters': { + 'ef_construction': 128, + 'm': 16 + } + } + }, 'ef_construction', 'Snake case within a dict'), + ({ + "type": "unstructured", + 'annParameters': { + 'spaceType': 'dotproduct', + 'parameters': { + 'ef_construction': 128, + 'm': 16 + } + } + }, 'ef_construction', 'Snake case within a dict, unstructured index') + ] + for test_case, field, test_name in test_cases: + with self.subTest(test_name): + response = self.client.post( + "/indexes/my_index", + json=test_case + ) + + self.assertEqual(response.status_code, 422) + self.assertTrue(f"Invalid field name '{field}'" in response.text)