Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(share): add sharing using magic-wormhole #223

Merged
merged 81 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
8c0a78d
Copy from closed PR magic wormhole working code and test cases
Mustaballer Jun 5, 2023
b4b7879
Merge branch 'main' into share-magic-wormhole
Mustaballer Jun 15, 2023
395fa00
Merge branch 'main' into share-magic-wormhole
Mustaballer Jun 16, 2023
5fbdc1a
Merge remote-tracking branch 'upstream/main' into share-magic-wormhole
Mustaballer Jun 23, 2023
0a0208e
Merge branch 'share-magic-wormhole' of https://github.com/Mustaballer…
Mustaballer Jun 23, 2023
9dc1850
modify export_sql to use paramerterized queries to prevent sql injection
Mustaballer Jun 23, 2023
53f11f5
try resolve merge conflict with poetry.lock
Mustaballer Jun 23, 2023
0977859
Merge branch 'main' into share-magic-wormhole
Mustaballer Jun 23, 2023
cbee5e8
fix merge conflict and use better approach for overwriting env
Mustaballer Jun 23, 2023
9819330
Remove unnecessary function and pass test cases
Mustaballer Jun 25, 2023
afd9810
reformat file and group constants together in config.py
Mustaballer Jun 25, 2023
2211f78
ran black
Mustaballer Jun 25, 2023
f58fd5f
ran black
Mustaballer Jun 25, 2023
dafca05
moved functions in crud.py to db.py
Mustaballer Jun 26, 2023
3434a00
use tempfile in test_share.py and address minor changes
Mustaballer Jun 26, 2023
b801caf
restore original db name instead of literal
Mustaballer Jun 26, 2023
4c59c10
ran black -l 60, and it used call chain
Mustaballer Jun 26, 2023
088f370
Merge branch 'main' into share-magic-wormhole
0dm Jun 29, 2023
1096a15
Merge branch 'main' into share-magic-wormhole
Mustaballer Jul 4, 2023
01eef24
Add exception handling when ctrl+c during sharing that deletes db and…
Mustaballer Jul 4, 2023
5e52b5c
ran black and update poetry.lock
Mustaballer Jul 4, 2023
8317634
Add missing docstring in db.py
Mustaballer Jul 5, 2023
e675227
delete temp .env if ctrl+c during sharing
Mustaballer Jul 5, 2023
04c69b8
merge with latest main
Mustaballer Jul 6, 2023
2d54e83
Add .env.example and generate env in config.py
Mustaballer Jul 10, 2023
3992d1c
use .env.example for creating .env and removed unnecessary exceptions
Mustaballer Jul 12, 2023
d4d3f1c
Merge branch 'main' into share-magic-wormhole
Mustaballer Jul 17, 2023
2f2c0dd
use new approach for copying db
Mustaballer Jul 19, 2023
7a67386
Merge branch 'share-magic-wormhole' of https://github.com/Mustaballer…
Mustaballer Jul 19, 2023
6196876
modify copy deb function to return recording data and remov comments …
Mustaballer Jul 19, 2023
5893b3e
modify env names and add asserts to share.py
Mustaballer Jul 20, 2023
2149adc
copy alembic migrations
Mustaballer Jul 20, 2023
2f1548d
Copy all data relating to recording_timestamp in all tables
Mustaballer Jul 20, 2023
625c192
extract db file upon receiving recording
Mustaballer Jul 21, 2023
18a4918
add command to visualize recording
Mustaballer Jul 22, 2023
9b7f7af
remove unnecessary function and todo comment
Mustaballer Jul 22, 2023
786e063
refactor copy_recording_data
Mustaballer Jul 22, 2023
644c83a
remove unittest class and fix failing test case for receiving recordi…
Mustaballer Jul 22, 2023
653e785
modify conftest.py and fixtures.sql to insert data to every table for…
Mustaballer Jul 22, 2023
76f14b0
update unit tests
Mustaballer Jul 24, 2023
ada2aad
Merge branch 'main' into share-magic-wormhole
Mustaballer Jul 25, 2023
2f66b8a
resolve merge conflicts
Mustaballer Jul 25, 2023
f29f199
address flake8 errors
Mustaballer Jul 25, 2023
1a10f8a
Merge branch 'main' into share-magic-wormhole
Mustaballer Jul 31, 2023
c3173f8
resolve merge issues
Mustaballer Aug 1, 2023
ea25b53
Update openadapt/share.py
Mustaballer Aug 1, 2023
0123aae
resolve https://github.com/OpenAdaptAI/OpenAdapt/issues/441
Mustaballer Aug 2, 2023
7d0d343
add type annotation
Mustaballer Aug 3, 2023
b6cce11
Merge branch 'share-magic-wormhole' of https://github.com/Mustaballer…
Mustaballer Aug 3, 2023
d9eefa0
run black --preview and modify main.yml to check black --preview
Mustaballer Aug 3, 2023
961130b
remove unused import
Mustaballer Aug 4, 2023
5fd9868
Add timestamp to exported recording db files and update unit tests
Mustaballer Aug 4, 2023
453d2c2
Merge branch 'main' into share-magic-wormhole
Mustaballer Aug 10, 2023
4d66584
update poetry.lock
Mustaballer Aug 10, 2023
75b324d
Merge branch 'main' into share-magic-wormhole
Mustaballer Aug 10, 2023
dc5f2b6
fix: enhance publish action and authors in pyproject.toml
Mustaballer Aug 11, 2023
ebea2cb
modify release-and-publish.yml
Mustaballer Aug 11, 2023
c77f339
change author name to OpenAdapt.AI Team
Mustaballer Aug 11, 2023
8b1299d
Merge remote-tracking branch 'upstream/enhance-publishing' into share…
Mustaballer Aug 14, 2023
09dd4d6
test publish to test pypi
Mustaballer Aug 14, 2023
8ea0ee4
fix poetry conflicts and conflicting files
Mustaballer Aug 14, 2023
1b2eef3
Merge branch 'main' into share-magic-wormhole
Mustaballer Aug 18, 2023
0e8d620
resolve merge conflicts and linting errors from recent merge
Mustaballer Aug 18, 2023
5b9bbe0
Merge branch 'main' into share-magic-wormhole
Mustaballer Aug 28, 2023
24a3e79
Merge branch 'main' into share-magic-wormhole
Mustaballer Aug 29, 2023
a03bd59
update poetry.lock file and some formatting
Mustaballer Aug 29, 2023
a94c303
fix wormhole sharing
Mustaballer Sep 8, 2023
ca04296
Merge remote-tracking branch 'upstream/main' into share-magic-wormhole
Mustaballer Sep 8, 2023
61553f0
poetry lock
abrichr Nov 18, 2023
22b5f0c
fix failing tests
abrichr Nov 19, 2023
00c79cf
Merge branch 'main' into share-magic-wormhole
abrichr Nov 19, 2023
21340b2
poetry lock
abrichr Nov 19, 2023
0edcd88
add spacy-curated-transformers
abrichr Nov 19, 2023
1367cb9
remove custom visualize function
Mustaballer Dec 11, 2023
54a954f
change logging.py name to resolve naming conflict with logger library
Mustaballer Dec 11, 2023
9682209
Merge branch 'main' into share-magic-wormhole
Mustaballer Dec 11, 2023
6651f2c
Restore `export_recording` import with lint ignore comment
Mustaballer Dec 11, 2023
117694b
Merge branch 'share-magic-wormhole' of https://github.com/Mustaballer…
Mustaballer Dec 11, 2023
bf9e8bf
Merge branch 'main' into share-magic-wormhole
Mustaballer Dec 12, 2023
39b334e
Update openadapt/config.py
Mustaballer Dec 12, 2023
b79f3b5
Update openadapt/config.py
Mustaballer Dec 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,69 @@
...

"""

import multiprocessing
import os
import pathlib

from dotenv import load_dotenv
from dotenv import find_dotenv, load_dotenv
from loguru import logger

ROOT_DIRPATH = pathlib.Path(__file__).parent.parent.resolve()
ZIPPED_RECORDING_FOLDER_PATH = ROOT_DIRPATH / "data" / "zipped"

ENV_FILE_PATH = (ROOT_DIRPATH / ".env").resolve()
logger.info(f"{ENV_FILE_PATH=}")
dotenv_file = find_dotenv()
load_dotenv(dotenv_file)


def read_env_file(file_path):
Mustaballer marked this conversation as resolved.
Show resolved Hide resolved
"""Read the contents of an environment file.

Args:
file_path (str): The path to the environment file.

Returns:
dict: A dictionary containing the environment variables and their values.
"""
env_vars = {}
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
env_vars[key] = value.strip('"')
return env_vars


env_file_path = ".env"
env_vars = read_env_file(env_file_path)

DB_FNAME = env_vars.get("DB_FNAME")


def set_db_url(db_fname):
"""Set the database URL based on the given database file name.

Args:
db_fname (str): The database file name.
"""
global DB_FNAME, DB_FPATH, DB_URL
DB_FNAME = db_fname
DB_FPATH = ROOT_DIRPATH / DB_FNAME
DB_URL = f"sqlite:///{DB_FPATH}"
logger.info(f"{DB_URL=}")


_DEFAULTS = {
"CACHE_DIR_PATH": ".cache",
"CACHE_ENABLED": True,
"CACHE_VERBOSITY": 0,
"DB_ECHO": False,
"DB_FNAME": "openadapt.db",
"DB_FNAME": DB_FNAME,
"OPENAI_API_KEY": "<set your api key in .env>",
# "OPENAI_MODEL_NAME": "gpt-4",
"OPENAI_MODEL_NAME": "gpt-3.5-turbo",
# may incur significant performance penalty
"RECORD_READ_ACTIVE_ELEMENT_STATE": False,
# TODO: remove?
"REPLAY_STRIP_ELEMENT_STATE": True,
# IGNORES WARNINGS (PICKLING, ETC.)
# TODO: ignore warnings by default on GUI
Expand Down Expand Up @@ -87,29 +129,43 @@


def getenv_fallback(var_name):
rval = os.getenv(var_name) or _DEFAULTS.get(var_name)
"""Get the value of an environment variable with fallback to default.

Args:
var_name (str): The name of the environment variable.

Returns:
str: The value of the environment variable or the default value if not found.

Raises:
ValueError: If the environment variable is not defined.
"""
if var_name == "DB_FNAME":
rval = _DEFAULTS.get(var_name)
else:
rval = os.getenv(var_name) or _DEFAULTS.get(var_name)
if rval is None:
raise ValueError(f"{var_name=} not defined")
return rval


load_dotenv()

for key in _DEFAULTS:
val = getenv_fallback(key)
locals()[key] = val


ROOT_DIRPATH = pathlib.Path(__file__).parent.parent.resolve()
DB_FPATH = ROOT_DIRPATH / DB_FNAME
DB_URL = f"sqlite:///{DB_FPATH}"
DIRNAME_PERFORMANCE_PLOTS = "performance"
DB_ECHO = False
DT_FMT = "%Y-%m-%d_%H-%M-%S"

if multiprocessing.current_process().name == "MainProcess":
for key, val in locals().items():
if not key.startswith("_") and key.isupper():
logger.info(f"{key}={val}")


def filter_log_messages(data):
"""
This function filters log messages by ignoring any message that contains a specific string.
Expand All @@ -127,4 +183,4 @@ def filter_log_messages(data):
messages_to_ignore = [
"Cannot pickle Objective-C objects",
]
return not any(msg in data["message"] for msg in messages_to_ignore)
return not any(msg in data["message"] for msg in messages_to_ignore)
134 changes: 110 additions & 24 deletions openadapt/crud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import os
import time
import shutil

from datetime import datetime
from loguru import logger
from sqlalchemy.orm import sessionmaker
import sqlalchemy as sa

from openadapt.db import Session
from openadapt.db import Session, get_base, get_engine, engine
from openadapt import config
from openadapt.models import (
ActionEvent,
Screenshot,
Expand All @@ -19,13 +26,11 @@
window_events = []
performance_stats = []


def _insert(event_data, table, buffer=None):
"""Insert using Core API for improved performance (no rows are returned)"""

db_obj = {
column.name: None
for column in table.__table__.columns
}
db_obj = {column.name: None for column in table.__table__.columns}
for key in db_obj:
if key in event_data:
val = event_data[key]
Expand Down Expand Up @@ -74,6 +79,7 @@ def insert_window_event(recording_timestamp, event_timestamp, event_data):
}
_insert(event_data, WindowEvent, window_events)


def insert_perf_stat(recording_timestamp, event_type, start_time, end_time):
"""
Insert event performance stat into db
Expand All @@ -87,19 +93,20 @@ def insert_perf_stat(recording_timestamp, event_type, start_time, end_time):
}
_insert(event_perf_stat, PerformanceStat, performance_stats)


def get_perf_stats(recording_timestamp):
"""
return performance stats for a given recording
"""

return (
db
.query(PerformanceStat)
db.query(PerformanceStat)
.filter(PerformanceStat.recording_timestamp == recording_timestamp)
.order_by(PerformanceStat.start_time)
.all()
)


def insert_recording(recording_data):
db_obj = Recording(**recording_data)
db.add(db_obj)
Expand All @@ -109,28 +116,107 @@ def insert_recording(recording_data):


def get_latest_recording():
return (
db
.query(Recording)
.order_by(sa.desc(Recording.timestamp))
.limit(1)
.first()
)
return db.query(Recording).order_by(sa.desc(Recording.timestamp)).limit(1).first()
Mustaballer marked this conversation as resolved.
Show resolved Hide resolved


def get_recording_by_id(recording_id):
return db.query(Recording).filter_by(id=recording_id).first()


def export_sql(recording_id):
Mustaballer marked this conversation as resolved.
Show resolved Hide resolved
"""Export the recording data as SQL statements.

Args:
recording_id (int): The ID of the recording.

Returns:
str: The SQL statements to insert the recording into the output file.
"""
engine = sa.create_engine(config.DB_URL)
Session = sessionmaker(bind=engine)
session = Session()

recording = get_recording_by_id(recording_id)

if recording:
sql = f"INSERT INTO recording VALUES ({recording.id}, {recording.timestamp}, {recording.monitor_width}, {recording.monitor_height}, {recording.double_click_interval_seconds}, {recording.double_click_distance_pixels}, '{recording.platform}', '{recording.task_description}')"
Mustaballer marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"Recording with ID {recording_id} exported successfully.")
else:
logger.info(f"No recording found with ID {recording_id}.")

return sql


def create_db(recording_id, sql):
"""Create a new database and import the recording data.

Args:
recording_id (int): The ID of the recording.
sql (str): The SQL statements to import the recording.

Returns:
tuple: A tuple containing the timestamp and the file path of the new database.
"""
db.close()
db_fname = f"recording_{recording_id}.db"

t = time.time()
shutil.copyfile(config.ENV_FILE_PATH, f"{config.ENV_FILE_PATH}-{t}")
Mustaballer marked this conversation as resolved.
Show resolved Hide resolved
config.set_db_url(db_fname)

with open(config.ENV_FILE_PATH, "r") as f:
lines = f.readlines()
lines[1] = f"DB_FNAME={db_fname}\n"
Mustaballer marked this conversation as resolved.
Show resolved Hide resolved
with open(config.ENV_FILE_PATH, "w") as f:
f.writelines(lines)

engine = sa.create_engine(config.DB_URL)
Session = sessionmaker(bind=engine)
session = Session()
os.system("alembic upgrade head")
db.engine = engine

with engine.begin() as connection:
connection.execute(sql)

db_file_path = config.DB_FPATH.resolve()

return t, db_file_path


def restore_db(timestamp):
"""Restore the database to a previous state.

Args:
timestamp (float): The timestamp associated with the backup file.
"""
backup_file = f"{config.ENV_FILE_PATH}-{timestamp}"
shutil.copyfile(backup_file, config.ENV_FILE_PATH)
config.set_db_url("openadapt.db")
Mustaballer marked this conversation as resolved.
Show resolved Hide resolved
db.engine = get_engine()


def export_recording(recording_id):
"""Export a recording by creating a new database, importing the recording, and then restoring the previous state.

Args:
recording_id (int): The ID of the recording to export.

Returns:
str: The file path of the new database.
"""
sql = export_sql(recording_id)
t, db_file_path = create_db(recording_id, sql)
restore_db(t)
return db_file_path

def get_recording(timestamp):
return (
db
.query(Recording)
.filter(Recording.timestamp == timestamp)
.first()
)
return db.query(Recording).filter(Recording.timestamp == timestamp).first()


def _get(table, recording_timestamp):
return (
db
.query(table)
db.query(table)
.filter(table.recording_timestamp == recording_timestamp)
.order_by(table.timestamp)
.all()
Expand All @@ -150,11 +236,11 @@ def get_screenshots(recording, precompute_diffs=False):

# TODO: store diffs
if precompute_diffs:
logger.info(f"precomputing diffs...")
logger.info("precomputing diffs...")
[(screenshot.diff, screenshot.diff_mask) for screenshot in screenshots]

return screenshots


def get_window_events(recording):
return _get(WindowEvent, recording.timestamp)
return _get(WindowEvent, recording.timestamp)
Loading