Skip to content

Commit

Permalink
Viz_classes code cleanup (#1073)
Browse files Browse the repository at this point in the history
Some very minor cleanup of some code in viz_classes. Changes include:

1. Using raw strings to properly escape regex tokens. This silences some
warnings.
2. Properly close the sql files that are opened
3. Use native set difference instead of manually computing it with a
list comprehension.
  • Loading branch information
groutr authored Jan 30, 2025
1 parent f26681b commit a16fa9f
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions Core/LAMBDA/layers/viz_lambda_shared_funcs/python/viz_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import urllib.parse
import inspect
import pathlib
try:
import fsspec
except:
Expand Down Expand Up @@ -98,7 +99,7 @@ def load_df_into_db(self, table_name, df, drop_first=True):
###################################
def execute_sql(self, sql):
if sql.endswith('.sql') and os.path.exists(sql):
sql = open(sql, 'r').read()
sql = pathlib.Path(sql).read_text()
with self.connection:
try:
with self.connection.cursor() as cur:
Expand All @@ -111,7 +112,7 @@ def execute_sql(self, sql):
###################################
def sql_to_dataframe(self, sql, return_geodataframe=False):
if sql.endswith(".sql"):
sql = open(sql, 'r').read()
sql = pathlib.Path(sql).read_text()

db_engine = self.engine
if not return_geodataframe:
Expand Down Expand Up @@ -223,16 +224,16 @@ def check_required_tables_updated(self, sql_path_or_str, sql_replace={}, referen
issues_encountered = []
# Determine if arg is file or raw SQL string
if os.path.exists(sql_path_or_str):
sql = open(sql_path_or_str, 'r').read()
sql = pathlib.Path(sql_path_or_str).read_text()
else:
sql = sql_path_or_str

for word, replacement in sql_replace.items():
sql = re.sub(word, replacement, sql, flags=re.IGNORECASE).replace('utc', 'UTC')

output_tables = set(re.findall('(?<=INTO )\w+\.\w+', sql, flags=re.IGNORECASE))
input_tables = set(re.findall('(?<=FROM |JOIN )\w+\.\w+', sql, flags=re.IGNORECASE))
check_tables = [t for t in input_tables if t not in output_tables]
output_tables = set(re.findall(r'(?<=INTO\s)\w+\.\w+', sql, flags=re.IGNORECASE))
input_tables = set(re.findall(r'(?<=FROM\s|JOIN\s)\w+\.\w+', sql, flags=re.IGNORECASE))
check_tables = input_tables - output_tables

if not check_tables:
return True
Expand Down

0 comments on commit a16fa9f

Please sign in to comment.