Skip to content

Commit

Permalink
Checks LLM response for mismatches (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
pipliggins authored Jan 31, 2025
1 parent d99702d commit d28d65f
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 4 deletions.
6 changes: 5 additions & 1 deletion src/adtl/autoparser/dict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .util import (
DEFAULT_CONFIG,
check_matches,
load_data_dict,
read_config_schema,
read_data,
Expand Down Expand Up @@ -200,7 +201,10 @@ def generate_descriptions(
# check ordering is correct even if the return field names aren't quite the same
# e.g. numbering has been stripped
assert all(
descrip.apply(lambda x: x["source_field_gpt"] in x.source_field, axis=1)
descrip.apply(
lambda x: check_matches(x["source_field_gpt"], [x.source_field]),
axis=1,
)
), "Field names from the LLM don't match the originals."

descrip.drop(columns=["source_field_gpt"], inplace=True)
Expand Down
24 changes: 24 additions & 0 deletions src/adtl/autoparser/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .util import (
DEFAULT_CONFIG,
check_matches,
load_data_dict,
read_config_schema,
read_json,
Expand Down Expand Up @@ -201,6 +202,29 @@ def match_fields_to_schema(self) -> pd.DataFrame:
)
df_merged.set_index("target_field", inplace=True, drop=True)

# Check to see if any fields with mapped descriptions are missing after merge
missed_merge = df_merged[
(df_merged["source_description"].notna())
& (df_merged["source_field"].isna())
]

if not missed_merge.empty:
descriptions_list = self.data_dictionary["source_description"].tolist()
df_merged.loc[
(df_merged["source_description"].notna())
& (df_merged["source_field"].isna()),
"source_description",
] = missed_merge["source_description"].apply(
lambda x: check_matches(x, descriptions_list)
)

df_merged = (
df_merged["source_description"]
.reset_index()
.merge(self.data_dictionary, how="left")
.set_index("target_field")
)

self.mapped_fields = df_merged.source_field
self.filtered_data_dict = df_merged
return df_merged
Expand Down
25 changes: 25 additions & 0 deletions src/adtl/autoparser/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import difflib
import json
import re
from pathlib import Path
Expand Down Expand Up @@ -139,3 +140,27 @@ def setup_llm(provider, api_key):
return GeminiLanguageModel(api_key=api_key)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")


def check_matches(llm: str, source: list[str], cutoff=0.8) -> str | None:
"""
Use to check if a string returned by an llm is a close enough match to the original
source.
Useful for checking or finding the original word if the LLM misspells it when
returning results.
Parameters
----------
llm
String returned by the LLM
source
List of strings to compare against (usually the original fields/descriptions
from the previous step)
"""
if not isinstance(source, list):
raise ValueError(
f"check matches: source must be a list of strings, got '{source}'"
)
matches = difflib.get_close_matches(llm, source, n=1, cutoff=cutoff)
return matches[0] if matches else None
2 changes: 1 addition & 1 deletion tests/test_autoparser/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_map_fields(monkeypatch):

# Define the mocked response
def mock_generate_content(*args, **kwargs):
json_str = '{"targets_descriptions": [{"source_description": "Identity", "target_field": "identity"}, {"source_description": "Full Name", "target_field": "name"}, {"source_description": "Province", "target_field": "loc_admin_1"}, {"source_description": null, "target_field": "country_iso3"}, {"source_description": "Notification Date", "target_field": "notification_date"}, {"source_description": "Classification", "target_field": "classification"}, {"source_description": "Case Status", "target_field": "case_status"}, {"source_description": "Death Date", "target_field": "date_of_death"}, {"source_description": "Age in Years", "target_field": "age_years"}, {"source_description": "Age in Months", "target_field": "age_months"}, {"source_description": "Gender", "target_field": "sex"}, {"source_description": "Pet Animal", "target_field": "pet"}, {"source_description": "Microchipped", "target_field": "chipped"}, {"source_description": null, "target_field": "owner"}]}' # noqa
json_str = '{"targets_descriptions": [{"source_description": "Identity", "target_field": "identity"}, {"source_description": "Full Name", "target_field": "name"}, {"source_description": "Province", "target_field": "loc_admin_1"}, {"source_description": null, "target_field": "country_iso3"}, {"source_description": "Notification Date", "target_field": "notification_date"}, {"source_description": "Classification", "target_field": "classification"}, {"source_description": "Case Status", "target_field": "case_status"}, {"source_description": "Death Date", "target_field": "date_of_death"}, {"source_description": "Age Years", "target_field": "age_years"}, {"source_description": "Age in Months", "target_field": "age_months"}, {"source_description": "Gender", "target_field": "sex"}, {"source_description": "Pet Animal", "target_field": "pet"}, {"source_description": "Microchipped", "target_field": "chipped"}, {"source_description": null, "target_field": "owner"}]}' # noqa
res = protos.GenerateContentResponse(
candidates=[
protos.Candidate(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_autoparser/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ def test_match_fields_to_schema_dummy_data():
# check mapped_values now filled
pd.testing.assert_series_equal(mapper.mapped_fields, df["source_field"])

# check the description that was misspelled is now corrected
assert df.at["age_years", "source_field"] == "AgeAns"
assert df.at["date_of_death", "source_field"] is np.nan


def test_match_values_to_schema_dummy_data():
mapper = ANIMAL_MAPPER
Expand Down
16 changes: 16 additions & 0 deletions tests/test_autoparser/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from adtl.autoparser.util import (
check_matches,
load_data_dict,
parse_choices,
read_config_schema,
Expand Down Expand Up @@ -106,3 +107,18 @@ def test_setup_llm_no_key():
def test_setup_llm_bad_provider():
with pytest.raises(ValueError, match="Unsupported LLM provider: fish"):
setup_llm("fish", "abcd")


@pytest.mark.parametrize(
"input, expected", [(("fish", ["fishes"]), "fishes"), (("fish", ["shark"]), None)]
)
def test_check_matches(input, expected):
llm, source = input
assert check_matches(llm, source) == expected


def test_check_matches_error():
with pytest.raises(
ValueError, match="check matches: source must be a list of strings, got 'fish'"
):
check_matches("fish", "fish")
4 changes: 2 additions & 2 deletions tests/test_autoparser/testing_data_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def get_definitions(*args):
"notification_date": "Notification Date",
"classification": "Classification",
"case_status": "Case Status",
"date_of_death": "Death Date",
"age_years": "Age in Years",
"date_of_death": "Death Date", # "Date of Death", misspelled by 'LLM'
"age_years": "Age Years", # "Age in Years", misspelled by 'LLM'
"age_months": "Age in Months",
"sex": "Gender",
"pet": "Pet Animal",
Expand Down

0 comments on commit d28d65f

Please sign in to comment.