Skip to content

Commit

Permalink
Added globally option for Psycopg 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 2, 2024
1 parent 9018d36 commit b88ebed
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.3.5 (unreleased)

- Added `globally` option for Psycopg 2

## 0.3.4 (2024-09-26)

- Added `schema` option for asyncpg
Expand Down
4 changes: 2 additions & 2 deletions pgvector/psycopg2/halfvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def cast_halfvec(value, cur):
return HalfVector._from_db(value)


def register_halfvec_info(oid):
def register_halfvec_info(oid, scope):
halfvec = new_type((oid,), 'HALFVEC', cast_halfvec)
register_type(halfvec)
register_type(halfvec, scope)
register_adapter(HalfVector, HalfvecAdapter)
10 changes: 6 additions & 4 deletions pgvector/psycopg2/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from .vector import register_vector_info


def register_vector(conn_or_curs=None):
# TODO make globally False by default in 0.4.0
def register_vector(conn_or_curs=None, globally=True):
conn = conn_or_curs if hasattr(conn_or_curs, 'cursor') else conn_or_curs.connection
cur = conn.cursor(cursor_factory=cursor)
scope = None if globally else conn_or_curs

# use to_regtype to get first matching type in search path
cur.execute("SELECT typname, oid FROM pg_type WHERE oid IN (to_regtype('vector'), to_regtype('halfvec'), to_regtype('sparsevec'))")
Expand All @@ -16,10 +18,10 @@ def register_vector(conn_or_curs=None):
if 'vector' not in type_info:
raise psycopg2.ProgrammingError('vector type not found in the database')

register_vector_info(type_info['vector'])
register_vector_info(type_info['vector'], scope)

if 'halfvec' in type_info:
register_halfvec_info(type_info['halfvec'])
register_halfvec_info(type_info['halfvec'], scope)

if 'sparsevec' in type_info:
register_sparsevec_info(type_info['sparsevec'])
register_sparsevec_info(type_info['sparsevec'], scope)
4 changes: 2 additions & 2 deletions pgvector/psycopg2/sparsevec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def cast_sparsevec(value, cur):
return SparseVector._from_db(value)


def register_sparsevec_info(oid):
def register_sparsevec_info(oid, scope):
sparsevec = new_type((oid,), 'SPARSEVEC', cast_sparsevec)
register_type(sparsevec)
register_type(sparsevec, scope)
register_adapter(SparseVector, SparsevecAdapter)
4 changes: 2 additions & 2 deletions pgvector/psycopg2/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def cast_vector(value, cur):
return Vector._from_db(value)


def register_vector_info(oid):
def register_vector_info(oid, scope):
vector = new_type((oid,), 'VECTOR', cast_vector)
register_type(vector)
register_type(vector, scope)
register_adapter(np.ndarray, VectorAdapter)

0 comments on commit b88ebed

Please sign in to comment.