Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #5018 remove global uid in flask #5064

Merged
merged 32 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions sirepo/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@
_cfg = None


def hack_logged_in_user():
# avoids case of no quest (sirepo.agent)
return _cfg.logged_in_user or sirepo.quest.hack_current().auth.logged_in_user()


def init_quest(qcall):
o = _Auth(qcall=qcall)
qcall.attr_set("auth", o)
Expand All @@ -107,7 +102,6 @@ def init_module(**imports):
def _init_full():
global visible_methods, valid_methods, non_guest_methods

simulation_db.hook_auth_user = hack_logged_in_user
p = pkinspect.this_module().__name__
visible_methods = []
valid_methods = _cfg.methods.union(_cfg.deprecated_methods)
Expand Down Expand Up @@ -164,8 +158,8 @@ def check_sim_type_role(self, sim_type):
u = self.logged_in_user()
r = sirepo.auth_role.for_sim_type(t)
if sirepo.auth_db.UserRole.has_role(
u, r
) and not sirepo.auth_db.UserRole.is_expired(u, r):
qcall=self.qcall, role=r
) and not sirepo.auth_db.UserRole.is_expired(qcall=self.qcall, role=r):
return
elif r in sirepo.auth_role.for_proprietary_oauth_sim_types():
oauth.raise_authorize_redirect(self.qcall, sirepo.auth_role.sim_type(r))
Expand Down Expand Up @@ -225,10 +219,14 @@ def cookie_cleaner(self, values):
values[_COOKIE_STATE] = _STATE_LOGGED_OUT
return values

def create_user(self, uid_generated_callback, module):
def create_user(self, module, want_login=False):
u = simulation_db.user_create()
uid_generated_callback(u)
self._create_roles_for_new_user(u, module.AUTH_METHOD)
if want_login:
self._login_user(module, u)
self._create_roles_for_new_user(module.AUTH_METHOD)
else:
with self.logged_in_user_set(u, method=module.AUTH_METHOD):
self._create_roles_for_new_user(module.AUTH_METHOD)
return u

def get_module(self, name):
Expand All @@ -253,8 +251,8 @@ def is_logged_in(self, state=None):

def is_premium_user(self):
return sirepo.auth_db.UserRole.has_role(
self.logged_in_user(),
sirepo.auth_role.ROLE_PAYMENT_PLAN_PREMIUM,
qcall=self.qcall,
role=sirepo.auth_role.ROLE_PAYMENT_PLAN_PREMIUM,
)

def logged_in_user(self, check_path=True):
Expand All @@ -265,8 +263,6 @@ def logged_in_user(self, check_path=True):
Returns:
str: uid of authenticated user
"""
if self._logged_in_user:
return self._logged_in_user
u = self._qcall_bound_user()
if not self.is_logged_in():
raise sirepo.util.SRException(
Expand All @@ -280,7 +276,7 @@ def logged_in_user(self, check_path=True):
self._qcall_bound_method(),
)
if check_path:
simulation_db.user_path(u, check=True)
simulation_db.user_path(uid=u, check=True)
return u

def logged_in_user_name(self):
Expand All @@ -290,10 +286,18 @@ def logged_in_user_name(self):
method=self._qcall_bound_method(),
)

@contextlib.contextmanager
def logged_in_user_set(self, uid, method=METHOD_GUEST):
"""Ephemeral login"""
self._logged_in_user = uid
self._logged_in_method = method
"""Ephemeral login or may be used to logout"""
u = self._logged_in_user
m = self._logged_in_method
try:
self._logged_in_user = uid
self._logged_in_method = None if uid is None else method
yield
finally:
self._logged_in_user = u
self._logged_in_method = m

def login(
self,
Expand Down Expand Up @@ -357,7 +361,7 @@ def login(
# This handles the case where logging in as guest, creates a user every time
self._login_user(method, uid)
else:
uid = self.create_user(lambda u: self._login_user(method, u), method)
uid = self.create_user(method, want_login=True)
if model:
model.uid = uid
model.save()
Expand Down Expand Up @@ -444,7 +448,10 @@ def parse_display_name(self, value):

def require_adm(self):
u = self.require_user()
if not sirepo.auth_db.UserRole.has_role(u, sirepo.auth_role.ROLE_ADM):
if not sirepo.auth_db.UserRole.has_role(
qcall=self.qcall,
role=sirepo.auth_role.ROLE_ADM,
):
sirepo.util.raise_forbidden(f"uid={u} role=ROLE_ADM not found")

def require_auth_basic(self):
Expand Down Expand Up @@ -634,7 +641,7 @@ def _get_slack_uri():
r = sirepo.auth_db.UserRegistration.search_by(uid=u)
if r:
v.displayName = r.display_name
v.roles = sirepo.auth_db.UserRole.get_roles(u)
v.roles = sirepo.auth_db.UserRole.get_roles(qcall=self.qcall)
self._plan(v)
self._method_auth_state(v, u)
if pkconfig.channel_in_internal_test():
Expand All @@ -643,10 +650,10 @@ def _get_slack_uri():
pkdc("state={}", v)
return v

def _create_roles_for_new_user(self, uid, method):
r = sirepo.auth_role.for_new_user(method == METHOD_GUEST)
def _create_roles_for_new_user(self, method):
r = sirepo.auth_role.for_new_user(is_guest=method == METHOD_GUEST)
if r:
sirepo.auth_db.UserRole.add_roles(uid, r)
sirepo.auth_db.UserRole.add_roles(qcall=self.qcall, roles=r)

def _login_user(self, module, uid):
"""Set up the cookie for logged in state
Expand Down Expand Up @@ -713,7 +720,7 @@ def _qcall_bound_state(self):
return self.qcall.cookie.unchecked_get_value(_COOKIE_STATE)

def _qcall_bound_user(self):
return _cfg.logged_in_user or self.qcall.cookie.unchecked_get_value(
return self._logged_in_user or self.qcall.cookie.unchecked_get_value(
_COOKIE_USER
)

Expand Down
4 changes: 1 addition & 3 deletions sirepo/auth/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,5 @@ def _cfg_uid(value):

if value and value == "dev-no-validate" and pkconfig.channel_in_internal_test():
return value
assert simulation_db.user_path(value).check(
dir=True
), "uid={} does not exist".format(value)
simulation_db.user_path(uid=value, check=True)
return value
2 changes: 1 addition & 1 deletion sirepo/auth/bluesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def api_authBlueskyLogin(self):
)
return self.reply_ok(
PKDict(
data=simulation_db.open_json_file(req.type, sid=req.id),
data=simulation_db.open_json_file(req.type, sid=req.id, qcall=self),
schema=simulation_db.get_schema(req.type),
),
)
Expand Down
73 changes: 41 additions & 32 deletions sirepo/auth_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@
UserRoleInvite = None


def all_uids():
def all_uids(qcall):
return UserRegistration.search_all_for_column("uid")


def audit_proprietary_lib_files(uid, force=False, sim_types=None):
def audit_proprietary_lib_files(qcall, force=False, sim_types=None):
"""Add/removes proprietary files based on a user's roles

For example, add the Flash tarball if user has the flash role.

Args:
uid (str): The uid of the user to audit
qcall (quest.API): logged in user
force (bool): Overwrite existing lib files with the same name as new ones
sim_types (set): Set of sim_types to audit (proprietary_sim_types if None)
"""
Expand All @@ -66,7 +66,7 @@ def audit_proprietary_lib_files(uid, force=False, sim_types=None):

def _add(proprietary_code_dir, sim_type, sim_data_class):
p = proprietary_code_dir.join(sim_data_class.proprietary_code_tarball())
with sirepo.simulation_db.tmp_dir(chdir=True, uid=uid) as t:
with sirepo.simulation_db.tmp_dir(chdir=True, qcall=qcall) as t:
d = t.join(p.basename)
d.mksymlinkto(p, absolute=False)
subprocess.check_output(
Expand All @@ -80,7 +80,7 @@ def _add(proprietary_code_dir, sim_type, sim_data_class):
)
# lib_dir may not exist: git.radiasoft.org/ops/issues/645
l = pykern.pkio.mkdir_parent(
sirepo.simulation_db.simulation_lib_dir(sim_type, uid=uid),
sirepo.simulation_db.simulation_lib_dir(sim_type, qcall=qcall),
)
e = [f.basename for f in pykern.pkio.sorted_glob(l.join("*"))]
for f in sim_data_class.proprietary_code_lib_file_basenames():
Expand All @@ -93,6 +93,7 @@ def _add(proprietary_code_dir, sim_type, sim_data_class):
s
), f"sim_types={sim_types} not a subset of proprietary_sim_types={s}"
s = sim_types
u = qcall.auth.logged_in_user()
for t in s:
c = sirepo.sim_data.get_class(t)
if not c.proprietary_code_tarball():
Expand All @@ -102,14 +103,16 @@ def _add(proprietary_code_dir, sim_type, sim_data_class):
"; run: sirepo setup_dev" if pykern.pkconfig.channel_in("dev") else ""
)
r = UserRole.has_role(
uid,
sirepo.auth_role.for_sim_type(t),
qcall=qcall,
role=sirepo.auth_role.for_sim_type(t),
)
if r:
_add(d, t, c)
continue
# SECURITY: User no longer has access so remove all artifacts
pykern.pkio.unchecked_remove(sirepo.simulation_db.simulation_dir(t, uid=uid))
pykern.pkio.unchecked_remove(
sirepo.simulation_db.simulation_dir(t, qcall=qcall)
)


def db_filename():
Expand Down Expand Up @@ -281,60 +284,66 @@ def all_roles(cls):
return [r[0] for r in cls._session().query(cls.role.distinct()).all()]

@classmethod
def add_roles(cls, uid, role_or_roles, expiration=None):
if isinstance(role_or_roles, str):
role_or_roles = [role_or_roles]
def add_roles(cls, qcall, roles, expiration=None):
with sirepo.util.THREAD_LOCK:
for r in role_or_roles:
u = qcall.auth.logged_in_user()
for r in roles:
try:
UserRole(uid=uid, role=r, expiration=expiration).save()
UserRole(
uid=u,
role=r,
expiration=expiration,
).save()
except sqlalchemy.exc.IntegrityError:
pass
audit_proprietary_lib_files(uid)
audit_proprietary_lib_files(qcall=qcall)

@classmethod
def add_role_or_update_expiration(cls, uid, role, expiration):
def add_role_or_update_expiration(cls, qcall, role, expiration):
with sirepo.util.THREAD_LOCK:
if not cls.has_role(uid, role):
cls.add_roles(uid, role, expiration=expiration)
if not cls.has_role(qcall, role):
cls.add_roles(qcall=qcall, roles=[role], expiration=expiration)
return
r = cls.search_by(uid=uid, role=role)
r = cls.search_by(uid=qcall.auth.logged_in_user(), role=role)
r.expiration = expiration
r.save()

@classmethod
def delete_roles(cls, uid, roles):
def delete_roles(cls, qcall, roles):
with sirepo.util.THREAD_LOCK:
cls.execute(
sqlalchemy.delete(cls)
.where(
cls.uid == uid,
cls.uid == qcall.auth.logged_in_user(),
)
.where(
cls.role.in_(roles),
)
)
cls._session().commit()
audit_proprietary_lib_files(uid)
audit_proprietary_lib_files(qcall=qcall)

@classmethod
def get_roles(cls, uid):
def get_roles(cls, qcall):
with sirepo.util.THREAD_LOCK:
return UserRole.search_all_for_column(
"role",
uid=uid,
uid=qcall.auth.logged_in_user(),
)

@classmethod
def has_role(cls, uid, role):
def has_role(cls, qcall, role):
with sirepo.util.THREAD_LOCK:
return bool(cls.search_by(uid=uid, role=role))
return bool(cls.search_by(uid=qcall.auth.logged_in_user(), role=role))

@classmethod
def is_expired(cls, uid, role):
def is_expired(cls, qcall, role):
with sirepo.util.THREAD_LOCK:
assert cls.has_role(uid, role), f"No role for uid={uid} and role={role}"
r = cls.search_by(uid=uid, role=role)
u = qcall.auth.logged_in_user()
assert cls.has_role(
qcall=qcall, role=role
), f"No role for uid={u} and role={role}"
r = cls.search_by(uid=u, role=role)
if not r.expiration:
# Roles with no expiration can't expire
return False
Expand Down Expand Up @@ -392,17 +401,17 @@ def get_moderation_request_rows(cls, qcall):
return [PKDict(zip(r.keys(), r)) for r in q]

@classmethod
def get_status(cls, uid, role):
def get_status(cls, qcall, role):
with sirepo.util.THREAD_LOCK:
s = cls.search_by(uid=uid, role=role)
s = cls.search_by(uid=qcall.auth.logged_in_user(), role=role)
if not s:
return None
return sirepo.auth_role.ModerationStatus.check(s.status)

@classmethod
def set_status(cls, uid, role, status, moderator_uid=None):
def set_status(cls, qcall, role, status, moderator_uid):
with sirepo.util.THREAD_LOCK:
s = cls.search_by(uid=uid, role=role)
s = cls.search_by(uid=qcall.auth.logged_in_user(), role=role)
s.status = sirepo.auth_role.ModerationStatus.check(status)
if moderator_uid:
s.moderator_uid = moderator_uid
Expand Down
Loading