From b7e887556ca2d9a3a32a854fd80501b53ec698ff Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Sun, 14 Jul 2024 17:27:48 +0800 Subject: [PATCH] add logger --- .gitignore | 7 +- olah/constants.py | 1 + olah/server.py | 6 ++ olah/utils/logging.py | 233 ++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 5 files changed, 245 insertions(+), 4 deletions(-) create mode 100644 olah/utils/logging.py diff --git a/.gitignore b/.gitignore index 3793090..fbafc9a 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,7 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -model_dir/ -dataset_dir/ -repos/ +/model_dir/ +/dataset_dir/ +/repos/ +/logs/ \ No newline at end of file diff --git a/olah/constants.py b/olah/constants.py index aee106f..3e55db0 100644 --- a/olah/constants.py +++ b/olah/constants.py @@ -3,6 +3,7 @@ CHUNK_SIZE = 4096 LFS_FILE_BLOCK = 64 * 1024 * 1024 +DEFAULT_LOGGER_DIR = "./logs" from huggingface_hub.constants import ( _HF_DEFAULT_ENDPOINT, diff --git a/olah/server.py b/olah/server.py index 665e8fa..66d1678 100644 --- a/olah/server.py +++ b/olah/server.py @@ -16,6 +16,8 @@ from olah.meta import meta_generator from olah.utils.url_utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo +from olah.utils.logging import build_logger + app = FastAPI(debug=False) class AppSettings(BaseSettings): @@ -239,7 +241,11 @@ async def index(): parser.add_argument("--ssl-key", type=str, default=None) parser.add_argument("--ssl-cert", type=str, default=None) parser.add_argument("--repos-path", type=str, default="./repos") + parser.add_argument("--log-path", type=str, default="./logs", help="The folder to save logs") args = parser.parse_args() + print(args) + + logger = build_logger("olah", "olah.log", logger_dir=args.log_path) def is_default_value(args, arg_name): if hasattr(args, arg_name): diff --git a/olah/utils/logging.py b/olah/utils/logging.py new file mode 100644 index 0000000..8067309 --- /dev/null +++ b/olah/utils/logging.py @@ -0,0 +1,233 @@ + +from asyncio import AbstractEventLoop +import json +import logging +import logging.handlers +import os +import platform +import sys +from typing import AsyncGenerator, Generator +import warnings + +import requests +import torch + +from olah.constants import DEFAULT_LOGGER_DIR + +server_error_msg = ( + "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" +) +moderation_msg = ( + "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." +) + +handler = None + + +def build_logger(logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR) -> logging.Logger: + global handler + + formatter = logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Set the format of root handlers + if logging.getLogger().handlers is None or len(logging.getLogger().handlers) == 0: + if sys.version_info[1] >= 9: + # This is for windows + logging.basicConfig(level=logging.INFO, encoding="utf-8") + else: + if platform.system() == "Windows": + warnings.warn( + "If you are running on Windows, " + "we recommend you use Python >= 3.9 for UTF-8 encoding." + ) + logging.basicConfig(level=logging.INFO) + logging.getLogger().handlers[0].setFormatter(formatter) + + # Redirect stdout and stderr to loggers + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.DEBUG) + sl = StreamToLogger(stdout_logger, logging.INFO) + sys.stdout = sl + + # stderr_logger = logging.getLogger("stderr") + # stderr_logger.setLevel(logging.ERROR) + # sl = StreamToLogger(stderr_logger, logging.ERROR) + # sys.stderr = sl + + # Get logger + logger = logging.getLogger(logger_name) + logger.setLevel(logging.DEBUG) + + # Add a file handler for all loggers + if handler is None: + os.makedirs(logger_dir, exist_ok=True) + filename = os.path.join(logger_dir, logger_filename) + handler = logging.handlers.TimedRotatingFileHandler( + filename, when="M", utc=True, encoding="utf-8" + ) + handler.setFormatter(formatter) + handler.namer = lambda name: name.replace(".log", "") + ".log" + + for name, item in logging.root.manager.loggerDict.items(): + if isinstance(item, logging.Logger): + item.addHandler(handler) + + return logger + + +class StreamToLogger(object): + """ + Fake file-like stream object that redirects writes to a logger instance. + """ + + def __init__(self, logger, log_level=logging.INFO): + self.terminal = sys.stdout + self.logger = logger + self.log_level = log_level + self.linebuf = "" + + def __getattr__(self, attr): + try: + attr_value = getattr(self.terminal, attr) + except: + return None + return attr_value + + def write(self, buf): + temp_linebuf = self.linebuf + buf + self.linebuf = "" + for line in temp_linebuf.splitlines(True): + # From the io.TextIOWrapper docs: + # On output, if newline is None, any '\n' characters written + # are translated to the system default line separator. + # By default sys.stdout.write() expects '\n' newlines and then + # translates them so this is still cross platform. + if line[-1] == "\n": + encoded_message = line.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + else: + self.linebuf += line + + def flush(self): + if self.linebuf != "": + encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") + self.logger.log(self.log_level, encoded_message.rstrip()) + self.linebuf = "" + + +def disable_torch_init(): + """ + Disable the redundant torch default initialization to accelerate model creation. + """ + import torch + + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) + + +def get_gpu_memory(max_gpus=None): + """Get available memory for each GPU.""" + gpu_memory = [] + num_gpus = ( + torch.cuda.device_count() + if max_gpus is None + else min(max_gpus, torch.cuda.device_count()) + ) + + for gpu_id in range(num_gpus): + with torch.cuda.device(gpu_id): + device = torch.cuda.current_device() + gpu_properties = torch.cuda.get_device_properties(device) + total_memory = gpu_properties.total_memory / (1024**3) + allocated_memory = torch.cuda.memory_allocated() / (1024**3) + available_memory = total_memory - allocated_memory + gpu_memory.append(available_memory) + return gpu_memory + + +def violates_moderation(text): + """ + Check whether the text violates OpenAI moderation API. + """ + url = "https://api.openai.com/v1/moderations" + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], + } + text = text.replace("\n", "") + data = "{" + '"input": ' + f'"{text}"' + "}" + data = data.encode("utf-8") + try: + ret = requests.post(url, headers=headers, data=data, timeout=5) + flagged = ret.json()["results"][0]["flagged"] + except requests.exceptions.RequestException as e: + flagged = False + except KeyError as e: + flagged = False + + return flagged + + +# Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, +# Use this function to make sure it can be correctly loaded. +def clean_flant5_ckpt(ckpt_path): + index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") + index_json = json.load(open(index_file, "r")) + + weightmap = index_json["weight_map"] + + share_weight_file = weightmap["shared.weight"] + share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ + "shared.weight" + ] + + for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: + weight_file = weightmap[weight_name] + weight = torch.load(os.path.join(ckpt_path, weight_file)) + weight[weight_name] = share_weight + torch.save(weight, os.path.join(ckpt_path, weight_file)) + + +def pretty_print_semaphore(semaphore): + if semaphore is None: + return "None" + return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" + + +get_window_url_params_js = """ +function() { + const params = new URLSearchParams(window.location.search); + url_params = Object.fromEntries(params); + console.log("url_params", url_params); + return url_params; + } +""" + + +def iter_over_async( + async_gen: AsyncGenerator, event_loop: AbstractEventLoop +) -> Generator: + """ + Convert async generator to sync generator + + :param async_gen: the AsyncGenerator to convert + :param event_loop: the event loop to run on + :returns: Sync generator + """ + ait = async_gen.__aiter__() + + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + + while True: + done, obj = event_loop.run_until_complete(get_next()) + if done: + break + yield obj \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b402824..2f3e246 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "olah" -version = "0.0.6" +version = "0.1.0" description = "Self-hosted lightweight huggingface mirror." readme = "README.md" requires-python = ">=3.8"