Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add function security option to postgres function. #62

Merged
merged 1 commit into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from __future__ import annotations

import enum
import textwrap
from dataclasses import dataclass, replace

from sqlalchemy_declarative_extensions.function import base


@enum.unique
class FunctionSecurity(enum.Enum):
invoker = "INVOKER"
definer = "DEFINER"


@dataclass
class Function(base.Function):
"""Describes a PostgreSQL function.

Many function attributes are not currently supported. Support is **currently**
minimal due to being a means to an end for defining triggers.
minimal due to being a means to an end for defining triggers, but can certainly
be evaluated/added on request.
"""

security: FunctionSecurity = FunctionSecurity.invoker

@classmethod
def from_unknown_function(cls, f: base.Function | Function) -> Function:
if not isinstance(f, Function):
Expand All @@ -27,13 +37,39 @@ def from_unknown_function(cls, f: base.Function | Function) -> Function:

return f

def to_sql_create(self, replace=False):
components = ["CREATE"]

if replace:
components.append("OR REPLACE")

components.append("FUNCTION")
components.append(self.qualified_name + "()")

if self.returns:
components.append(f"RETURNS {self.returns}")

if self.security == FunctionSecurity.definer:
components.append("SECURITY DEFINER")

components.append(f"LANGUAGE {self.language}")
components.append(f"AS $${self.definition}$$")

return " ".join(components) + ";"

def normalize(self) -> Function:
returns = self.returns.lower()
definition = textwrap.dedent(self.definition)
return replace(
self, returns=type_map.get(returns, returns), definition=definition
)

def with_security(self, security: FunctionSecurity):
return replace(self, security=security)

def with_security_definer(self):
return replace(self, security=FunctionSecurity.definer)


type_map = {
"bigint": "int8",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
parse_acl,
parse_default_acl,
)
from sqlalchemy_declarative_extensions.dialects.postgresql.function import Function
from sqlalchemy_declarative_extensions.dialects.postgresql.function import (
Function,
FunctionSecurity,
)
from sqlalchemy_declarative_extensions.dialects.postgresql.role import Role
from sqlalchemy_declarative_extensions.dialects.postgresql.schema import (
default_acl_query,
Expand Down Expand Up @@ -165,6 +168,9 @@ def get_functions_postgresql(connection: Connection) -> list[Function]:
returns=f.return_type,
language=f.language,
schema=f.schema if f.schema != "public" else None,
security=FunctionSecurity.definer
if f.security_definer
else FunctionSecurity.invoker,
)
functions.append(function)
return functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
column("pronamespace"),
column("prolang"),
column("prorettype"),
column("prosecdef"),
)

pg_language = table(
Expand Down Expand Up @@ -258,6 +259,7 @@ def _schema_not_pg(column=pg_namespace.c.nspname):
pg_language.c.lanname.label("language"),
pg_type.c.typname.label("return_type"),
pg_proc.c.prosrc.label("source"),
pg_proc.c.prosecdef.label("security_definer"),
)
.select_from(
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
Expand Down
25 changes: 10 additions & 15 deletions src/sqlalchemy_declarative_extensions/function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,23 @@ def normalize(self) -> Function:
raise NotImplementedError() # pragma: no cover

def to_sql_create(self, replace=False):
components = ["CREATE"]

if replace:
components.append("OR REPLACE")

components.append("FUNCTION")
components.append(self.qualified_name + "()")

if self.returns:
components.append(f"RETURNS {self.returns}")

components.append(f"LANGUAGE {self.language}")
components.append(f"AS $${self.definition}$$")

return " ".join(components) + ";"
raise NotImplementedError()

def to_sql_update(self):
return self.to_sql_create(replace=True)

def to_sql_drop(self):
return f"DROP FUNCTION {self.qualified_name}();"

def with_name(self, name: str):
return replace(self, name=name)

def with_language(self, language: str):
return replace(self, language=language)

def with_return_type(self, return_type: str):
return replace(self, returns=return_type)


@dataclass
class Functions:
Expand Down
105 changes: 105 additions & 0 deletions tests/function/test_pg_security_definer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import Column, text, types
from sqlalchemy.exc import ProgrammingError

from sqlalchemy_declarative_extensions import (
declarative_database,
register_function,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.dialects.postgresql import Function
from sqlalchemy_declarative_extensions.dialects.postgresql.grant import (
DefaultGrant,
)
from sqlalchemy_declarative_extensions.dialects.postgresql.role import Role
from sqlalchemy_declarative_extensions.function.compare import compare_functions
from sqlalchemy_declarative_extensions.grant.base import Grants
from sqlalchemy_declarative_extensions.role.base import Roles
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

_Base = declarative_base()


@declarative_database
class Base(_Base): # type: ignore
__abstract__ = True

roles = Roles(ignore_unspecified=True).are(Role("can_insert"), Role("cant_insert"))
grants = Grants().are(
DefaultGrant.on_tables_in_schema("public").grant("select", to="cant_insert"),
DefaultGrant.on_tables_in_schema("public")
.grant("insert", to="cant_insert")
.invert(),
DefaultGrant.on_tables_in_schema("public").grant(
"insert", "select", to="can_insert"
),
)


class Foo(Base):
__tablename__ = "foo"

id = Column(types.Integer(), primary_key=True)


function_invoker = (
Function(
"add_one_invoker",
"""
DECLARE
m INTEGER;
BEGIN
SELECT coalesce(max(id), 0) + 1 INTO m FROM foo;
INSERT INTO foo (id) VALUES (m);
RETURN m;
END
""",
)
.with_return_type("INTEGER")
.with_language("plpgsql")
)

function_definer = function_invoker.with_security_definer().with_name("add_one_definer")

register_function(Base.metadata, function_invoker)
register_function(Base.metadata, function_definer)


register_sqlalchemy_events(Base.metadata, roles=True, grants=True, functions=True)

pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)


def test_function_security(pg):
Base.metadata.create_all(bind=pg.connection())
pg.commit()

# Permission decided by invoker
pg.execute(text("SET ROLE cant_insert"))
with pytest.raises(ProgrammingError) as e:
result = pg.execute(text("SELECT add_one_invoker()")).scalar()
assert "permission denied for table foo" in str(e)

pg.rollback()

pg.execute(text("SET ROLE can_insert"))
result = pg.execute(text("SELECT add_one_invoker()")).scalar()
assert result == 1

# Permission decided by definer
pg.execute(text("SET ROLE cant_insert"))
result = pg.execute(text("SELECT add_one_definer()")).scalar()
assert result == 2

pg.execute(text("SET ROLE can_insert"))
result = pg.execute(text("SELECT add_one_definer()")).scalar()
assert result == 3

# Just a double check
result = pg.query(Foo.id).count()
assert result == 3

connection = pg.connection()
diff = compare_functions(connection, Base.metadata.info["functions"], Base.metadata)
assert diff == []
Loading