Skip to content

Commit

Permalink
add logger
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Jul 14, 2024
1 parent b763c23 commit b7e8875
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 4 deletions.
7 changes: 4 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

model_dir/
dataset_dir/
repos/
/model_dir/
/dataset_dir/
/repos/
/logs/
1 change: 1 addition & 0 deletions olah/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CHUNK_SIZE = 4096
LFS_FILE_BLOCK = 64 * 1024 * 1024

DEFAULT_LOGGER_DIR = "./logs"

from huggingface_hub.constants import (
_HF_DEFAULT_ENDPOINT,
Expand Down
6 changes: 6 additions & 0 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from olah.meta import meta_generator
from olah.utils.url_utils import check_proxy_rules_hf, check_commit_hf, get_commit_hf, get_newest_commit_hf, parse_org_repo

from olah.utils.logging import build_logger

app = FastAPI(debug=False)

class AppSettings(BaseSettings):
Expand Down Expand Up @@ -239,7 +241,11 @@ async def index():
parser.add_argument("--ssl-key", type=str, default=None)
parser.add_argument("--ssl-cert", type=str, default=None)
parser.add_argument("--repos-path", type=str, default="./repos")
parser.add_argument("--log-path", type=str, default="./logs", help="The folder to save logs")
args = parser.parse_args()
print(args)

logger = build_logger("olah", "olah.log", logger_dir=args.log_path)

def is_default_value(args, arg_name):
if hasattr(args, arg_name):
Expand Down
233 changes: 233 additions & 0 deletions olah/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@

from asyncio import AbstractEventLoop
import json
import logging
import logging.handlers
import os
import platform
import sys
from typing import AsyncGenerator, Generator
import warnings

import requests
import torch

from olah.constants import DEFAULT_LOGGER_DIR

server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
moderation_msg = (
"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
)

handler = None


def build_logger(logger_name, logger_filename, logger_dir=DEFAULT_LOGGER_DIR) -> logging.Logger:
global handler

formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

# Set the format of root handlers
if logging.getLogger().handlers is None or len(logging.getLogger().handlers) == 0:
if sys.version_info[1] >= 9:
# This is for windows
logging.basicConfig(level=logging.INFO, encoding="utf-8")
else:
if platform.system() == "Windows":
warnings.warn(
"If you are running on Windows, "
"we recommend you use Python >= 3.9 for UTF-8 encoding."
)
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)

# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.DEBUG)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl

# stderr_logger = logging.getLogger("stderr")
# stderr_logger.setLevel(logging.ERROR)
# sl = StreamToLogger(stderr_logger, logging.ERROR)
# sys.stderr = sl

# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)

# Add a file handler for all loggers
if handler is None:
os.makedirs(logger_dir, exist_ok=True)
filename = os.path.join(logger_dir, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when="M", utc=True, encoding="utf-8"
)
handler.setFormatter(formatter)
handler.namer = lambda name: name.replace(".log", "") + ".log"

for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)

return logger


class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""

def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""

def __getattr__(self, attr):
try:
attr_value = getattr(self.terminal, attr)
except:
return None
return attr_value

def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
else:
self.linebuf += line

def flush(self):
if self.linebuf != "":
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
self.logger.log(self.log_level, encoded_message.rstrip())
self.linebuf = ""


def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch

setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)


def get_gpu_memory(max_gpus=None):
"""Get available memory for each GPU."""
gpu_memory = []
num_gpus = (
torch.cuda.device_count()
if max_gpus is None
else min(max_gpus, torch.cuda.device_count())
)

for gpu_id in range(num_gpus):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = gpu_properties.total_memory / (1024**3)
allocated_memory = torch.cuda.memory_allocated() / (1024**3)
available_memory = total_memory - allocated_memory
gpu_memory.append(available_memory)
return gpu_memory


def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False

return flagged


# Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
# Use this function to make sure it can be correctly loaded.
def clean_flant5_ckpt(ckpt_path):
index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
index_json = json.load(open(index_file, "r"))

weightmap = index_json["weight_map"]

share_weight_file = weightmap["shared.weight"]
share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
"shared.weight"
]

for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
weight_file = weightmap[weight_name]
weight = torch.load(os.path.join(ckpt_path, weight_file))
weight[weight_name] = share_weight
torch.save(weight, os.path.join(ckpt_path, weight_file))


def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"


get_window_url_params_js = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log("url_params", url_params);
return url_params;
}
"""


def iter_over_async(
async_gen: AsyncGenerator, event_loop: AbstractEventLoop
) -> Generator:
"""
Convert async generator to sync generator
:param async_gen: the AsyncGenerator to convert
:param event_loop: the event loop to run on
:returns: Sync generator
"""
ait = async_gen.__aiter__()

async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None

while True:
done, obj = event_loop.run_until_complete(get_next())
if done:
break
yield obj
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "olah"
version = "0.0.6"
version = "0.1.0"
description = "Self-hosted lightweight huggingface mirror."
readme = "README.md"
requires-python = ">=3.8"
Expand Down

0 comments on commit b7e8875

Please sign in to comment.