Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Viz_classes code cleanup #1073

Merged
merged 3 commits into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions Core/LAMBDA/layers/viz_lambda_shared_funcs/python/viz_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import urllib.parse
import inspect
import pathlib
try:
import fsspec
except:
Expand Down Expand Up @@ -98,7 +99,7 @@ def load_df_into_db(self, table_name, df, drop_first=True):
###################################
def execute_sql(self, sql):
if sql.endswith('.sql') and os.path.exists(sql):
sql = open(sql, 'r').read()
sql = pathlib.Path(sql).read_text()
with self.connection:
try:
with self.connection.cursor() as cur:
Expand All @@ -111,7 +112,7 @@ def execute_sql(self, sql):
###################################
def sql_to_dataframe(self, sql, return_geodataframe=False):
if sql.endswith(".sql"):
sql = open(sql, 'r').read()
sql = pathlib.Path(sql).read_text()

db_engine = self.engine
if not return_geodataframe:
Expand Down Expand Up @@ -223,16 +224,16 @@ def check_required_tables_updated(self, sql_path_or_str, sql_replace={}, referen
issues_encountered = []
# Determine if arg is file or raw SQL string
if os.path.exists(sql_path_or_str):
sql = open(sql_path_or_str, 'r').read()
sql = pathlib.Path(sql_path_or_str).read_text()
else:
sql = sql_path_or_str

for word, replacement in sql_replace.items():
sql = re.sub(word, replacement, sql, flags=re.IGNORECASE).replace('utc', 'UTC')

output_tables = set(re.findall('(?<=INTO )\w+\.\w+', sql, flags=re.IGNORECASE))
input_tables = set(re.findall('(?<=FROM |JOIN )\w+\.\w+', sql, flags=re.IGNORECASE))
check_tables = [t for t in input_tables if t not in output_tables]
output_tables = set(re.findall(r'(?<=INTO\s)\w+\.\w+', sql, flags=re.IGNORECASE))
input_tables = set(re.findall(r'(?<=FROM\s|JOIN\s)\w+\.\w+', sql, flags=re.IGNORECASE))
check_tables = input_tables - output_tables

if not check_tables:
return True
Expand Down
Loading