Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for standardizable SQL tasks #2

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
os: [ubuntu-latest]
python-version: ['3.9', '3.10', '3.11', '3.12']
airflow-version: ['<2.10', '<2.11']
airflow-postgres-version: ['<6']

steps:
- name: Checkout code
Expand All @@ -32,7 +33,7 @@ jobs:
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install pytest
python -m pip install "apache-airflow${{ matrix.airflow-version }}"
python -m pip install "apache-airflow${{ matrix.airflow-version }}" "apache-airflow-providers-postgres${{ matrix.airflow-postgres-version }}"
- name: Run the test
run: pytest

Expand All @@ -59,7 +60,7 @@ jobs:
run: |
python -m pip install --upgrade pip setuptools wheel
python -m pip install pytest pytest-cov coveralls
python -m pip install "apache-airflow${{ matrix.airflow-version }}"
python -m pip install "apache-airflow${{ matrix.airflow-version }}" "apache-airflow-providers-postgres<6"
- name: Run the test with coverage
run: pytest --cov
- name: Coveralls
Expand Down
2 changes: 1 addition & 1 deletion LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
BSD 3-Clause License

Copyright (c) 2024, Astro Data Lab
Copyright (c) 2024-2025, Astro Data Lab <datalab@noirlab.edu>

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
prune .github
global-exclude .gitignore .readthedocs.yml
global-exclude .gitignore .readthedocs.yaml
173 changes: 172 additions & 1 deletion dlairflow/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

Standard tasks for working with PostgreSQL that can be imported into a DAG.
"""
import os
from airflow.operators.bash import BashOperator
from airflow.hooks.base import BaseHook
from .util import user_scratch
from airflow.providers.postgres.operators.postgres import PostgresOperator
from .util import user_scratch, ensure_sql


def _connection_to_environment(connection):
Expand Down Expand Up @@ -92,3 +94,172 @@ def pg_restore_schema(connection, schema, dump_dir=None):
'dump_dir': dump_dir},
env=pg_env,
append_env=True)


def q3c_index(connection, schema, table, ra='ra', dec='dec', overwrite=False):
"""Create a q3c index on `schema`.`table`.

Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
table : :class:`str`
The name of the table in `schema`.
ra : :class:`str`, optional
Name of the column containing Right Ascension, default 'ra'.
dec : :class:`str`, optional
Name of the column containing Declination, default 'dec'.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.

Returns
-------
:class:`~airflow.providers.postgres.operators.postgres.PostgresOperator`
A task to create a q3c index
"""
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.q3c_index.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.q3c_index().
-- Call q3c_index(..., overwrite=True) to replace this file.
--
CREATE INDEX {{ params.table }}_q3c_ang2ipix
ON {{ params.schema }}.{{ params.table }} (q3c_ang2ipix("{{ params.ra }}", "{{ params.dec }}"))
WITH (fillfactor=100);
CLUSTER {{ params.table }}_q3c_ang2ipix ON {{ params.schema }}.{{ params.table }};
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return PostgresOperator(task_id="q3c_index",
postgres_conn_id=connection,
sql=f"sql/{sql_basename}",
params={'schema': schema, 'table': table, 'ra': ra, 'dec': dec})


def index_columns(connection, schema, table, columns, overwrite=False):
"""Create "generic" indexes for a set of columns

Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
table : :class:`str`
The name of the table in `schema`.
columns : :class:`list`
A list of columns to index. See below for the possible entries in
the list of columns.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.

Returns
-------
:class:`~airflow.providers.postgres.operators.postgres.PostgresOperator`
A task to create several indexes.

Notes
-----
`columns` may be a list containing multiple types:

* :class:`str`: create an index on one column.
* :class:`tuple`: create an index on the set of columns in the tuple.
* :class:`dict`: create a *function* index. The key is the name of the function
and the value is the column that is the argument to the function.
* Any other type in `columns` will be ignored.
"""
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.index_columns.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.index_columns().
-- Call index_columns(..., overwrite=True) to replace this file.
--
{% for col in params.columns %}
{% if col is string -%}
CREATE INDEX {{ params.table }}_{{ col }}_idx
ON {{ params.schema }}.{{ params.table }} ("{{ col }}")
WITH (fillfactor=100);
{% elif col is mapping -%}
{% for key, value in col.items() -%}
CREATE_INDEX {{ params.table }}_{{ key|replace('.', '_') }}_{{ value }}_idx
ON {{ params.schema }}.{{ params.table }} ({{ key }}({{ value }}))
WITH (fillfactor=100);
{% endfor %}
{% elif col is sequence -%}
CREATE INDEX {{ params.table }}_{{ col|join("_") }}_idx
ON {{ params.schema }}.{{ params.table }} ("{{ col|join('", "') }}")
WITH (fillfactor=100);
{% else -%}
-- Unknown type: {{ col }}.
{% endif -%}
{% endfor %}
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return PostgresOperator(task_id="index_columns",
postgres_conn_id=connection,
sql=f"sql/{sql_basename}",
params={'schema': schema, 'table': table, 'columns': columns})


def primary_key(connection, schema, primary_keys, overwrite=False):
"""Create a primary key on one or more tables in `schema`.

Parameters
----------
connection : :class:`str`
An Airflow database connection string.
schema : :class:`str`
The name of the database schema.
primary_keys : :class:`dict`
A dictionary containing the of the table in `schema` mapped to the
primary key column(s). See below for details.
overwrite : :class:`bool`, optional
If ``True`` replace any existing SQL template file.

Returns
-------
:class:`~airflow.providers.postgres.operators.postgres.PostgresOperator`
A task to create a q3c index

Notes
-----
`primary_keys` may be a :class:`dict` containing multiple types:

* The key is the table name within `schema`.
* The value can be:

- :class:`str`: create a primary key on one column.
- :class:`tuple`: create a primary key on the set of columns in the tuple.
- Any other type will be ignored.
"""
sql_dir = ensure_sql()
sql_basename = "dlairflow.postgresql.primary_key.sql"
sql_file = os.path.join(sql_dir, sql_basename)
if overwrite or not os.path.exists(sql_file):
sql_data = """--
-- Created by dlairflow.postgresql.primary_key().
-- Call primary_key(..., overwrite=True) to replace this file.
--
{% for table, columns in params.primary_keys.items() %}
{% if columns is string -%}
ALTER TABLE {{ params.schema }}.{{ table }} ADD PRIMARY KEY ("{{ columns }}");
{% elif columns is sequence -%}
ALTER TABLE {{ params.schema }}.{{ table }} ADD PRIMARY KEY ("{{ columns|join('", "') }}");
{% else -%}
-- Unknown type: {{ columns }}.
{% endif -%}
{% endfor %}
"""
with open(sql_file, 'w') as s:
s.write(sql_data)
return PostgresOperator(task_id="primary_key",
postgres_conn_id=connection,
sql=f"sql/{sql_basename}",
params={'schema': schema, 'primary_keys': primary_keys})
Loading