-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_relabel.py
82 lines (63 loc) · 2.75 KB
/
dataset_relabel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from openai import OpenAI
from datasets import load_dataset, DatasetDict
from tqdm import tqdm
import time # Added import for time module
with open('api.txt', 'r') as file:
client = OpenAI(
# This is the default and can be omitted
api_key=file.read().strip(),
)
def extract_final_answer(solution):
prompt = f"""
Given the following math solution, what is the final answer?
Provide only the answer in its simplest form, without any additional explanation.
The answer could be a number, an expression, a set, or any other mathematical entity.
Solution:
{solution}
Final answer:
"""
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a math expert tasked with extracting the final numerical answer from a given solution."},
{"role": "user", "content": prompt}
],
max_tokens=20, # Limiting tokens as we expect a short numerical answer
n=1,
)
return response.choices[0].message.content.strip()
def update_solution(solution):
time.sleep(0.15) # Wait for 0.126 seconds (approximately 7.936 requests per second)
final_answer = extract_final_answer(solution)
return f"{solution}\nThe final answer is ${final_answer}$. I hope it is correct."
def update_dataset_split(dataset_split):
updated_data = []
for item in tqdm(dataset_split, desc=f"Updating solutions for {dataset_split.split}"):
original_solution = item['solution']
try:
# Update the solution using the existing function
updated_solution = update_solution(original_solution)
# Include all original fields in the updated data
updated_data.append({
**item, # Copy all original fields
"final_answer_solution": updated_solution, # Update the solution field with a new name
})
except Exception as e:
print(f"Error processing item: {e}")
updated_data.append(item) # Keep original item if there's an error
return dataset_split.from_list(updated_data)
def update_dataset():
# Update each split
updated_train = update_dataset_split(load_dataset("lighteval/MATH", split="train[95%:]"))
updated_test = update_dataset_split(load_dataset("lighteval/MATH", split="test[95%:]"))
# Combine updated splits into a new dataset
dataset = DatasetDict({
'train': updated_train,
'test': updated_test
})
# Save the updated dataset
dataset.save_to_disk("updated_math_dataset")
print("Dataset updated and saved to 'updated_math_dataset' directory")
if __name__ == "__main__":
update_dataset()