From 0e34139c5902f41799765427ddb96732aa762044 Mon Sep 17 00:00:00 2001 From: DanCardin Date: Tue, 21 May 2024 14:41:22 -0400 Subject: [PATCH] feat: Add function security option to postgres function. --- .../dialects/postgresql/function.py | 38 ++++++- .../dialects/postgresql/query.py | 8 +- .../dialects/postgresql/schema.py | 2 + .../function/base.py | 25 ++--- tests/function/test_pg_security_definer.py | 105 ++++++++++++++++++ 5 files changed, 161 insertions(+), 17 deletions(-) create mode 100644 tests/function/test_pg_security_definer.py diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py index cf81df4..3144e57 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py @@ -1,19 +1,29 @@ from __future__ import annotations +import enum import textwrap from dataclasses import dataclass, replace from sqlalchemy_declarative_extensions.function import base +@enum.unique +class FunctionSecurity(enum.Enum): + invoker = "INVOKER" + definer = "DEFINER" + + @dataclass class Function(base.Function): """Describes a PostgreSQL function. Many function attributes are not currently supported. Support is **currently** - minimal due to being a means to an end for defining triggers. + minimal due to being a means to an end for defining triggers, but can certainly + be evaluated/added on request. """ + security: FunctionSecurity = FunctionSecurity.invoker + @classmethod def from_unknown_function(cls, f: base.Function | Function) -> Function: if not isinstance(f, Function): @@ -27,6 +37,26 @@ def from_unknown_function(cls, f: base.Function | Function) -> Function: return f + def to_sql_create(self, replace=False): + components = ["CREATE"] + + if replace: + components.append("OR REPLACE") + + components.append("FUNCTION") + components.append(self.qualified_name + "()") + + if self.returns: + components.append(f"RETURNS {self.returns}") + + if self.security == FunctionSecurity.definer: + components.append("SECURITY DEFINER") + + components.append(f"LANGUAGE {self.language}") + components.append(f"AS $${self.definition}$$") + + return " ".join(components) + ";" + def normalize(self) -> Function: returns = self.returns.lower() definition = textwrap.dedent(self.definition) @@ -34,6 +64,12 @@ def normalize(self) -> Function: self, returns=type_map.get(returns, returns), definition=definition ) + def with_security(self, security: FunctionSecurity): + return replace(self, security=security) + + def with_security_definer(self): + return replace(self, security=FunctionSecurity.definer) + type_map = { "bigint": "int8", diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py index 30cf75d..e1c5078 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py @@ -10,7 +10,10 @@ parse_acl, parse_default_acl, ) -from sqlalchemy_declarative_extensions.dialects.postgresql.function import Function +from sqlalchemy_declarative_extensions.dialects.postgresql.function import ( + Function, + FunctionSecurity, +) from sqlalchemy_declarative_extensions.dialects.postgresql.role import Role from sqlalchemy_declarative_extensions.dialects.postgresql.schema import ( default_acl_query, @@ -165,6 +168,9 @@ def get_functions_postgresql(connection: Connection) -> list[Function]: returns=f.return_type, language=f.language, schema=f.schema if f.schema != "public" else None, + security=FunctionSecurity.definer + if f.security_definer + else FunctionSecurity.invoker, ) functions.append(function) return functions diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py index b959786..b234dd5 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py @@ -77,6 +77,7 @@ column("pronamespace"), column("prolang"), column("prorettype"), + column("prosecdef"), ) pg_language = table( @@ -258,6 +259,7 @@ def _schema_not_pg(column=pg_namespace.c.nspname): pg_language.c.lanname.label("language"), pg_type.c.typname.label("return_type"), pg_proc.c.prosrc.label("source"), + pg_proc.c.prosecdef.label("security_definer"), ) .select_from( pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid) diff --git a/src/sqlalchemy_declarative_extensions/function/base.py b/src/sqlalchemy_declarative_extensions/function/base.py index 0e6b0dc..7a23945 100644 --- a/src/sqlalchemy_declarative_extensions/function/base.py +++ b/src/sqlalchemy_declarative_extensions/function/base.py @@ -31,21 +31,7 @@ def normalize(self) -> Function: raise NotImplementedError() # pragma: no cover def to_sql_create(self, replace=False): - components = ["CREATE"] - - if replace: - components.append("OR REPLACE") - - components.append("FUNCTION") - components.append(self.qualified_name + "()") - - if self.returns: - components.append(f"RETURNS {self.returns}") - - components.append(f"LANGUAGE {self.language}") - components.append(f"AS $${self.definition}$$") - - return " ".join(components) + ";" + raise NotImplementedError() def to_sql_update(self): return self.to_sql_create(replace=True) @@ -53,6 +39,15 @@ def to_sql_update(self): def to_sql_drop(self): return f"DROP FUNCTION {self.qualified_name}();" + def with_name(self, name: str): + return replace(self, name=name) + + def with_language(self, language: str): + return replace(self, language=language) + + def with_return_type(self, return_type: str): + return replace(self, returns=return_type) + @dataclass class Functions: diff --git a/tests/function/test_pg_security_definer.py b/tests/function/test_pg_security_definer.py new file mode 100644 index 0000000..d381673 --- /dev/null +++ b/tests/function/test_pg_security_definer.py @@ -0,0 +1,105 @@ +import pytest +from pytest_mock_resources import create_postgres_fixture +from sqlalchemy import Column, text, types +from sqlalchemy.exc import ProgrammingError + +from sqlalchemy_declarative_extensions import ( + declarative_database, + register_function, + register_sqlalchemy_events, +) +from sqlalchemy_declarative_extensions.dialects.postgresql import Function +from sqlalchemy_declarative_extensions.dialects.postgresql.grant import ( + DefaultGrant, +) +from sqlalchemy_declarative_extensions.dialects.postgresql.role import Role +from sqlalchemy_declarative_extensions.function.compare import compare_functions +from sqlalchemy_declarative_extensions.grant.base import Grants +from sqlalchemy_declarative_extensions.role.base import Roles +from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + roles = Roles(ignore_unspecified=True).are(Role("can_insert"), Role("cant_insert")) + grants = Grants().are( + DefaultGrant.on_tables_in_schema("public").grant("select", to="cant_insert"), + DefaultGrant.on_tables_in_schema("public") + .grant("insert", to="cant_insert") + .invert(), + DefaultGrant.on_tables_in_schema("public").grant( + "insert", "select", to="can_insert" + ), + ) + + +class Foo(Base): + __tablename__ = "foo" + + id = Column(types.Integer(), primary_key=True) + + +function_invoker = ( + Function( + "add_one_invoker", + """ + DECLARE + m INTEGER; + BEGIN + SELECT coalesce(max(id), 0) + 1 INTO m FROM foo; + INSERT INTO foo (id) VALUES (m); + RETURN m; + END + """, + ) + .with_return_type("INTEGER") + .with_language("plpgsql") +) + +function_definer = function_invoker.with_security_definer().with_name("add_one_definer") + +register_function(Base.metadata, function_invoker) +register_function(Base.metadata, function_definer) + + +register_sqlalchemy_events(Base.metadata, roles=True, grants=True, functions=True) + +pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True) + + +def test_function_security(pg): + Base.metadata.create_all(bind=pg.connection()) + pg.commit() + + # Permission decided by invoker + pg.execute(text("SET ROLE cant_insert")) + with pytest.raises(ProgrammingError) as e: + result = pg.execute(text("SELECT add_one_invoker()")).scalar() + assert "permission denied for table foo" in str(e) + + pg.rollback() + + pg.execute(text("SET ROLE can_insert")) + result = pg.execute(text("SELECT add_one_invoker()")).scalar() + assert result == 1 + + # Permission decided by definer + pg.execute(text("SET ROLE cant_insert")) + result = pg.execute(text("SELECT add_one_definer()")).scalar() + assert result == 2 + + pg.execute(text("SET ROLE can_insert")) + result = pg.execute(text("SELECT add_one_definer()")).scalar() + assert result == 3 + + # Just a double check + result = pg.query(Foo.id).count() + assert result == 3 + + connection = pg.connection() + diff = compare_functions(connection, Base.metadata.info["functions"], Base.metadata) + assert diff == []