diff --git a/sirepo/sim_db_file.py b/sirepo/sim_db_file.py index 427ad260bf..88e230b0f5 100644 --- a/sirepo/sim_db_file.py +++ b/sirepo/sim_db_file.py @@ -140,9 +140,10 @@ class SimDbServer(sirepo.agent_supervisor_api.ReqBase): _UID_TO_TOKEN = PKDict() async def get(self, unused_arg): - p = self.__authenticate_and_path() - if p.exists(): - self.write(pkio.read_binary(p)) + if not self.__authenticate_and_path(): + return + if self.__path.exists(): + self.write(pkio.read_binary(self.__path)) else: self.send_error(404) @@ -155,7 +156,8 @@ def _result(value): return value.pksetdefault(state="ok") try: - self.__path = self.__authenticate_and_path() + if not self.__authenticate_and_path(): + return r = pkjson.load_any(self.request.body) # note that args may be empty (but must be PKDict), since uri has path if not isinstance(a := r.get("args"), PKDict): @@ -182,12 +184,16 @@ def _result(value): async def put(self, unused_arg): # TODO(robnagler) should this be atomic? # check size - async with aiofiles.open(self.__authenticate_and_path(), "wb") as f: + if not self.__authenticate_and_path(): + return + async with aiofiles.open(self.__path, "wb") as f: await f.write(self.request.body) async def _sr_post_delete_glob(self, args): + if not self.__authenticate_and_path(): + return t = [] - for f in pkio.sorted_glob(f"{self.__authenticate_and_path()}*"): + for f in pkio.sorted_glob(f"{self.__path}*"): if f.check(dir=True): pkdlog("path={} is a directory", f) self.send_error(403) @@ -197,14 +203,20 @@ async def _sr_post_delete_glob(self, args): pkio.unchecked_remove(f) async def _sr_post_copy(self, args): + p = self.__uri_arg_to_path(args.dst_uri) + if not p: + return # TODO(robnagler) should this be atomic? - self.__path.copy(self.__uri_arg_to_path(args.dst_uri)) + self.__path.copy(p) async def _sr_post_exists(self, args): return self.__path.check(file=True) async def _sr_post_move(self, args): - self.__path.move(self.__uri_arg_to_path(args.dst_uri)) + p = self.__uri_arg_to_path(args.dst_uri) + if not p: + return + self.__path.move(p) async def _sr_post_save_from_url(self, args): max_size = sirepo.job.cfg().max_message_bytes @@ -244,34 +256,18 @@ async def _sr_post_size(self, args): def __authenticate_and_path(self): self.__uid = self._rs_authenticate() - return self.__uri_to_path(self.request.path) + self.__path = _uri_parse(self.request.path, uid=self.__uid) + if self.__path is None: + self.send_error(403) + return False + return self.__path def __uri_arg_to_path(self, uri): - p = uri.split("/") - if len(p) != 3: - raise AssertionError(f"uri={p} must be 3 parts") - return self.__uri_to_path_simple(*p) - - def __uri_to_path(self, uri): - m = _URI_RE.search(uri) - if not m: - pkdlog("uri={} missing {sirepo.job.SIM_DB_FILE_URI} prefix", uri) + res = _uri_parse(uri, uid=self.__uid, is_arg_uri=True) + if res is None: self.send_error(403) - return - p = m.group(1).split("/") - assert len(p) == 4, f"uri={p} must be 4 parts" - assert p[0] == self.__uid, f"uid={p[0]} is not expect_uid={self.__uid}" - return self.__uri_to_path_simple(*p[1:]) - - def __uri_to_path_simple(self, stype, sid_or_lib, basename): - from sirepo import simulation_db, template - - template.assert_sim_type(stype), - _sid_or_lib(sid_or_lib), - simulation_db.assert_sim_db_basename(basename), - return simulation_db.user_path_root().join( - self.__uid, stype, sid_or_lib, basename - ) + return None + return res class SimDbUri(str): @@ -322,6 +318,60 @@ def _sid_or_lib(value): ) +def _uri_parse(uri, uid, is_arg_uri=False): + """Evaluate uri matches correct form and uid + + Separate function for testability + + Args: + uri (str): to test + uid (str): expected user + is_arg_uri (bool): True then do not test uid + Returns: + str: validated relative path to sim_db. None if error + """ + + def _path_join(stype, sid_or_lib, basename): + from sirepo import simulation_db, template + + try: + return simulation_db.user_path(uid=uid, check=True).join( + template.assert_sim_type(stype), + _sid_or_lib(sid_or_lib), + simulation_db.assert_sim_db_basename(basename), + ) + except Exception as e: + pkdlog( + "error={} uid={} stype={} sid_or_lib={} basename={}", + e, + uid, + stype, + sid_or_lib, + basename, + ) + return None + + if len(uri) <= 0: + # no point in logging anything + return + if is_arg_uri: + p = uri.split("/") + else: + m = _URI_RE.search(uri) + if not m: + pkdlog("uri={} missing prefix={}", uri, sirepo.job.SIM_DB_FILE_URI) + return + p = m.group(1).split("/") + if p[0] != uid: + pkdlog("uri={} does not match expect_uid={}", p[0], uid) + return + p.pop(0) + if len(p) != 3: + pkdlog("uri={} invalid part count is_arg_uri={}", uri, is_arg_uri) + return + return _path_join(*p) + + _cfg = pkconfig.init( server_token=(None, str, "credential to connect"), server_uri=(None, str, "how to connect to server"), diff --git a/sirepo/simulation_db.py b/sirepo/simulation_db.py index 258eb03fd0..b5af63428b 100644 --- a/sirepo/simulation_db.py +++ b/sirepo/simulation_db.py @@ -79,7 +79,8 @@ #: configuration _cfg = None -_SIM_DB_BASENAME_RE = re.compile(r"^[a-zA-Z0-9-_\.]{1,128}$") +#: begin/end with alnum, 128 chars max +_SIM_DB_BASENAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_\.-]{1,126}[a-zA-Z0-9]$") #: For re-entrant `user_lock` _USER_LOCK = PKDict(paths=set()) diff --git a/tests/sim_db_file_test.py b/tests/sim_db_file_test.py index 4378b221a2..cd5dc8ee46 100644 --- a/tests/sim_db_file_test.py +++ b/tests/sim_db_file_test.py @@ -44,3 +44,33 @@ def test_save_from_uri(sim_db_file_server): pkunit.pkok(not c.exists(f), "favicon.ico should not exist") c.save_from_url(u, f) pkunit.pkeq(requests.get(u).content, c.get(f)) + + +def test_uri(): + from pykern import pkunit, pkdebug + from sirepo import srunit + + def _full(uri, deviance=True): + r = sim_db_file._uri_parse(f"{job.SIM_DB_FILE_URI}/{uri}", uid) + if deviance: + pkunit.pkok(not r, "unexpected res={} uri={} uid={} ", r, uri, uid) + else: + pkunit.pkeq(simulation_db.user_path_root().join(uri), r) + + srunit.setup_srdb_root() + from sirepo import simulation_db, sim_db_file, job + + uid = simulation_db.user_create() + stype = srunit.SR_SIM_TYPE_DEFAULT + _full( + f"{uid}/{stype}/aValidId/flash_exe-SwBZWpYFR-PqFi81T6rQ8g", + deviance=False, + ) + _full(f"{uid}/{stype}/invalid/valid-file") + _full(f"{uid}/invalid/aValidId/valid-file") + _full(f"notfound/{stype}/aValidId/valid-file") + _full(f"{uid}/{stype}/aValidId/{'too-long':x>129s}") + _full(f"{uid}/{stype}/aValidId/.invalid-part") + _full(f"{uid}/{stype}/aValidId/invalid-part.") + # too few parts + _full(f"{uid}/{stype}/aValidId") diff --git a/tests/simulation_db_test.py b/tests/simulation_db_test.py index 8cd09e4722..d220e369e7 100644 --- a/tests/simulation_db_test.py +++ b/tests/simulation_db_test.py @@ -32,45 +32,3 @@ def _time(data, data_path, trigger, time): pkio.sorted_glob(f.dirpath().join("*/in.json"))[0], t - 10000, ) - - -def test_uid(): - from pykern.pkunit import pkeq, pkexcept, pkre - from pykern.pkdebug import pkdp - from sirepo import simulation_db - - qcall = None - - def _do(uri, uid, expect=True): - if expect: - with pkexcept(AssertionError): - simulation_db.sim_db_file_uri_to_path(uri=uri, expect_uid=uid) - else: - p = simulation_db.sim_db_file_uri_to_path(uri=uri, expect_uid=uid) - pkre(uri + "$", str(p)) - - _do( - "xxx/elegant/RrCoL7rQ/flash_exe-SwBZWpYFR-PqFi81T6rQ8g", - "yyy", - ) - _do( - "xxx/elegant/RrCoL7rQ/../../../foo", - "xxx", - ) - _do( - "yyy/invalid/R/flash_exe-SwBZWpYFR-PqFi81T6rQ8g", - "yyy", - ) - _do( - "yyy/invalid/RrCoL7rQ/flash_exe-SwBZWpYFR-PqFi81T6rQ8g", - "yyy", - ) - _do( - "HsCFbRrQ/elegant/RrCoL7rQ/{}".format("x" * 129), - "HsCFbRrQ", - ) - _do( - "HsCFbRrQ/elegant/RrCoL7rQ/flash_exe-SwBZWpYFR-PqFi81T6rQ8g", - "HsCFbRrQ", - expect=False, - )