Skip to content

Commit

Permalink
[Fix] Consume much more gpt memory running eval_rm (LAION-AI#3614)
Browse files Browse the repository at this point in the history
Fix LAION-AI#3611.
Still debugging or model_training.

---------

Co-authored-by: Lin Junpeng <linjunpeng@sensetime.com>
  • Loading branch information
SingL3 and Lin Junpeng authored Aug 30, 2023
1 parent 7e40ee3 commit 709bb99
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
27 changes: 16 additions & 11 deletions model/model_eval/eval_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from model_training.custom_datasets.ranking_collator import RankingDataCollator
from model_training.metrics import RewardMetrics
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.trainer_utils import EvalPrediction
from utils import write_to_json
Expand All @@ -29,15 +30,16 @@ def get_ranking_dataset(dataset, split):
def batch_inference(inputs, model):
batch, cu_lens = inputs
batch = {k: v.to(model.device) for k, v in batch.items()}
logits = (
model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
.logits.detach()
.cpu()
.numpy()
)

with torch.no_grad():
logits = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits.detach().cpu()

if logits.dtype == torch.bfloat16:
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
# Until Numpy adds bfloat16, we must convert float32.
logits = logits.to(torch.float32)
logits = logits.numpy()

labels = []
for i, (s, e) in enumerate(zip(cu_lens[:-1], cu_lens[1:])):
Expand All @@ -54,6 +56,7 @@ def batch_inference(inputs, model):
parser.add_argument("--metrics", type=str, help="metrics to evaluate", default="accuracy")
parser.add_argument("--batch_size", type=int, help="Batch Size", default=8)
parser.add_argument("--device", type=str, help="device", default="cuda")
parser.add_argument("--dtype", type=str, help="data type", default=None)
args = parser.parse_args().__dict__

if args.get("device") != "cpu":
Expand All @@ -64,7 +67,9 @@ def batch_inference(inputs, model):
model_name = args.get("model")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name, torch_dtype="auto" if not args.dtype else args.dtype
)
model.eval()
model.to(device)
max_length = args.get("max_length") or model.config.max_position_embeddings
Expand All @@ -77,7 +82,7 @@ def batch_inference(inputs, model):
metrics = args.get("metrics").split(",")
compute_metrics = RewardMetrics(metrics)
score_dict = defaultdict(float)
for i, data in enumerate(dataset):
for i, data in enumerate(tqdm(dataset)):
eval_pred = batch_inference(data, model)
results = compute_metrics(eval_pred)
for metric in metrics:
Expand Down
6 changes: 2 additions & 4 deletions model/model_training/custom_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,9 @@ def get_one_dataset(
elif dataset_name == "gpt4all":
dataset = Gpt4All(mode=mode, cache_dir=data_path)
elif dataset_name == "prosocial_dialogue":
train = ProsocialDialogue(cache_dir=data_path, split="train")
eval = ProsocialDialogue(cache_dir=data_path, split="validation")
dataset = ProsocialDialogue(cache_dir=data_path, split="train")
elif dataset_name == "explain_prosocial":
train = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
eval = ProsocialDialogueExplaination(cache_dir=data_path, split="validation")
dataset = ProsocialDialogueExplaination(cache_dir=data_path, split="train")
elif dataset_name == "soda":
dataset = SODA(data_path, **kwargs)
elif dataset_name == "soda_dialogue":
Expand Down
5 changes: 2 additions & 3 deletions model/model_training/custom_datasets/qa_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,10 +519,9 @@ def __init__(self, cache_dir: str | Path, mode: str = "sft", input_max_length: i
self.mode = mode

dataset = load_dataset(
"gozfarb/ShareGPT_Vicuna_unfiltered",
"Aeala/ShareGPT_Vicuna_unfiltered",
cache_dir=cache_dir,
data_files=["ShareGPT_2023.05.02v0_unfiltered_cleaned_split.json"],
revision="7b8551404f3de5704d634e7516b9ff77be3e2700",
data_files=["ShareGPT_V4.3_unfiltered_cleaned_split.json"],
)["train"]

self.pairs = []
Expand Down

0 comments on commit 709bb99

Please sign in to comment.