diff --git a/README.md b/README.md index 6ef4785a..572b4b05 100644 --- a/README.md +++ b/README.md @@ -147,21 +147,21 @@ The following command will run TORAX using the default configuration file ```shell python3 run_simulation_main.py \ - --python_config='torax.examples.basic_config' --log_progress + --config='torax.examples.basic_config' --log_progress ``` To run more involved, ITER-inspired simulations, run: ```shell python3 run_simulation_main.py \ - --python_config='torax.examples.iterhybrid_rampup' --log_progress + --config='torax.examples.iterhybrid_rampup' --log_progress ``` and ```shell python3 run_simulation_main.py \ - --python_config='torax.examples.iterhybrid_predictor_corrector' --log_progress + --config='torax.examples.iterhybrid_predictor_corrector' --log_progress ``` Additional configuration is provided through flags which append the above run command, and environment variables: @@ -198,7 +198,7 @@ Output simulation time, dt, and number of stepper iterations (dt backtracking wi ```shell python3 run_simulation_main.py \ - --python_config='torax.examples.iterhybrid_predictor_corrector' \ + --config='torax.examples.iterhybrid_predictor_corrector' \ --log_progress ``` @@ -206,7 +206,7 @@ Live plotting of simulation state and derived quantities. ```shell python3 run_simulation_main.py \ - --python_config='torax.examples.iterhybrid_predictor_corrector' \ + --config='torax.examples.iterhybrid_predictor_corrector' \ --plot_progress ``` @@ -214,7 +214,7 @@ Combination of the above. ```shell python3 run_simulation_main.py \ - --python_config='torax.examples.iterhybrid_predictor_corrector' \ + --config='torax.examples.iterhybrid_predictor_corrector' \ --log_progress --plot_progress ``` diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 023f3508..39534ae6 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -13,7 +13,7 @@ of all input configuration fields. .. code-block:: console python3 run_simulation_main.py \ - --python_config='torax.examples.basic_config' --log_progress + --config='torax.examples.basic_config' --log_progress More involved examples in ``torax/examples`` include non-rigorous mockups of the ITER hybrid scenario: @@ -72,7 +72,7 @@ For nonlinear solvers, the stepper iterations can be more than 1 due to dt backt .. code-block:: console python3 run_simulation_main.py \ - --python_config='torax.examples.basic_config' \ + --config='torax.examples.basic_config' \ --log_progress plot_progress @@ -82,7 +82,7 @@ Live plotting of simulation state and derived quantities as the simulation progr .. code-block:: console python3 run_simulation_main.py \ - --python_config='torax.examples.basic_config' \ + --config='torax.examples.basic_config' \ --plot_progress For a combination of the above: @@ -90,7 +90,7 @@ For a combination of the above: .. code-block:: console python3 run_simulation_main.py \ - --python_config='torax.examples.basic_config' \ + --config='torax.examples.basic_config' \ --log_progress --plot_progress Post-simulation diff --git a/run_simulation_main.py b/run_simulation_main.py index 70390511..c14edcf1 100644 --- a/run_simulation_main.py +++ b/run_simulation_main.py @@ -16,7 +16,7 @@ Example command with a configuration defined in Python: python3 run_simulation_main.py \ - --python_config='torax.tests.test_data.default_config' \ + --config='torax.tests.test_data.default_config' \ --log_progress """ @@ -33,19 +33,19 @@ _PYTHON_CONFIG_MODULE = flags.DEFINE_string( - 'python_config', + 'config', None, 'Module from which to import a python-based config. This program expects a ' '`get_sim()` function to be implemented in this module. Can either be ' 'an absolute or relative path. See importlib.import_module() for more ' - 'information on how to use this flag and --python_config_package.', + 'information on how to use this flag and --config_package.', ) _PYTHON_CONFIG_PACKAGE = flags.DEFINE_string( - 'python_config_package', + 'config_package', None, - 'If provided, it is the base package the --python_config is imported from. ' - 'This is required if --python_config is a relative path.', + 'If provided, it is the base package the --config is imported from. ' + 'This is required if --config is a relative path.', ) _LOG_SIM_PROGRESS = flags.DEFINE_bool( @@ -73,8 +73,7 @@ class _UserCommand(enum.Enum): """Options to do on every iteration of the script.""" - QUIT = ('quit', 'q') - RUN = ('run the simulation', 'r') + RUN = ('RUN SIMULATION', 'r', simulation_app.AnsiColors.GREEN) CHANGE_CONFIG = ( 'change config for the same sim object (may recompile)', 'cc', @@ -86,6 +85,7 @@ class _UserCommand(enum.Enum): TOGGLE_LOG_SIM_PROGRESS = ('toggle --log_progress', 'tlp') TOGGLE_PLOT_SIM_PROGRESS = ('toggle --plot_progress', 'tpp') TOGGLE_LOG_SIM_OUTPUT = ('toggle --log_output', 'tlo') + QUIT = ('quit', 'q', simulation_app.AnsiColors.RED) # Tracks all the modules imported so far. Maps the name to the module object. @@ -93,38 +93,41 @@ class _UserCommand(enum.Enum): def _import_module(module_name: str): - if module_name in _ALL_MODULES: - return importlib.reload(_ALL_MODULES[module_name]) - else: - module = importlib.import_module(module_name, _PYTHON_CONFIG_PACKAGE.value) - _ALL_MODULES[module_name] = module - return module - - -def _get_config_module( - config_module_str: str | None = None, -): - config_module_str = config_module_str or _PYTHON_CONFIG_MODULE.value - return _import_module(config_module_str), config_module_str + """Imports a module.""" + try: + if module_name in _ALL_MODULES: + return importlib.reload(_ALL_MODULES[module_name]) + else: + module = importlib.import_module( + module_name, _PYTHON_CONFIG_PACKAGE.value + ) + _ALL_MODULES[module_name] = module + return module + except Exception as e: + simulation_app.log_to_stdout(f'Exception raised: {e}') + raise ValueError('Exception while importing.') from e def prompt_user(config_module_str: str) -> _UserCommand: """Prompts the user for the next thing to do.""" + simulation_app.log_to_stdout('\n') simulation_app.log_to_stdout( - f'Running with the config: {config_module_str}', + f'Using the config: {config_module_str}', color=simulation_app.AnsiColors.YELLOW, ) user_command = None + simulation_app.log_to_stdout('\n') while user_command is None: simulation_app.log_to_stdout( 'What would you like to do next?', color=simulation_app.AnsiColors.BLUE, ) for uc in _UserCommand: - simulation_app.log_to_stdout( - f'{uc.value[1]}: {uc.value[0]}', - color=simulation_app.AnsiColors.BLUE, - ) + if len(uc.value) == 3: + color = uc.value[2] + else: + color = simulation_app.AnsiColors.BLUE + simulation_app.log_to_stdout(f'{uc.value[1]}: {uc.value[0]}', color) input_text = input('Your choice: ') text_to_uc = {uc.value[1]: uc for uc in _UserCommand} input_text = input_text.lower().strip() @@ -146,30 +149,23 @@ def maybe_update_config_module( f'Existing module: {config_module_str}', color=simulation_app.AnsiColors.BLUE, ) - input_text = None - while input_text is None: - simulation_app.log_to_stdout( - 'Would you like to change which file to import?', - color=simulation_app.AnsiColors.BLUE, - ) - input_text = input('y/n: ') - input_text = input_text.lower().strip() - if input_text not in ('y', 'n'): - simulation_app.log_to_stdout( - 'Unrecognized input. Try again.', - color=simulation_app.AnsiColors.YELLOW, - ) - input_text = None - if input_text == 'y': + simulation_app.log_to_stdout( + 'Would you like to change which file to import?', + color=simulation_app.AnsiColors.BLUE, + ) + should_change = _get_yes_or_no() + if should_change: logging.info('Updating the module.') return input('Enter the new module to use: ').strip() - return config_module_str + else: + logging.info('Continuing with %s', config_module_str) + return config_module_str def change_config( sim: torax.Sim, config_module_str: str, -) -> tuple[torax.Sim, torax.GeneralRuntimeParams, str]: +) -> tuple[torax.Sim, torax.GeneralRuntimeParams] | None: """Returns a new Sim with the updated config but same SimulationStepFn. This function gives the user a chance to reuse the SimulationStepFn without @@ -182,22 +178,39 @@ def change_config( Args: sim: Sim object used in the previous run. - config_module_str: Config module used previously. User will have the - opportunity to update which module to load. + config_module_str: Config module being used. Returns: Tuple with: - New Sim object with new config. - New Config object with modified configuration attributes - - Name of the module used to load the config. """ - config_module_str = maybe_update_config_module(config_module_str) simulation_app.log_to_stdout( f'Change {config_module_str} to include new values.', color=simulation_app.AnsiColors.BLUE, ) - input('Press Enter when ready.') - config_module, _ = _get_config_module(config_module_str) + yellow = simulation_app.AnsiColors.YELLOW + simulation_app.log_to_stdout('You cannot change the following:', color=yellow) + simulation_app.log_to_stdout(' - stepper type', color=yellow) + simulation_app.log_to_stdout(' - transport model type', color=yellow) + simulation_app.log_to_stdout(' - source types', color=yellow) + simulation_app.log_to_stdout(' - time step calculator', color=yellow) + simulation_app.log_to_stdout( + 'To change these parameters, select "cs" from the main menu.', + color=yellow, + ) + simulation_app.log_to_stdout( + 'Modify the config with new values, then enter "y".', + color=simulation_app.AnsiColors.BLUE, + ) + simulation_app.log_to_stdout( + 'Enter "n" to go back to the main menu without loading any changes.', + color=simulation_app.AnsiColors.BLUE, + ) + proceed_with_run = _get_yes_or_no() + if not proceed_with_run: + return None + config_module = _import_module(config_module_str) if hasattr(config_module, 'CONFIG'): # Assume that the config module uses the basic config dict to build Sim. sim_config = config_module.CONFIG @@ -245,7 +258,7 @@ def change_config( source_runtime_params=new_source_params, stepper_runtime_params_getter=stepper_params_getter, ) - return sim, new_runtime_params, config_module_str + return sim, new_runtime_params def change_sim_obj( @@ -270,16 +283,60 @@ def change_sim_obj( config_module_str = maybe_update_config_module(config_module_str) simulation_app.log_to_stdout( f'Change {config_module_str} to include new values. Any changes to ' - 'get_sim() will be picked up.', + 'CONFIG or get_sim() will be picked up.', color=simulation_app.AnsiColors.BLUE, ) input('Press Enter when done changing the module.') - config_module, _ = _get_config_module(config_module_str) - new_runtime_params = config_module.get_runtime_params() - sim = config_module.get_sim() + sim, new_runtime_params = _build_sim_and_runtime_params_from_config_module( + config_module_str + ) return sim, new_runtime_params, config_module_str +def _build_sim_and_runtime_params_from_config_module( + config_module_str: str, +) -> tuple[torax.Sim, torax.GeneralRuntimeParams]: + """Returns a Sim and RuntimeParams from the config module.""" + config_module = _import_module(config_module_str) + if hasattr(config_module, 'CONFIG'): + # The module likely uses the "basic" config setup which has a single CONFIG + # dictionary defining the full simulation. + config = config_module.CONFIG + new_runtime_params = build_sim.build_runtime_params_from_config( + config['runtime_params'] + ) + sim = build_sim.build_sim_from_config(config) + elif hasattr(config_module, 'get_runtime_params') and hasattr( + config_module, 'get_sim' + ): + # The module is likely using the "advances", more Python-forward + # configuration setup. + new_runtime_params = config_module.get_runtime_params() + sim = config_module.get_sim() + else: + raise ValueError( + f'Config module {config_module_str} must either define a get_sim() ' + 'method or a CONFIG dictionary.' + ) + return sim, new_runtime_params + + +def _get_yes_or_no() -> bool: + """Returns a boolean indicating yes depending on user input.""" + input_text = None + while input_text is None: + input_text = input('y/n: ') + input_text = input_text.lower().strip() + if input_text not in ('y', 'n'): + simulation_app.log_to_stdout( + 'Unrecognized input. Try again.', + color=simulation_app.AnsiColors.YELLOW, + ) + input_text = None + else: + return input_text == 'y' + + def _toggle_log_progress(log_sim_progress: bool) -> bool: """Toggles the --log_progress flag.""" log_sim_progress = not log_sim_progress @@ -341,37 +398,34 @@ def _toggle_log_output(log_sim_output: bool) -> bool: def main(_): - config_module, config_module_str = _get_config_module() - if hasattr(config_module, 'CONFIG'): - # The module likely uses the "basic" config setup which has a single CONFIG - # dictionary defining the full simulation. - config = config_module.CONFIG - new_runtime_params = build_sim.build_runtime_params_from_config( - config['runtime_params'] - ) - sim = build_sim.build_sim_from_config(config) - elif hasattr(config_module, 'get_runtime_params') and hasattr( - config_module, 'get_sim' - ): - # The module is likely using the "advances", more Python-forward - # configuration setup. - new_runtime_params = config_module.get_runtime_params() - sim = config_module.get_sim() - else: - raise ValueError( - f'Config module {config_module_str} must either define a get_sim() ' - 'method or a CONFIG dictionary.' - ) + config_module_str = _PYTHON_CONFIG_MODULE.value + if config_module_str is None: + raise ValueError(f'--{_PYTHON_CONFIG_MODULE.name} must be specified.') log_sim_progress = _LOG_SIM_PROGRESS.value plot_sim_progress = _PLOT_SIM_PROGRESS.value log_sim_output = _LOG_SIM_OUTPUT.value - simulation_app.main( - lambda: sim, - output_dir=new_runtime_params.output_dir, - log_sim_progress=log_sim_progress, - plot_sim_progress=plot_sim_progress, - log_sim_output=log_sim_output, - ) + sim = None + new_runtime_params = None + try: + sim, new_runtime_params = _build_sim_and_runtime_params_from_config_module( + config_module_str + ) + simulation_app.main( + lambda: sim, + output_dir=new_runtime_params.output_dir, + log_sim_progress=log_sim_progress, + plot_sim_progress=plot_sim_progress, + log_sim_output=log_sim_output, + ) + except ValueError as ve: + simulation_app.log_to_stdout( + f'Error ocurred: {ve}', + color=simulation_app.AnsiColors.RED, + ) + simulation_app.log_to_stdout( + 'Not running sim. Update config and try again.', + color=simulation_app.AnsiColors.RED, + ) user_command = prompt_user(config_module_str) while user_command != _UserCommand.QUIT: match user_command: @@ -379,23 +433,67 @@ def main(_): # This line shouldn't get hit, but is here for pytype. return # Exit the function. case _UserCommand.RUN: - simulation_app.main( - lambda: sim, - output_dir=new_runtime_params.output_dir, - log_sim_progress=log_sim_progress, - plot_sim_progress=plot_sim_progress, - log_sim_output=log_sim_output, - ) + if sim is None or new_runtime_params is None: + simulation_app.log_to_stdout( + 'Need to reload the simulation.', + color=simulation_app.AnsiColors.RED, + ) + simulation_app.log_to_stdout( + 'Try changing the config and running with' + f' {_UserCommand.CHANGE_SIM_OBJ.value[1]} from the main menu.', + color=simulation_app.AnsiColors.RED, + ) + else: + simulation_app.main( + lambda: sim, + output_dir=new_runtime_params.output_dir, + log_sim_progress=log_sim_progress, + plot_sim_progress=plot_sim_progress, + log_sim_output=log_sim_output, + ) case _UserCommand.CHANGE_CONFIG: # See docstring for detailed info on what recompiles. - sim, new_runtime_params, config_module_str = change_config( - sim, config_module_str - ) + if sim is None or new_runtime_params is None: + simulation_app.log_to_stdout( + 'Need to reload the simulation.', + color=simulation_app.AnsiColors.RED, + ) + simulation_app.log_to_stdout( + 'Try changing the config and running with' + f' {_UserCommand.CHANGE_SIM_OBJ.value[1]} from the main menu.', + color=simulation_app.AnsiColors.RED, + ) + else: + try: + sim_and_runtime_params_or_none = change_config( + sim, config_module_str + ) + if sim_and_runtime_params_or_none is not None: + sim, new_runtime_params = sim_and_runtime_params_or_none + except ValueError as ve: + simulation_app.log_to_stdout( + f'Error ocurred: {ve}', + color=simulation_app.AnsiColors.RED, + ) + simulation_app.log_to_stdout( + 'Update config and try again.', + color=simulation_app.AnsiColors.RED, + ) case _UserCommand.CHANGE_SIM_OBJ: # This always builds a new object and requires recompilation. - sim, new_runtime_params, config_module_str = change_sim_obj( - config_module_str - ) + try: + sim, new_runtime_params, config_module_str = change_sim_obj( + config_module_str + ) + except ValueError as ve: + simulation_app.log_to_stdout( + f'Error ocurred: {ve}', + color=simulation_app.AnsiColors.RED, + ) + simulation_app.log_to_stdout( + 'Update config and try again.', + color=simulation_app.AnsiColors.RED, + ) case _UserCommand.TOGGLE_LOG_SIM_PROGRESS: log_sim_progress = _toggle_log_progress(log_sim_progress) case _UserCommand.TOGGLE_PLOT_SIM_PROGRESS: