Skip to content

Commit

Permalink
Merge pull request #62 from DanCardin/dc/function-security
Browse files Browse the repository at this point in the history
feat: Add function security option to postgres function.
  • Loading branch information
DanCardin authored May 27, 2024
2 parents 55c9c0f + 0e34139 commit 23383bb
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from __future__ import annotations

import enum
import textwrap
from dataclasses import dataclass, replace

from sqlalchemy_declarative_extensions.function import base


@enum.unique
class FunctionSecurity(enum.Enum):
invoker = "INVOKER"
definer = "DEFINER"


@dataclass
class Function(base.Function):
"""Describes a PostgreSQL function.
Many function attributes are not currently supported. Support is **currently**
minimal due to being a means to an end for defining triggers.
minimal due to being a means to an end for defining triggers, but can certainly
be evaluated/added on request.
"""

security: FunctionSecurity = FunctionSecurity.invoker

@classmethod
def from_unknown_function(cls, f: base.Function | Function) -> Function:
if not isinstance(f, Function):
Expand All @@ -27,13 +37,39 @@ def from_unknown_function(cls, f: base.Function | Function) -> Function:

return f

def to_sql_create(self, replace=False):
components = ["CREATE"]

if replace:
components.append("OR REPLACE")

components.append("FUNCTION")
components.append(self.qualified_name + "()")

if self.returns:
components.append(f"RETURNS {self.returns}")

if self.security == FunctionSecurity.definer:
components.append("SECURITY DEFINER")

components.append(f"LANGUAGE {self.language}")
components.append(f"AS $${self.definition}$$")

return " ".join(components) + ";"

def normalize(self) -> Function:
returns = self.returns.lower()
definition = textwrap.dedent(self.definition)
return replace(
self, returns=type_map.get(returns, returns), definition=definition
)

def with_security(self, security: FunctionSecurity):
return replace(self, security=security)

def with_security_definer(self):
return replace(self, security=FunctionSecurity.definer)


type_map = {
"bigint": "int8",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
parse_acl,
parse_default_acl,
)
from sqlalchemy_declarative_extensions.dialects.postgresql.function import Function
from sqlalchemy_declarative_extensions.dialects.postgresql.function import (
Function,
FunctionSecurity,
)
from sqlalchemy_declarative_extensions.dialects.postgresql.role import Role
from sqlalchemy_declarative_extensions.dialects.postgresql.schema import (
default_acl_query,
Expand Down Expand Up @@ -165,6 +168,9 @@ def get_functions_postgresql(connection: Connection) -> list[Function]:
returns=f.return_type,
language=f.language,
schema=f.schema if f.schema != "public" else None,
security=FunctionSecurity.definer
if f.security_definer
else FunctionSecurity.invoker,
)
functions.append(function)
return functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
column("pronamespace"),
column("prolang"),
column("prorettype"),
column("prosecdef"),
)

pg_language = table(
Expand Down Expand Up @@ -258,6 +259,7 @@ def _schema_not_pg(column=pg_namespace.c.nspname):
pg_language.c.lanname.label("language"),
pg_type.c.typname.label("return_type"),
pg_proc.c.prosrc.label("source"),
pg_proc.c.prosecdef.label("security_definer"),
)
.select_from(
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
Expand Down
25 changes: 10 additions & 15 deletions src/sqlalchemy_declarative_extensions/function/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,23 @@ def normalize(self) -> Function:
raise NotImplementedError() # pragma: no cover

def to_sql_create(self, replace=False):
components = ["CREATE"]

if replace:
components.append("OR REPLACE")

components.append("FUNCTION")
components.append(self.qualified_name + "()")

if self.returns:
components.append(f"RETURNS {self.returns}")

components.append(f"LANGUAGE {self.language}")
components.append(f"AS $${self.definition}$$")

return " ".join(components) + ";"
raise NotImplementedError()

def to_sql_update(self):
return self.to_sql_create(replace=True)

def to_sql_drop(self):
return f"DROP FUNCTION {self.qualified_name}();"

def with_name(self, name: str):
return replace(self, name=name)

def with_language(self, language: str):
return replace(self, language=language)

def with_return_type(self, return_type: str):
return replace(self, returns=return_type)


@dataclass
class Functions:
Expand Down
105 changes: 105 additions & 0 deletions tests/function/test_pg_security_definer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import Column, text, types
from sqlalchemy.exc import ProgrammingError

from sqlalchemy_declarative_extensions import (
declarative_database,
register_function,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.dialects.postgresql import Function
from sqlalchemy_declarative_extensions.dialects.postgresql.grant import (
DefaultGrant,
)
from sqlalchemy_declarative_extensions.dialects.postgresql.role import Role
from sqlalchemy_declarative_extensions.function.compare import compare_functions
from sqlalchemy_declarative_extensions.grant.base import Grants
from sqlalchemy_declarative_extensions.role.base import Roles
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

_Base = declarative_base()


@declarative_database
class Base(_Base): # type: ignore
__abstract__ = True

roles = Roles(ignore_unspecified=True).are(Role("can_insert"), Role("cant_insert"))
grants = Grants().are(
DefaultGrant.on_tables_in_schema("public").grant("select", to="cant_insert"),
DefaultGrant.on_tables_in_schema("public")
.grant("insert", to="cant_insert")
.invert(),
DefaultGrant.on_tables_in_schema("public").grant(
"insert", "select", to="can_insert"
),
)


class Foo(Base):
__tablename__ = "foo"

id = Column(types.Integer(), primary_key=True)


function_invoker = (
Function(
"add_one_invoker",
"""
DECLARE
m INTEGER;
BEGIN
SELECT coalesce(max(id), 0) + 1 INTO m FROM foo;
INSERT INTO foo (id) VALUES (m);
RETURN m;
END
""",
)
.with_return_type("INTEGER")
.with_language("plpgsql")
)

function_definer = function_invoker.with_security_definer().with_name("add_one_definer")

register_function(Base.metadata, function_invoker)
register_function(Base.metadata, function_definer)


register_sqlalchemy_events(Base.metadata, roles=True, grants=True, functions=True)

pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)


def test_function_security(pg):
Base.metadata.create_all(bind=pg.connection())
pg.commit()

# Permission decided by invoker
pg.execute(text("SET ROLE cant_insert"))
with pytest.raises(ProgrammingError) as e:
result = pg.execute(text("SELECT add_one_invoker()")).scalar()
assert "permission denied for table foo" in str(e)

pg.rollback()

pg.execute(text("SET ROLE can_insert"))
result = pg.execute(text("SELECT add_one_invoker()")).scalar()
assert result == 1

# Permission decided by definer
pg.execute(text("SET ROLE cant_insert"))
result = pg.execute(text("SELECT add_one_definer()")).scalar()
assert result == 2

pg.execute(text("SET ROLE can_insert"))
result = pg.execute(text("SELECT add_one_definer()")).scalar()
assert result == 3

# Just a double check
result = pg.query(Foo.id).count()
assert result == 3

connection = pg.connection()
diff = compare_functions(connection, Base.metadata.info["functions"], Base.metadata)
assert diff == []

0 comments on commit 23383bb

Please sign in to comment.