Skip to content

Commit

Permalink
#5118 - merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeilman committed Nov 14, 2022
2 parents 3fd1525 + acaf9d4 commit 6126b87
Show file tree
Hide file tree
Showing 110 changed files with 1,524 additions and 775 deletions.
59 changes: 33 additions & 26 deletions sirepo/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@
_cfg = None


def hack_logged_in_user():
# avoids case of no quest (sirepo.agent)
return _cfg.logged_in_user or sirepo.quest.hack_current().auth.logged_in_user()


def init_quest(qcall):
o = _Auth(qcall=qcall)
qcall.attr_set("auth", o)
Expand All @@ -107,7 +102,6 @@ def init_module(**imports):
def _init_full():
global visible_methods, valid_methods, non_guest_methods

simulation_db.hook_auth_user = hack_logged_in_user
p = pkinspect.this_module().__name__
visible_methods = []
valid_methods = _cfg.methods.union(_cfg.deprecated_methods)
Expand Down Expand Up @@ -164,8 +158,8 @@ def check_sim_type_role(self, sim_type):
u = self.logged_in_user()
r = sirepo.auth_role.for_sim_type(t)
if sirepo.auth_db.UserRole.has_role(
u, r
) and not sirepo.auth_db.UserRole.is_expired(u, r):
qcall=self.qcall, role=r
) and not sirepo.auth_db.UserRole.is_expired(qcall=self.qcall, role=r):
return
elif r in sirepo.auth_role.for_proprietary_oauth_sim_types():
oauth.raise_authorize_redirect(self.qcall, sirepo.auth_role.sim_type(r))
Expand Down Expand Up @@ -225,10 +219,14 @@ def cookie_cleaner(self, values):
values[_COOKIE_STATE] = _STATE_LOGGED_OUT
return values

def create_user(self, uid_generated_callback, module):
def create_user(self, module, want_login=False):
u = simulation_db.user_create()
uid_generated_callback(u)
self._create_roles_for_new_user(u, module.AUTH_METHOD)
if want_login:
self._login_user(module, u)
self._create_roles_for_new_user(module.AUTH_METHOD)
else:
with self.logged_in_user_set(u, method=module.AUTH_METHOD):
self._create_roles_for_new_user(module.AUTH_METHOD)
return u

def get_module(self, name):
Expand All @@ -253,8 +251,8 @@ def is_logged_in(self, state=None):

def is_premium_user(self):
return sirepo.auth_db.UserRole.has_role(
self.logged_in_user(),
sirepo.auth_role.ROLE_PAYMENT_PLAN_PREMIUM,
qcall=self.qcall,
role=sirepo.auth_role.ROLE_PAYMENT_PLAN_PREMIUM,
)

def logged_in_user(self, check_path=True):
Expand All @@ -265,8 +263,6 @@ def logged_in_user(self, check_path=True):
Returns:
str: uid of authenticated user
"""
if self._logged_in_user:
return self._logged_in_user
u = self._qcall_bound_user()
if not self.is_logged_in():
raise sirepo.util.SRException(
Expand All @@ -280,7 +276,7 @@ def logged_in_user(self, check_path=True):
self._qcall_bound_method(),
)
if check_path:
simulation_db.user_path(u, check=True)
simulation_db.user_path(uid=u, check=True)
return u

def logged_in_user_name(self):
Expand All @@ -290,10 +286,18 @@ def logged_in_user_name(self):
method=self._qcall_bound_method(),
)

@contextlib.contextmanager
def logged_in_user_set(self, uid, method=METHOD_GUEST):
"""Ephemeral login"""
self._logged_in_user = uid
self._logged_in_method = method
"""Ephemeral login or may be used to logout"""
u = self._logged_in_user
m = self._logged_in_method
try:
self._logged_in_user = uid
self._logged_in_method = None if uid is None else method
yield
finally:
self._logged_in_user = u
self._logged_in_method = m

def login(
self,
Expand Down Expand Up @@ -357,7 +361,7 @@ def login(
# This handles the case where logging in as guest, creates a user every time
self._login_user(method, uid)
else:
uid = self.create_user(lambda u: self._login_user(method, u), method)
uid = self.create_user(method, want_login=True)
if model:
model.uid = uid
model.save()
Expand Down Expand Up @@ -444,7 +448,10 @@ def parse_display_name(self, value):

def require_adm(self):
u = self.require_user()
if not sirepo.auth_db.UserRole.has_role(u, sirepo.auth_role.ROLE_ADM):
if not sirepo.auth_db.UserRole.has_role(
qcall=self.qcall,
role=sirepo.auth_role.ROLE_ADM,
):
sirepo.util.raise_forbidden(f"uid={u} role=ROLE_ADM not found")

def require_auth_basic(self):
Expand Down Expand Up @@ -634,7 +641,7 @@ def _get_slack_uri():
r = sirepo.auth_db.UserRegistration.search_by(uid=u)
if r:
v.displayName = r.display_name
v.roles = sirepo.auth_db.UserRole.get_roles(u)
v.roles = sirepo.auth_db.UserRole.get_roles(qcall=self.qcall)
self._plan(v)
self._method_auth_state(v, u)
if pkconfig.channel_in_internal_test():
Expand All @@ -643,10 +650,10 @@ def _get_slack_uri():
pkdc("state={}", v)
return v

def _create_roles_for_new_user(self, uid, method):
r = sirepo.auth_role.for_new_user(method == METHOD_GUEST)
def _create_roles_for_new_user(self, method):
r = sirepo.auth_role.for_new_user(is_guest=method == METHOD_GUEST)
if r:
sirepo.auth_db.UserRole.add_roles(uid, r)
sirepo.auth_db.UserRole.add_roles(qcall=self.qcall, roles=r)

def _login_user(self, module, uid):
"""Set up the cookie for logged in state
Expand Down Expand Up @@ -713,7 +720,7 @@ def _qcall_bound_state(self):
return self.qcall.cookie.unchecked_get_value(_COOKIE_STATE)

def _qcall_bound_user(self):
return _cfg.logged_in_user or self.qcall.cookie.unchecked_get_value(
return self._logged_in_user or self.qcall.cookie.unchecked_get_value(
_COOKIE_USER
)

Expand Down
4 changes: 1 addition & 3 deletions sirepo/auth/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,5 @@ def _cfg_uid(value):

if value and value == "dev-no-validate" and pkconfig.channel_in_internal_test():
return value
assert simulation_db.user_path(value).check(
dir=True
), "uid={} does not exist".format(value)
simulation_db.user_path(uid=value, check=True)
return value
2 changes: 1 addition & 1 deletion sirepo/auth/bluesky.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def api_authBlueskyLogin(self):
)
return self.reply_ok(
PKDict(
data=simulation_db.open_json_file(req.type, sid=req.id),
data=simulation_db.open_json_file(req.type, sid=req.id, qcall=self),
schema=simulation_db.get_schema(req.type),
),
)
Expand Down
73 changes: 41 additions & 32 deletions sirepo/auth_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@
UserRoleInvite = None


def all_uids():
def all_uids(qcall):
return UserRegistration.search_all_for_column("uid")


def audit_proprietary_lib_files(uid, force=False, sim_types=None):
def audit_proprietary_lib_files(qcall, force=False, sim_types=None):
"""Add/removes proprietary files based on a user's roles
For example, add the Flash tarball if user has the flash role.
Args:
uid (str): The uid of the user to audit
qcall (quest.API): logged in user
force (bool): Overwrite existing lib files with the same name as new ones
sim_types (set): Set of sim_types to audit (proprietary_sim_types if None)
"""
Expand All @@ -66,7 +66,7 @@ def audit_proprietary_lib_files(uid, force=False, sim_types=None):

def _add(proprietary_code_dir, sim_type, sim_data_class):
p = proprietary_code_dir.join(sim_data_class.proprietary_code_tarball())
with sirepo.simulation_db.tmp_dir(chdir=True, uid=uid) as t:
with sirepo.simulation_db.tmp_dir(chdir=True, qcall=qcall) as t:
d = t.join(p.basename)
d.mksymlinkto(p, absolute=False)
subprocess.check_output(
Expand All @@ -80,7 +80,7 @@ def _add(proprietary_code_dir, sim_type, sim_data_class):
)
# lib_dir may not exist: git.radiasoft.org/ops/issues/645
l = pykern.pkio.mkdir_parent(
sirepo.simulation_db.simulation_lib_dir(sim_type, uid=uid),
sirepo.simulation_db.simulation_lib_dir(sim_type, qcall=qcall),
)
e = [f.basename for f in pykern.pkio.sorted_glob(l.join("*"))]
for f in sim_data_class.proprietary_code_lib_file_basenames():
Expand All @@ -93,6 +93,7 @@ def _add(proprietary_code_dir, sim_type, sim_data_class):
s
), f"sim_types={sim_types} not a subset of proprietary_sim_types={s}"
s = sim_types
u = qcall.auth.logged_in_user()
for t in s:
c = sirepo.sim_data.get_class(t)
if not c.proprietary_code_tarball():
Expand All @@ -102,14 +103,16 @@ def _add(proprietary_code_dir, sim_type, sim_data_class):
"; run: sirepo setup_dev" if pykern.pkconfig.channel_in("dev") else ""
)
r = UserRole.has_role(
uid,
sirepo.auth_role.for_sim_type(t),
qcall=qcall,
role=sirepo.auth_role.for_sim_type(t),
)
if r:
_add(d, t, c)
continue
# SECURITY: User no longer has access so remove all artifacts
pykern.pkio.unchecked_remove(sirepo.simulation_db.simulation_dir(t, uid=uid))
pykern.pkio.unchecked_remove(
sirepo.simulation_db.simulation_dir(t, qcall=qcall)
)


def db_filename():
Expand Down Expand Up @@ -281,60 +284,66 @@ def all_roles(cls):
return [r[0] for r in cls._session().query(cls.role.distinct()).all()]

@classmethod
def add_roles(cls, uid, role_or_roles, expiration=None):
if isinstance(role_or_roles, str):
role_or_roles = [role_or_roles]
def add_roles(cls, qcall, roles, expiration=None):
with sirepo.util.THREAD_LOCK:
for r in role_or_roles:
u = qcall.auth.logged_in_user()
for r in roles:
try:
UserRole(uid=uid, role=r, expiration=expiration).save()
UserRole(
uid=u,
role=r,
expiration=expiration,
).save()
except sqlalchemy.exc.IntegrityError:
pass
audit_proprietary_lib_files(uid)
audit_proprietary_lib_files(qcall=qcall)

@classmethod
def add_role_or_update_expiration(cls, uid, role, expiration):
def add_role_or_update_expiration(cls, qcall, role, expiration):
with sirepo.util.THREAD_LOCK:
if not cls.has_role(uid, role):
cls.add_roles(uid, role, expiration=expiration)
if not cls.has_role(qcall, role):
cls.add_roles(qcall=qcall, roles=[role], expiration=expiration)
return
r = cls.search_by(uid=uid, role=role)
r = cls.search_by(uid=qcall.auth.logged_in_user(), role=role)
r.expiration = expiration
r.save()

@classmethod
def delete_roles(cls, uid, roles):
def delete_roles(cls, qcall, roles):
with sirepo.util.THREAD_LOCK:
cls.execute(
sqlalchemy.delete(cls)
.where(
cls.uid == uid,
cls.uid == qcall.auth.logged_in_user(),
)
.where(
cls.role.in_(roles),
)
)
cls._session().commit()
audit_proprietary_lib_files(uid)
audit_proprietary_lib_files(qcall=qcall)

@classmethod
def get_roles(cls, uid):
def get_roles(cls, qcall):
with sirepo.util.THREAD_LOCK:
return UserRole.search_all_for_column(
"role",
uid=uid,
uid=qcall.auth.logged_in_user(),
)

@classmethod
def has_role(cls, uid, role):
def has_role(cls, qcall, role):
with sirepo.util.THREAD_LOCK:
return bool(cls.search_by(uid=uid, role=role))
return bool(cls.search_by(uid=qcall.auth.logged_in_user(), role=role))

@classmethod
def is_expired(cls, uid, role):
def is_expired(cls, qcall, role):
with sirepo.util.THREAD_LOCK:
assert cls.has_role(uid, role), f"No role for uid={uid} and role={role}"
r = cls.search_by(uid=uid, role=role)
u = qcall.auth.logged_in_user()
assert cls.has_role(
qcall=qcall, role=role
), f"No role for uid={u} and role={role}"
r = cls.search_by(uid=u, role=role)
if not r.expiration:
# Roles with no expiration can't expire
return False
Expand Down Expand Up @@ -392,17 +401,17 @@ def get_moderation_request_rows(cls, qcall):
return [PKDict(zip(r.keys(), r)) for r in q]

@classmethod
def get_status(cls, uid, role):
def get_status(cls, qcall, role):
with sirepo.util.THREAD_LOCK:
s = cls.search_by(uid=uid, role=role)
s = cls.search_by(uid=qcall.auth.logged_in_user(), role=role)
if not s:
return None
return sirepo.auth_role.ModerationStatus.check(s.status)

@classmethod
def set_status(cls, uid, role, status, moderator_uid=None):
def set_status(cls, qcall, role, status, moderator_uid):
with sirepo.util.THREAD_LOCK:
s = cls.search_by(uid=uid, role=role)
s = cls.search_by(uid=qcall.auth.logged_in_user(), role=role)
s.status = sirepo.auth_role.ModerationStatus.check(status)
if moderator_uid:
s.moderator_uid = moderator_uid
Expand Down
Loading

0 comments on commit 6126b87

Please sign in to comment.