Skip to content

Commit

Permalink
turn methods into private
Browse files Browse the repository at this point in the history
  • Loading branch information
goFrendiAsgard committed Nov 29, 2023
1 parent 1e63c7f commit 808d0e7
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 88 deletions.
2 changes: 1 addition & 1 deletion src/zrb/task/base_task/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ async def _set_local_keyval(
self.log_info('Merging task envs, task env files, and native envs')
for env_name, env in self._get_combined_env().items():
env_value = env.get(env_prefix)
if env.should_render:
if env.should_render():
env_value = self.render_any(env_value)
self._set_env_map(env_name, env_value)
self._set_env_map('_ZRB_EXECUTION_ID', self.get_execution_id())
Expand Down
89 changes: 46 additions & 43 deletions src/zrb/task/cmd_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
max_error_line = max_error_line if max_error_line > 0 else 1
self._cmd = cmd
self._cmd_path = cmd_path
self._set_cwd(cwd)
self.__set_cwd(cwd)
self._max_output_size = max_output_line
self._max_error_size = max_error_line
self._output_buffer: Iterable[str] = []
Expand All @@ -165,7 +165,10 @@ def __init__(
def copy(self) -> TCmdTask:
return super().copy()

def _set_cwd(
def set_cwd(self, cwd: Union[str, pathlib.Path]):
self.__set_cwd(cwd)

def __set_cwd(
self, cwd: Optional[Union[str, pathlib.Path]]
):
if cwd is None:
Expand Down Expand Up @@ -206,7 +209,7 @@ def inject_envs(self):

async def run(self, *args: Any, **kwargs: Any) -> CmdResult:
cmd = self.get_cmd_script(*args, **kwargs)
self.print_out_dark('Run script: ' + self._get_multiline_repr(cmd))
self.print_out_dark('Run script: ' + self.__get_multiline_repr(cmd))
self.print_out_dark('Working directory: ' + self._cwd)
self._output_buffer = []
self._error_buffer = []
Expand All @@ -226,14 +229,14 @@ async def run(self, *args: Any, **kwargs: Any) -> CmdResult:
self._pids.append(process.pid)
self._process = process
try:
signal.signal(signal.SIGINT, self._on_kill)
signal.signal(signal.SIGTERM, self._on_kill)
signal.signal(signal.SIGINT, self.__on_kill)
signal.signal(signal.SIGTERM, self.__on_kill)
except Exception as e:
self.print_err(e)
atexit.register(self._on_exit)
await self._wait_process(process)
atexit.register(self.__on_exit)
await self.__wait_process(process)
self.log_info('Process completed')
atexit.unregister(self._on_exit)
atexit.unregister(self.__on_exit)
output = '\n'.join(self._output_buffer)
error = '\n'.join(self._error_buffer)
# get return code
Expand All @@ -244,16 +247,6 @@ async def run(self, *args: Any, **kwargs: Any) -> CmdResult:
f'Process {self._name} exited ({return_code}): {error}'
)
return CmdResult(output, error)

def _get_multiline_repr(self, text: str) -> str:
lines_repr: Iterable[str] = []
lines = text.split('\n')
if len(lines) == 1:
return lines[0]
for index, line in enumerate(lines):
line_number_repr = str(index + 1).rjust(4, '0')
lines_repr.append(f' {line_number_repr} | {line}')
return '\n' + '\n'.join(lines_repr)

def _should_attempt(self) -> bool:
if self._global_state.no_more_attempt:
Expand All @@ -265,12 +258,12 @@ def _is_last_attempt(self) -> bool:
return True
return super()._is_last_attempt()

def _on_kill(self, signum: Any, frame: Any):
def __on_kill(self, signum: Any, frame: Any):
self._global_state.no_more_attempt = True
self._global_state.is_killed_by_signal = True
self.print_out_dark(f'Getting signal {signum}')
for pid in self._pids:
self._kill_by_pid(pid)
self.__kill_by_pid(pid)
tasks = asyncio.all_tasks()
for task in tasks:
try:
Expand All @@ -281,57 +274,57 @@ def _on_kill(self, signum: Any, frame: Any):
self.print_out_dark(f'Exiting with signal {signum}')
sys.exit(signum)

def _on_exit(self):
def __on_exit(self):
self._global_state.no_more_attempt = True
self._kill_by_pid(self._process.pid)
self.__kill_by_pid(self._process.pid)

def _kill_by_pid(self, pid: int):
def __kill_by_pid(self, pid: int):
'''
Kill a pid, gracefully
'''
try:
process_ever_exists = False
if self._is_process_exist(pid):
if self.__is_process_exist(pid):
process_ever_exists = True
self.print_out_dark(f'Send SIGTERM to process {pid}')
os.killpg(os.getpgid(pid), signal.SIGTERM)
time.sleep(0.3)
if self._is_process_exist(pid):
if self.__is_process_exist(pid):
self.print_out_dark(f'Send SIGINT to process {pid}')
os.killpg(os.getpgid(pid), signal.SIGINT)
time.sleep(0.3)
if self._is_process_exist(pid):
if self.__is_process_exist(pid):
self.print_out_dark(f'Send SIGKILL to process {pid}')
os.killpg(os.getpgid(pid), signal.SIGKILL)
if process_ever_exists:
self.print_out_dark(f'Process {pid} is killed successfully')
except Exception:
self.log_error(f'Cannot kill process {pid}')

def _is_process_exist(self, pid: int) -> bool:
def __is_process_exist(self, pid: int) -> bool:
try:
os.killpg(os.getpgid(pid), 0)
return True
except ProcessLookupError:
return False

async def _wait_process(self, process: asyncio.subprocess.Process):
async def __wait_process(self, process: asyncio.subprocess.Process):
# Create queue
stdout_queue = asyncio.Queue()
stderr_queue = asyncio.Queue()
# Read from streams and put into queue
stdout_process = asyncio.create_task(self._queue_stream(
stdout_process = asyncio.create_task(self.__queue_stream(
process.stdout, stdout_queue
))
stderr_process = asyncio.create_task(self._queue_stream(
stderr_process = asyncio.create_task(self.__queue_stream(
process.stderr, stderr_queue
))
# Handle messages in queue
stdout_log_process = asyncio.create_task(self._log_from_queue(
stdout_log_process = asyncio.create_task(self.__log_from_queue(
stdout_queue, self.print_out,
self._output_buffer, self._max_output_size
))
stderr_log_process = asyncio.create_task(self._log_from_queue(
stderr_log_process = asyncio.create_task(self.__log_from_queue(
stderr_queue, self.print_err,
self._error_buffer, self._max_error_size
))
Expand All @@ -355,13 +348,13 @@ def _create_cmd_script(
) -> str:
if not isinstance(cmd_path, str) or cmd_path != '':
if callable(cmd_path):
return self._get_rendered_cmd_path(cmd_path(*args, **kwargs))
return self._get_rendered_cmd_path(cmd_path)
return self.__get_rendered_cmd_path(cmd_path(*args, **kwargs))
return self.__get_rendered_cmd_path(cmd_path)
if callable(cmd):
return self._get_rendered_cmd(cmd(*args, **kwargs))
return self._get_rendered_cmd(cmd)
return self.__get_rendered_cmd(cmd(*args, **kwargs))
return self.__get_rendered_cmd(cmd)

def _get_rendered_cmd_path(
def __get_rendered_cmd_path(
self, cmd_path: Union[str, Iterable[str]]
) -> str:
if isinstance(cmd_path, str):
Expand All @@ -371,12 +364,12 @@ def _get_rendered_cmd_path(
for cmd_path_str in cmd_path
])

def _get_rendered_cmd(self, cmd: Union[str, Iterable[str]]) -> str:
def __get_rendered_cmd(self, cmd: Union[str, Iterable[str]]) -> str:
if isinstance(cmd, str):
return self.render_str(cmd)
return self.render_str('\n'.join(list(cmd)))

async def _queue_stream(self, stream, queue: asyncio.Queue):
async def __queue_stream(self, stream, queue: asyncio.Queue):
while True:
try:
line = await stream.readline()
Expand All @@ -386,7 +379,7 @@ async def _queue_stream(self, stream, queue: asyncio.Queue):
break
await queue.put(line)

async def _log_from_queue(
async def __log_from_queue(
self,
queue: asyncio.Queue,
print_log: Callable[[str], None],
Expand All @@ -398,17 +391,27 @@ async def _log_from_queue(
if not line:
break
line_str = line.decode('utf-8').rstrip()
self._add_to_buffer(buffer, max_size, line_str)
self.__add_to_buffer(buffer, max_size, line_str)
_reset_stty()
print_log(line_str)
_reset_stty()

def _add_to_buffer(
def __add_to_buffer(
self, buffer: Iterable[str], max_size: int, new_line: str
):
if len(buffer) >= max_size:
buffer.pop(0)
buffer.append(new_line)


def __get_multiline_repr(self, text: str) -> str:
lines_repr: Iterable[str] = []
lines = text.split('\n')
if len(lines) == 1:
return lines[0]
for index, line in enumerate(lines):
line_number_repr = str(index + 1).rjust(4, '0')
lines_repr.append(f' {line_number_repr} | {line}')
return '\n' + '\n'.join(lines_repr)

def __repr__(self) -> str:
return f'<CmdTask name={self._name}>'
37 changes: 20 additions & 17 deletions src/zrb/task/docker_compose_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def __init__(
self._compose_flags = compose_flags
self._compose_args = compose_args
self._compose_env_prefix = compose_env_prefix
self._compose_template_file = self._get_compose_template_file(
self._compose_template_file = self.__get_compose_template_file(
compose_file
)
self._compose_runtime_file = self._get_compose_runtime_file(
self._compose_runtime_file = self.__get_compose_runtime_file(
self._compose_template_file
)
# Flag to make mark whether service config and compose environments
Expand All @@ -160,7 +160,7 @@ def copy(self) -> TDockerComposeTask:
return super().copy()

async def run(self, *args, **kwargs: Any) -> CmdResult:
self._generate_compose_runtime_file()
self.__generate_compose_runtime_file()
try:
result = await super().run(*args, **kwargs)
finally:
Expand Down Expand Up @@ -192,18 +192,20 @@ def inject_env_files(self):
for _, service_config in self._compose_service_configs.items():
self.insert_env_file(*service_config.get_env_files())

def _generate_compose_runtime_file(self):
def __generate_compose_runtime_file(self):
compose_data = read_compose_file(self._compose_template_file)
for service, service_config in self._compose_service_configs.items():
envs: List[Env] = []
env_files = service_config.get_env_files()
for env_file in env_files:
envs += env_file.get_envs()
envs += service_config.get_envs()
compose_data = self._apply_service_env(compose_data, service, envs)
compose_data = self.__apply_service_env(
compose_data, service, envs
)
write_compose_file(self._compose_runtime_file, compose_data)

def _apply_service_env(
def __apply_service_env(
self, compose_data: Any, service: str, envs: List[Env]
) -> Any:
# service not found
Expand All @@ -213,38 +215,39 @@ def _apply_service_env(
# service has no environment definition
if 'environment' not in compose_data['services'][service]:
compose_data['services'][service]['environment'] = {
env.get_name(): self._get_env_compose_value(env)
env.get_name(): self.__get_env_compose_value(env)
for env in envs
}
return compose_data
# service environment is a map
if isinstance(compose_data['services'][service]['environment'], dict):
new_env_map = self._get_service_new_env_map(
new_env_map = self.__get_service_new_env_map(
compose_data['services'][service]['environment'], envs
)
for key, value in new_env_map.items():
compose_data['services'][service]['environment'][key] = value
return compose_data
# service environment is a list
if isinstance(compose_data['services'][service]['environment'], list):
new_env_list = self._get_service_new_env_list(
new_env_list = self.__get_service_new_env_list(
compose_data['services'][service]['environment'], envs
)
compose_data['services'][service]['environment'] += new_env_list
return compose_data
return compose_data

def _get_service_new_env_map(
def __get_service_new_env_map(
self, service_env_map: Mapping[str, str], new_envs: List[Env]
) -> Mapping[str, str]:
new_service_envs: Mapping[str, str] = {}
for env in new_envs:
if env.get_name() in service_env_map:
env_name = env.get_name()
if env_name in service_env_map:
continue
new_service_envs[env.get_name()] = self._get_env_compose_value(env)
new_service_envs[env_name] = self.__get_env_compose_value(env)
return new_service_envs

def _get_service_new_env_list(
def __get_service_new_env_list(
self, service_env_list: List[str], new_envs: List[Env]
) -> List[str]:
new_service_envs: List[str] = []
Expand All @@ -256,14 +259,14 @@ def _get_service_new_env_list(
if not should_be_added:
continue
new_service_envs.append(
env.get_name() + '=' + self._get_env_compose_value(env)
env.get_name() + '=' + self.__get_env_compose_value(env)
)
return new_service_envs

def _get_env_compose_value(self, env: Env) -> str:
def __get_env_compose_value(self, env: Env) -> str:
return '${' + env.get_name() + ':-' + env.get_default() + '}'

def _get_compose_runtime_file(self, compose_file_name: str) -> str:
def __get_compose_runtime_file(self, compose_file_name: str) -> str:
directory, file = os.path.split(compose_file_name)
prefix = '_' if file.startswith('.') else '._'
runtime_prefix = self.get_cmd_name()
Expand All @@ -282,7 +285,7 @@ def _get_compose_runtime_file(self, compose_file_name: str) -> str:
runtime_file_name = prefix + file + runtime_prefix
return os.path.join(directory, runtime_file_name)

def _get_compose_template_file(self, compose_file: Optional[str]) -> str:
def __get_compose_template_file(self, compose_file: Optional[str]) -> str:
if compose_file is None:
for _compose_file in [
'compose.yml', 'compose.yaml',
Expand Down
Loading

0 comments on commit 808d0e7

Please sign in to comment.