Skip to content

Commit

Permalink
Add type hints (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
palfrey authored Dec 21, 2022
1 parent 7755b60 commit 698882e
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 35 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ jobs:
pip install -r requirements.txt
pip install "Django~=${{ matrix.django-version }}.0" .
- name: Run mypy
run: |
python -m mypy dj_database_url.py
- name: Run Tests
run: |
echo "$(python --version) / Django $(django-admin --version)"
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include py.typed
89 changes: 55 additions & 34 deletions dj_database_url.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import urllib.parse as urlparse
from typing import Any, Dict, Optional, Union

from typing_extensions import TypedDict

# Register database schemes in URLs.
urlparse.uses_netloc.append("postgres")
Expand Down Expand Up @@ -45,15 +48,33 @@
}


# From https://docs.djangoproject.com/en/4.0/ref/settings/#databases
class DBConfig(TypedDict, total=False):
ATOMIC_REQUESTS: bool
AUTOCOMMIT: bool
CONN_MAX_AGE: int
CONN_HEALTH_CHECKS: bool
DISABLE_SERVER_SIDE_CURSORS: bool
ENGINE: str
HOST: str
NAME: str
OPTIONS: Optional[Dict[str, Any]]
PASSWORD: str
PORT: Union[str, int]
TEST: Dict[str, Any]
TIME_ZONE: str
USER: str


def config(
env=DEFAULT_ENV,
default=None,
engine=None,
conn_max_age=0,
conn_health_checks=False,
ssl_require=False,
test_options=None,
):
env: str = DEFAULT_ENV,
default: Optional[str] = None,
engine: Optional[str] = None,
conn_max_age: int = 0,
conn_health_checks: bool = False,
ssl_require: bool = False,
test_options: Optional[Dict] = None,
) -> DBConfig:
"""Returns configured DATABASE dictionary from DATABASE_URL."""
s = os.environ.get(env, default)

Expand All @@ -66,13 +87,13 @@ def config(


def parse(
url,
engine=None,
conn_max_age=0,
conn_health_checks=False,
ssl_require=False,
test_options=None,
):
url: str,
engine: Optional[str] = None,
conn_max_age: int = 0,
conn_health_checks: bool = False,
ssl_require: bool = False,
test_options: Optional[dict] = None,
) -> DBConfig:
"""Parses a database URL."""
if url == "sqlite://:memory:":
# this is a special case, because if we pass this URL into
Expand All @@ -82,31 +103,31 @@ def parse(
# note: no other settings are required for sqlite

# otherwise parse the url as normal
parsed_config = {}
parsed_config: DBConfig = {}

if test_options is None:
test_options = {}

url = urlparse.urlsplit(url)
spliturl = urlparse.urlsplit(url)

# Split query strings from path.
path = url.path[1:]
if "?" in path and not url.query:
path, query = path.split("?", 2)
path = spliturl.path[1:]
if "?" in path and not spliturl.query:
path, raw_query = path.split("?", 2)
else:
path, query = path, url.query
query = urlparse.parse_qs(query)
path, raw_query = path, spliturl.query
query = urlparse.parse_qs(raw_query)

# If we are using sqlite and we have no path, then assume we
# want an in-memory database (this is the behaviour of sqlalchemy)
if url.scheme == "sqlite" and path == "":
if spliturl.scheme == "sqlite" and path == "":
path = ":memory:"

# Handle postgres percent-encoded paths.
hostname = url.hostname or ""
hostname = spliturl.hostname or ""
if "%" in hostname:
# Switch to url.netloc to avoid lower cased paths
hostname = url.netloc
hostname = spliturl.netloc
if "@" in hostname:
hostname = hostname.rsplit("@", 1)[1]
if ":" in hostname:
Expand All @@ -116,26 +137,26 @@ def parse(

# Lookup specified engine.
if engine is None:
engine = SCHEMES.get(url.scheme)
engine = SCHEMES.get(spliturl.scheme)
if engine is None:
raise ValueError(
"No support for '%s'. We support: %s"
% (url.scheme, ", ".join(sorted(SCHEMES.keys())))
% (spliturl.scheme, ", ".join(sorted(SCHEMES.keys())))
)

port = (
str(url.port)
if url.port
str(spliturl.port)
if spliturl.port
and engine in (SCHEMES["oracle"], SCHEMES["mssql"], SCHEMES["mssqlms"])
else url.port
else spliturl.port
)

# Update with environment configuration.
parsed_config.update(
{
"NAME": urlparse.unquote(path or ""),
"USER": urlparse.unquote(url.username or ""),
"PASSWORD": urlparse.unquote(url.password or ""),
"USER": urlparse.unquote(spliturl.username or ""),
"PASSWORD": urlparse.unquote(spliturl.password or ""),
"HOST": hostname,
"PORT": port or "",
"CONN_MAX_AGE": conn_max_age,
Expand All @@ -150,9 +171,9 @@ def parse(
)

# Pass the query string into OPTIONS.
options = {}
options: Dict[str, Any] = {}
for key, values in query.items():
if url.scheme == "mysql" and key == "ssl-ca":
if spliturl.scheme == "mysql" and key == "ssl-ca":
options["ssl"] = {"ca": values[-1]}
continue

Expand Down
Empty file added py.typed
Empty file.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
[tool.black]
skip-string-normalization = 1

[tool.mypy]
show_error_codes=true
disallow_untyped_defs=true
disallow_untyped_calls=true
warn_redundant_casts=true
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
coverage
mypy
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
long_description=readme,
long_description_content_type="text/x-rst",
py_modules=["dj_database_url"],
install_requires=["Django>=3.2"],
install_requires=["Django>=3.2", "typing_extensions >= 3.10.0.0"],
zip_safe=False,
include_package_data=True,
platforms="any",
Expand Down

0 comments on commit 698882e

Please sign in to comment.