Skip to content

Commit

Permalink
Added tests for double and numeric arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 22, 2024
1 parent 59a3efc commit 79d4111
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/test_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from django.contrib.postgres.fields import ArrayField
from django.core import serializers
from django.db import connection, migrations, models
from django.db.models import Avg, Sum
from django.db.models import Avg, Sum, FloatField, DecimalField
from django.db.models.functions import Cast
from django.db.migrations.loader import MigrationLoader
from django.forms import ModelForm
from math import sqrt
Expand Down Expand Up @@ -48,6 +49,8 @@ class Item(models.Model):
binary_embedding = BitField(length=3, null=True, blank=True)
sparse_embedding = SparseVectorField(dimensions=3, null=True, blank=True)
embeddings = ArrayField(VectorField(dimensions=3), null=True, blank=True)
double_embedding = ArrayField(FloatField(), null=True, blank=True)
numeric_embedding = ArrayField(DecimalField(max_digits=20, decimal_places=10), null=True, blank=True)

class Meta:
app_label = 'django_app'
Expand Down Expand Up @@ -85,6 +88,8 @@ class Migration(migrations.Migration):
('binary_embedding', pgvector.django.BitField(length=3, null=True, blank=True)),
('sparse_embedding', pgvector.django.SparseVectorField(dimensions=3, null=True, blank=True)),
('embeddings', ArrayField(pgvector.django.VectorField(dimensions=3), null=True, blank=True)),
('double_embedding', ArrayField(FloatField(), null=True, blank=True)),
('numeric_embedding', ArrayField(DecimalField(max_digits=20, decimal_places=10), null=True, blank=True)),
],
),
migrations.AddIndex(
Expand Down Expand Up @@ -448,3 +453,23 @@ def test_vector_array(self):
item = Item.objects.get(pk=1)
assert item.embeddings[0].tolist() == [1, 2, 3]
assert item.embeddings[1].tolist() == [4, 5, 6]

def test_double_array(self):
Item(id=1, double_embedding=[1, 1, 1]).save()
Item(id=2, double_embedding=[2, 2, 2]).save()
Item(id=3, double_embedding=[1, 1, 2]).save()
distance = L2Distance(Cast('double_embedding', VectorField()), [1, 1, 1])
items = Item.objects.annotate(distance=distance).order_by(distance)
assert [v.id for v in items] == [1, 3, 2]
assert [v.distance for v in items] == [0, 1, sqrt(3)]
assert items[1].double_embedding == [1, 1, 2]

def test_numeric_array(self):
Item(id=1, numeric_embedding=[1, 1, 1]).save()
Item(id=2, numeric_embedding=[2, 2, 2]).save()
Item(id=3, numeric_embedding=[1, 1, 2]).save()
distance = L2Distance(Cast('numeric_embedding', VectorField()), [1, 1, 1])
items = Item.objects.annotate(distance=distance).order_by(distance)
assert [v.id for v in items] == [1, 3, 2]
assert [v.distance for v in items] == [0, 1, sqrt(3)]
assert items[1].numeric_embedding == [1, 1, 2]

0 comments on commit 79d4111

Please sign in to comment.