Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Unit-tests #9

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
14 changes: 0 additions & 14 deletions .github/workflows/mongodb_settings.py

This file was deleted.

2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@ docs/build
site/

_development/
tests/django
documentdb_settings.py
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,9 @@ class TestModel(DocumentModel):
## Forked Project

This project, **django-documentdb**, is a fork of the original **django-mongodb** library, which aimed to integrate MongoDB with Django. The fork was created to enhance compatibility with AWS DocumentDB, addressing the limitations of its API support while maintaining the core functionalities of the original library. We appreciate the work of the MongoDB Python Team and aim to build upon their foundation to better serve users needing DocumentDB integration.

## Run tests

docker build . -t test:latest -f tests/Dockerfile && docker run -it test:latest

docker build . -t mongo_test:latest -f tests/mongodb.Dockerfile && docker run -it mongo_test:latest
13 changes: 5 additions & 8 deletions django_documentdb/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db.models.lookups import IsNull

from .query_utils import process_lhs
from .utils import prefix_with_dollar

# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower().
MONGO_AGGREGATIONS = {Count: "sum"}
Expand All @@ -26,9 +27,9 @@ def aggregate(
node = self
lhs_mql = process_lhs(node, compiler, connection)
if resolve_inner_expression:
return lhs_mql
return prefix_with_dollar(lhs_mql)
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
return {f"${operator}": f"${lhs_mql}"}
return {f"${operator}": prefix_with_dollar(lhs_mql)}


def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
Expand Down Expand Up @@ -64,12 +65,8 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
return {"$add": [{"$size": lhs_mql}, exits_null]}


def stddev_variance(self, compiler, connection, **extra_context):
if self.function.endswith("_SAMP"):
operator = "stdDevSamp"
elif self.function.endswith("_POP"):
operator = "stdDevPop"
return aggregate(self, compiler, connection, operator=operator, **extra_context)
def stddev_variance(*args, **kwargs): # noqa: ARG001
raise NotImplementedError("StdDev and Variance are not supported yet.")


def register_aggregates():
Expand Down
63 changes: 42 additions & 21 deletions django_documentdb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .operations import DatabaseOperations
from .query_utils import regex_match
from .schema import DatabaseSchemaEditor
from .utils import IndexNotUsedWarning, OperationDebugWrapper
from .utils import IndexNotUsedWarning, OperationDebugWrapper, prefix_with_dollar

# ignore warning from pymongo about DocumentDB
warnings.filterwarnings("ignore", "You appear to be connected to a DocumentDB cluster", UserWarning)
Expand Down Expand Up @@ -87,37 +87,58 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"iendswith": "LIKE '%%' || UPPER({})",
}

def _isnull_operator(a, b):
def _isnull_operator(a, b, pos: bool = False):
if b:
return {a: None}
return {a: None} if not pos else {"$eq": [prefix_with_dollar(a), None]}

warnings.warn("You're using $ne, index will not be used", IndexNotUsedWarning, stacklevel=1)
return {a: {"$ne": None}}
return {a: {"$ne": None}} if not pos else {"$ne": [prefix_with_dollar(a), None]}

mongo_operators = {
# Where a = field_name, b = value
"exact": lambda a, b: {a: b},
"gt": lambda a, b: {a: {"$gt": b}},
"gte": lambda a, b: {a: {"$gte": b}},
"lt": lambda a, b: {a: {"$lt": b}},
"lte": lambda a, b: {a: {"$lte": b}},
"in": lambda a, b: {a: {"$in": b}},
# Where a = field_name, b = value, pos = positional operator syntax
"exact": lambda a, b, pos: {a: b} if not pos else {"$eq": [prefix_with_dollar(a), b]},
"gt": lambda a, b, pos: {a: {"$gt": b}} if not pos else {"$gt": [prefix_with_dollar(a), b]},
"gte": lambda a, b, pos: {a: {"$gte": b}}
if not pos
else {"$gte": [prefix_with_dollar(a), b]},
"lt": lambda a, b, pos: {a: {"$lt": b}} if not pos else {"$lt": [prefix_with_dollar(a), b]},
"lte": lambda a, b, pos: {a: {"$lte": b}}
if not pos
else {"$lte": [prefix_with_dollar(a), b]},
"in": lambda a, b, pos: {a: {"$in": b}} if not pos else {"$in": [prefix_with_dollar(a), b]},
"isnull": _isnull_operator,
"range": lambda a, b: {
"range": lambda a, b, pos: {
"$and": [
{"$or": [{a: {"$gte": b[0]}}, {a: None}]},
{"$or": [{a: {"$lte": b[1]}}, {a: None}]},
]
}
if not pos
else {
"$and": [
{
"$or": [
{"$gte": [prefix_with_dollar(a), b[0]]},
{"$eq": [prefix_with_dollar(a), None]},
]
},
{
"$or": [
{"$lte": [prefix_with_dollar(a), b[1]]},
{"$eq": [prefix_with_dollar(a), None]},
]
},
]
},
"iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True),
"startswith": lambda a, b: regex_match(a, f"^{b}"),
"istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True),
"endswith": lambda a, b: regex_match(a, f"{b}$"),
"iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True),
"contains": lambda a, b: regex_match(a, b),
"icontains": lambda a, b: regex_match(a, b, insensitive=True),
"regex": lambda a, b: regex_match(a, b),
"iregex": lambda a, b: regex_match(a, b, insensitive=True),
"iexact": lambda a, b, pos: regex_match(a, f"^{b}$", insensitive=True, pos=pos),
"startswith": lambda a, b, pos: regex_match(a, f"^{b}", pos=pos),
"istartswith": lambda a, b, pos: regex_match(a, f"^{b}", insensitive=True, pos=pos),
"endswith": lambda a, b, pos: regex_match(a, f"{b}$", pos=pos),
"iendswith": lambda a, b, pos: regex_match(a, f"{b}$", insensitive=True, pos=pos),
"contains": lambda a, b, pos: regex_match(a, b, pos=pos),
"icontains": lambda a, b, pos: regex_match(a, b, insensitive=True, pos=pos),
"regex": lambda a, b, pos: regex_match(a, b, pos=pos),
"iregex": lambda a, b, pos: regex_match(a, b, insensitive=True, pos=pos),
}

display_name = "DocumentDB"
Expand Down
34 changes: 12 additions & 22 deletions django_documentdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
from django.db import IntegrityError, NotSupportedError
from django.db.models import Count
from django.db.models.aggregates import Aggregate, Variance
from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When
from django.db.models.functions.comparison import Coalesce
from django.db.models.functions.math import Power
from django.db.models.lookups import IsNull
from django.db.models.sql import compiler
Expand All @@ -18,6 +16,7 @@

from .base import Cursor
from .query import MongoQuery, wrap_database_errors
from .utils import Distinct, prefix_with_dollar


class SQLCompiler(compiler.SQLCompiler):
Expand Down Expand Up @@ -95,8 +94,8 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group
group[alias] = sub_expr.as_mql(self, self.connection)
replacing_expr = inner_column
# Count must return 0 rather than null.
if isinstance(sub_expr, Count):
replacing_expr = Coalesce(replacing_expr, 0)
# if isinstance(sub_expr, Count):
# replacing_expr = Coalesce(replacing_expr, 0)
# Variance = StdDev^2
if isinstance(sub_expr, Variance):
replacing_expr = Power(replacing_expr, 2)
Expand Down Expand Up @@ -245,7 +244,8 @@ def execute_sql(
else:
return self._make_result(obj, columns)
# result_type is MULTI
cursor.batch_size(chunk_size)
if not isinstance(cursor, list):
cursor.batch_size(chunk_size)
result = self.cursor_iter(cursor, chunk_size, columns)
if not chunked_fetch:
# If using non-chunked reads, read data into memory.
Expand Down Expand Up @@ -347,24 +347,16 @@ def build_query(self, columns=None):
if self.query.distinct:
# If query is distinct, build a $group stage for distinct
# fields, then set project fields based on the grouped _id.
distinct_fields = self.get_project_fields(
columns, ordering_fields, force_expression=True
)
if not query.aggregation_pipeline:
query.aggregation_pipeline = []
query.aggregation_pipeline.extend(
[
{"$group": {"_id": distinct_fields}},
{"$project": {key: f"$_id.{key}" for key in distinct_fields}},
]
query.distinct = Distinct(
fields=self.get_project_fields(columns, ordering_fields, force_expression=True)
)
else:
# Otherwise, project fields without grouping.
query.project_fields = self.get_project_fields(columns, ordering_fields)
# If columns is None, then get_project_fields() won't add
# ordering_fields to $project. Use $addFields (extra_fields) instead.
# if columns is None:
# extra_fields += ordering_fields
if columns is None:
extra_fields += ordering_fields
query.lookup_pipeline = self.get_lookup_pipeline()
where = self.get_where()
try:
Expand Down Expand Up @@ -479,10 +471,11 @@ def get_combinator_queries(self):
inner_pipeline.append({"$project": fields})
# Combine query with the current combinator pipeline.
if combinator_pipeline:
raise NotSupportedError
combinator_pipeline.append(
{"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}}
)
else:
else: # noqa: RET506
combinator_pipeline = inner_pipeline
if not self.query.combinator_all:
ids = defaultdict(dict)
Expand Down Expand Up @@ -528,10 +521,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
fields[collection][name] = 1
else:
mql = expr.as_mql(self, self.connection)
if isinstance(mql, str):
fields[collection][name] = f"${mql}"
else:
fields[collection][name] = mql
fields[collection][name] = prefix_with_dollar(mql)

except EmptyResultSet:
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
Expand Down
37 changes: 20 additions & 17 deletions django_documentdb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
)
from django.db.models.sql import Query

from django_documentdb.utils import IndexNotUsedWarning
from django_documentdb.utils import IndexNotUsedWarning, prefix_with_dollar


def case(self, compiler, connection):
case_parts = []
for case in self.cases:
case_mql = {}
try:
case_mql["case"] = case.as_mql(compiler, connection)
case_mql["case"] = case.as_mql(compiler, connection, positional_operator_syntax=True)
except EmptyResultSet:
continue
except FullResultSet:
default_mql = case.result.as_mql(compiler, connection)
break
case_mql["then"] = case.result.as_mql(compiler, connection)
case_mql["then"] = prefix_with_dollar(case.result.as_mql(compiler, connection))
case_parts.append(case_mql)
else:
default_mql = self.default.as_mql(compiler, connection)
Expand Down Expand Up @@ -73,12 +73,8 @@ def col(self, compiler, connection): # noqa: ARG001

def combined_expression(self, compiler, connection):
expressions = [
f"${self.lhs.as_mql(compiler, connection)}"
if isinstance(self.lhs, Col)
else self.lhs.as_mql(compiler, connection),
f"${self.rhs.as_mql(compiler, connection)}"
if isinstance(self.rhs, Col)
else self.rhs.as_mql(compiler, connection),
prefix_with_dollar(self.lhs.as_mql(compiler, connection)),
prefix_with_dollar(self.rhs.as_mql(compiler, connection)),
]
return connection.ops.combine_expression(self.connector, expressions)

Expand Down Expand Up @@ -121,10 +117,15 @@ def query(self, compiler, connection, lookup_name=None):
subquery.subquery_lookup = {
"as": table_output,
"from": from_table,
"let": {
compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection)
for col, i in subquery_compiler.column_indices.items()
},
"localField": next(
iter(
[
col.as_mql(compiler, connection)
for col, i in subquery_compiler.column_indices.items()
]
)
),
"foreignField": next(iter(subquery.mongo_query.keys())),
}
# The result must be a list of values. The output is compressed with an
# aggregation pipeline.
Expand Down Expand Up @@ -191,16 +192,18 @@ def subquery(self, compiler, connection, lookup_name=None):
return self.query.as_mql(compiler, connection, lookup_name=lookup_name)


def exists(self, compiler, connection, lookup_name=None):
def exists(self, compiler, connection, lookup_name=None, positional_operator_syntax: bool = False):
try:
lhs_mql = subquery(self, compiler, connection, lookup_name=lookup_name)
except EmptyResultSet:
return Value(False).as_mql(compiler, connection)
return connection.mongo_operators["isnull"](lhs_mql, False)
return connection.mongo_operators["isnull"](lhs_mql, False, pos=positional_operator_syntax)


def when(self, compiler, connection):
return self.condition.as_mql(compiler, connection)
def when(self, compiler, connection, positional_operator_syntax: bool = False):
return self.condition.as_mql(
compiler, connection, positional_operator_syntax=positional_operator_syntax
)


def value(self, compiler, connection): # noqa: ARG001
Expand Down
4 changes: 2 additions & 2 deletions django_documentdb/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)

from .query_utils import process_lhs
from .utils import prefix_with_dollar

MONGO_OPERATORS = {
Ceil: "ceil",
Expand Down Expand Up @@ -102,8 +103,7 @@ def func(self, compiler, connection):
# Functions are using array syntax and for field name we want to add $
lhs_mql = process_lhs(self, compiler, connection)
if isinstance(lhs_mql, list):
field_name = lhs_mql[0]
lhs_mql[0] = f"${field_name}"
lhs_mql = [prefix_with_dollar(field_name) for field_name in lhs_mql]
operator = MONGO_OPERATORS.get(self.__class__, self.function.lower())
return {f"${operator}": lhs_mql}

Expand Down
12 changes: 6 additions & 6 deletions django_documentdb/lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from .query_utils import process_lhs, process_rhs


def builtin_lookup(self, compiler, connection):
def builtin_lookup(self, compiler, connection, positional_operator_syntax: bool = False):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
return connection.mongo_operators[self.lookup_name](lhs_mql, value, positional_operator_syntax)


_field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter
Expand All @@ -33,7 +33,7 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param):
return sql, sql_params


def in_(self, compiler, connection):
def in_(self, compiler, connection, positional_operator_syntax: bool = False):
if isinstance(self.lhs, MultiColSource):
raise NotImplementedError("MultiColSource is not supported.")
db_rhs = getattr(self.rhs, "_db", None)
Expand All @@ -42,14 +42,14 @@ def in_(self, compiler, connection):
"Subqueries aren't allowed across different databases. Force "
"the inner query to be evaluated using `list(inner_query)`."
)
return builtin_lookup(self, compiler, connection)
return builtin_lookup(self, compiler, connection, positional_operator_syntax)


def is_null(self, compiler, connection):
def is_null(self, compiler, connection, positional_operator_syntax: bool = False):
if not isinstance(self.rhs, bool):
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
lhs_mql = process_lhs(self, compiler, connection)
return connection.mongo_operators["isnull"](lhs_mql, self.rhs)
return connection.mongo_operators["isnull"](lhs_mql, self.rhs, pos=positional_operator_syntax)


# from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4
Expand Down
Loading