Skip to content

Commit

Permalink
🚨 Fix more mypy lint warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
i-be-snek committed Nov 1, 2024
1 parent 9ec4623 commit 0e88805
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
11 changes: 7 additions & 4 deletions Database/scr/normalize_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def _preprocess(self, text: str) -> str:

return text

def _extract_single_number(self, text: str) -> List[float] | BaseException:
def _extract_single_number(self, text: str) -> List[float | None] | BaseException:
number = None

for z in self.zero_phrases:
Expand Down Expand Up @@ -478,7 +478,7 @@ def _check_for_approximation(self, doc: spacy.tokens.doc.Doc, labels: List[str])

return 0

def _extract_simple_range(self, text: str) -> Tuple[float]:
def _extract_simple_range(self, text: str) -> Tuple[float, float] | None:
sep = "-"
for i in ("and", "to", "&"):
if i in text:
Expand All @@ -500,6 +500,7 @@ def _extract_simple_range(self, text: str) -> Tuple[float]:
return (self.atof(nums[0].strip()), self.atof(nums[1].strip()))
except:
return None
return None

def _get_scale(self, n_init: float | int):
"""
Expand All @@ -508,6 +509,7 @@ def _get_scale(self, n_init: float | int):
n = int(n_init) if isinstance(n_init, float) and n_init.is_integer() else n_init
abs_n = abs(n)
n_str = str(abs_n)
scale = 0

if isinstance(n, int):
# Check if the last digit is zero
Expand Down Expand Up @@ -556,12 +558,12 @@ def _extract_complex_range(self, text: str) -> Tuple[float, float] | None:
num = self._extract_single_number(" ".join(norm_text))[0]
except BaseException as err:
self.logger.error(f"Could not infer number from {norm_text}. Error: {err}")
return
return None
else:
if len(digits) == 1:
num = digits[0]
else:
return
return None
lower_mod, upper_mod = (
(3, 5)
if any([x in [y.lower() for y in text.split()] for x in self.family_synonyms])
Expand All @@ -581,6 +583,7 @@ def _extract_complex_range(self, text: str) -> Tuple[float, float] | None:
max(0, n - scale - inc) * lower_mod,
max(0, n - inc) * upper_mod,
)
return None

def _extract_approximate_quantifiers(self, text: str) -> Tuple[float, float] | None:
one, ten, hun, tho, mil, bil, tri = (
Expand Down
27 changes: 14 additions & 13 deletions Database/scr/normalize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pathlib
import re
from typing import Tuple, Union
from typing import Any, Tuple, Union

import pandas as pd
import pycountry
Expand All @@ -21,7 +21,7 @@ class NormalizeUtils:
def __init__(self):
self.logger = Logging.get_logger("normalize-utils")

def load_spacy_model(self, spacy_model: str = "en_core_web_trf") -> spacy_language:
def load_spacy_model(self, spacy_model: str = "en_core_web_trf") -> spacy_language.Language:
import spacy

try:
Expand Down Expand Up @@ -51,7 +51,7 @@ def replace_nulls(df: pd.DataFrame) -> pd.DataFrame:
df = df.astype(object).where(pd.notnull(df), None)
return df

def normalize_date(self, row: Union[str, None]) -> Tuple[int, int, int]:
def normalize_date(self, row: Union[str, None]) -> Tuple[str | None, str | None, str | None]:
"""
See https://github.com/scrapinghub/dateparser/issues/700
and https://dateparser.readthedocs.io/en/latest/dateparser.html#dateparser.date.DateDataParser.get_date_data
Expand Down Expand Up @@ -99,6 +99,7 @@ def normalize_date(self, row: Union[str, None]) -> Tuple[int, int, int]:
except BaseException as err:
self.logger.error(f"Date parsing error in {row} with date\n{err}\n")
return (None, None, None)
return (None, None, None)

@staticmethod
def unpack_col(df: pd.DataFrame, columns: list = []) -> pd.DataFrame:
Expand Down Expand Up @@ -221,7 +222,7 @@ def __init__(self):
self.logger = Logging.get_logger("normalize-utils-json")

@staticmethod
def infer_date_from_dict(x: any) -> str:
def infer_date_from_dict(x: Any) -> str:
"""
This function normalizes date output in various formats by some LLMs.
Current usecases:
Expand All @@ -235,6 +236,7 @@ def infer_date_from_dict(x: any) -> str:
If no date is found, an empty string is returned.
"""
day, month, year, date, time = None, None, None, None, None
if isinstance(x, str):
return x
if isinstance(x, list):
Expand All @@ -243,7 +245,6 @@ def infer_date_from_dict(x: any) -> str:
normalized_x = {}
for k, v in x.items():
normalized_x[k.strip().lower()] = str(v)
day, month, year, date, time = None, None, None, None, None
if "year" in normalized_x.keys():
if "year" in normalized_x.keys():
year = normalized_x["year"]
Expand Down Expand Up @@ -292,7 +293,7 @@ def merge_json(self, file_path_dir: str) -> list[pd.DataFrame]:
file_list = os.listdir(file_path_dir)
file_list_relative = [f"{file_path_dir}/{i}" for i in file_list if i and i.endswith(".json")]

dfs = []
dfs, json_file = [], None
for idx in range(len(file_list_relative)):
try:
json_file = json.load(open(file_list_relative[idx]))
Expand Down Expand Up @@ -520,7 +521,7 @@ def __init__(self, nid_path: str = "/tmp/geojson") -> None:
self.non_english_nids_path = f"{nid_path}/non-english-locations.csv"
self.non_english_nids_columns = ["location_name", "nid"]
pathlib.Path(self.nid_path).mkdir(parents=True, exist_ok=True)
self.nid_list = self.update_nid_list()
self.update_nid_list()
try:
self.non_english_nids_df = pd.read_csv(
self.non_english_nids_path,
Expand All @@ -543,8 +544,8 @@ def random_nid(self, length: int = 5) -> str:
"""Generates a short lowercase UID"""
return shortuuid.ShortUUID().random(length=length)

def generate_nid(self, text: str) -> tuple[str, None]:
nid = None
def generate_nid(self, text: str) -> str:
nid: str = ""
try:
assert text
text = unidecode(text)
Expand Down Expand Up @@ -573,8 +574,8 @@ def store_non_english_nids(self) -> None:
self.logger.info(f"Storing non english location names and their generated nids to {self.non_english_nids_path}")
self.non_english_nids_df.to_csv(self.non_english_nids_path, sep=",", index=False, mode="w")

def check_duplicate(self, nid: str, obj: json) -> tuple[str, bool]:
nid_path = f"{self.nid_path}/{nid}"
def check_duplicate(self, nid: str, obj: Any) -> tuple[str, bool]:
nid_path: str = f"{self.nid_path}/{nid}"
self.update_nid_list()

if nid_path in self.nid_list or nid in self.non_english_nids_df["location_name"].tolist():
Expand All @@ -588,7 +589,7 @@ def check_duplicate(self, nid: str, obj: json) -> tuple[str, bool]:
return alt_nid, False
return nid, False

def geojson_to_file(self, geojson_obj: str, area_name: str) -> str:
def geojson_to_file(self, geojson_obj: str, area_name: str) -> str | None:
"""Checks if a GeoJson object is stored by a specific nid. Handles three cases:
- If the nid and file content match, nothing is written to file.
- If the there is no record of the nid in self.nid_path, a new file is written.
Expand Down Expand Up @@ -645,7 +646,7 @@ def validate_categorical(self, text: str, categories: list) -> str | None:
return categories[cat_idx]
except BaseException as err:
self.logger.warning(f"Value `{text}` may be invalid for this category. Error: {err}")
return
return None

def validate_main_event_hazard_relation(
self, row: dict, hazards: str = "Hazards", main_event: str = "Main_Event"
Expand Down

0 comments on commit 0e88805

Please sign in to comment.