Skip to content

Commit

Permalink
Added avg function with type casting to SQLAlchemy - #44
Browse files Browse the repository at this point in the history
Co-authored-by: lucasgadams
  • Loading branch information
ankane committed Oct 2, 2024
1 parent f65c361 commit e90260a
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.3.5 (unreleased)

- Added `avg` function with type casting to SQLAlchemy
- Added `globally` option for Psycopg 2

## 0.3.4 (2024-09-26)
Expand Down
5 changes: 4 additions & 1 deletion pgvector/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bit import BIT
from .functions import avg, sum
from .halfvec import HALFVEC
from .sparsevec import SPARSEVEC
from .vector import VECTOR
Expand All @@ -12,5 +13,7 @@
'BIT',
'SPARSEVEC',
'HalfVector',
'SparseVector'
'SparseVector',
'avg',
'sum'
]
8 changes: 8 additions & 0 deletions pgvector/sqlalchemy/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# https://docs.sqlalchemy.org/en/20/core/functions.html
# include sum for a consistent API
from sqlalchemy.sql.functions import ReturnTypeFromArgs, sum


class avg(ReturnTypeFromArgs):
inherit_cache = True
package = 'pgvector'
36 changes: 17 additions & 19 deletions tests/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
import pytest
from sqlalchemy import create_engine, insert, inspect, select, text, MetaData, Table, Column, Index, Integer
from sqlalchemy.exc import StatementError
Expand Down Expand Up @@ -339,41 +339,39 @@ def test_select_orm(self):

def test_avg(self):
with Session(engine) as session:
avg = session.query(func.avg(Item.embedding)).first()[0]
assert avg is None
res = session.query(avg(Item.embedding)).first()[0]
assert res is None
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
avg = session.query(func.avg(Item.embedding)).first()[0]
# does not type cast
assert avg == '[2.5,3.5,4.5]'
res = session.query(avg(Item.embedding)).first()[0]
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))

def test_avg_orm(self):
with Session(engine) as session:
avg = session.scalars(select(func.avg(Item.embedding))).first()
assert avg is None
res = session.scalars(select(avg(Item.embedding))).first()
assert res is None
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
avg = session.scalars(select(func.avg(Item.embedding))).first()
# does not type cast
assert avg == '[2.5,3.5,4.5]'
res = session.scalars(select(avg(Item.embedding))).first()
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))

def test_sum(self):
with Session(engine) as session:
sum = session.query(func.sum(Item.embedding)).first()[0]
assert sum is None
res = session.query(sum(Item.embedding)).first()[0]
assert res is None
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
sum = session.query(func.sum(Item.embedding)).first()[0]
assert np.array_equal(sum, np.array([5, 7, 9]))
res = session.query(sum(Item.embedding)).first()[0]
assert np.array_equal(res, np.array([5, 7, 9]))

def test_sum_orm(self):
with Session(engine) as session:
sum = session.scalars(select(func.sum(Item.embedding))).first()
assert sum is None
res = session.scalars(select(sum(Item.embedding))).first()
assert res is None
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
sum = session.scalars(select(func.sum(Item.embedding))).first()
assert np.array_equal(sum, np.array([5, 7, 9]))
res = session.scalars(select(sum(Item.embedding))).first()
assert np.array_equal(res, np.array([5, 7, 9]))

def test_bad_dimensions(self):
item = Item(embedding=[1, 2])
Expand Down
36 changes: 17 additions & 19 deletions tests/test_sqlmodel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector
from pgvector.sqlalchemy import VECTOR, HALFVEC, BIT, SPARSEVEC, SparseVector, avg, sum
import pytest
from sqlalchemy import Column, Index
from sqlalchemy.exc import StatementError
Expand Down Expand Up @@ -198,41 +198,39 @@ def test_select(self):

def test_vector_avg(self):
with Session(engine) as session:
avg = session.exec(select(func.avg(Item.embedding))).first()
assert avg is None
res = session.exec(select(avg(Item.embedding))).first()
assert res is None
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
avg = session.exec(select(func.avg(Item.embedding))).first()
# does not type cast
assert avg == '[2.5,3.5,4.5]'
res = session.exec(select(avg(Item.embedding))).first()
assert np.array_equal(res, np.array([2.5, 3.5, 4.5]))

def test_vector_sum(self):
with Session(engine) as session:
sum = session.exec(select(func.sum(Item.embedding))).first()
assert sum is None
res = session.exec(select(sum(Item.embedding))).first()
assert res is None
session.add(Item(embedding=[1, 2, 3]))
session.add(Item(embedding=[4, 5, 6]))
sum = session.exec(select(func.sum(Item.embedding))).first()
assert np.array_equal(sum, np.array([5, 7, 9]))
res = session.exec(select(sum(Item.embedding))).first()
assert np.array_equal(res, np.array([5, 7, 9]))

def test_halfvec_avg(self):
with Session(engine) as session:
avg = session.exec(select(func.avg(Item.half_embedding))).first()
assert avg is None
res = session.exec(select(avg(Item.half_embedding))).first()
assert res is None
session.add(Item(half_embedding=[1, 2, 3]))
session.add(Item(half_embedding=[4, 5, 6]))
avg = session.exec(select(func.avg(Item.half_embedding))).first()
# does not type cast
assert avg == '[2.5,3.5,4.5]'
res = session.exec(select(avg(Item.half_embedding))).first()
assert res.to_list() == [2.5, 3.5, 4.5]

def test_halfvec_sum(self):
with Session(engine) as session:
sum = session.exec(select(func.sum(Item.half_embedding))).first()
assert sum is None
res = session.exec(select(sum(Item.half_embedding))).first()
assert res is None
session.add(Item(half_embedding=[1, 2, 3]))
session.add(Item(half_embedding=[4, 5, 6]))
sum = session.exec(select(func.sum(Item.half_embedding))).first()
assert sum.to_list() == [5, 7, 9]
res = session.exec(select(sum(Item.half_embedding))).first()
assert res.to_list() == [5, 7, 9]

def test_bad_dimensions(self):
item = Item(embedding=[1, 2])
Expand Down

0 comments on commit e90260a

Please sign in to comment.