-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_features.py
36 lines (31 loc) · 2.55 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from src.quere import ClosedEndedExplanationDataset, OpenEndedExplanationDataset, SquadExplanationDataset
import argparse
import numpy as np
import torch
if __name__ == "__main__":
# set random seed
np.random.seed(0)
torch.manual_seed(0)
parser = argparse.ArgumentParser()
parser.add_argument("--llm", type=str, default="llama-7b")
parser.add_argument("--dataset", type=str, default="WinoGrande")
parser.add_argument("--random", action="store_true", default=False, help="Use random prompts")
parser.add_argument("--gpt_exp", action="store_true", default=False, help="Use GPT explanations")
parser.add_argument("--gpt_diverse", action="store_true", default=False, help="Use diverse GPT explanations")
parser.add_argument("--gpt_sim", action="store_true", default=False, help="Use GPT explanations with similar prompts")
parser.add_argument("--random_tokens", action="store_true", default=False, help="Use random tokens")
args = parser.parse_args()
if args.dataset == "BooIQ":
dataset = ClosedEndedExplanationDataset("BooIQ", args.llm, gpt_exp=args.gpt_exp, gpt_diverse=args.gpt_diverse, random=args.random, gpt_sim=args.gpt_sim, random_tokens=args.random_tokens)
elif args.dataset == "HaluEval":
dataset = ClosedEndedExplanationDataset("HaluEval", args.llm, gpt_exp=args.gpt_exp, gpt_diverse=args.gpt_diverse, random=args.random, gpt_sim=args.gpt_sim, random_tokens=args.random_tokens)
elif args.dataset == "ToxicEval":
dataset = ClosedEndedExplanationDataset("ToxicEval", args.llm, gpt_exp=args.gpt_exp, gpt_diverse=args.gpt_diverse, random=args.random, gpt_sim=args.gpt_sim, random_tokens=args.random_tokens)
elif args.dataset == "CommonsenseQA":
dataset = ClosedEndedExplanationDataset("CommonsenseQA", args.llm, gpt_exp=args.gpt_exp, gpt_diverse=args.gpt_diverse, random=args.random, gpt_sim=args.gpt_sim, random_tokens=args.random_tokens)
elif args.dataset == "WinoGrande":
dataset = ClosedEndedExplanationDataset("WinoGrande", args.llm, gpt_exp=args.gpt_exp, gpt_diverse=args.gpt_diverse, random=args.random, gpt_sim=args.gpt_sim, random_tokens=args.random_tokens)
elif args.dataset == "squad":
dataset = SquadExplanationDataset(args.llm, gpt_exp=args.gpt_exp, gpt_diverse=args.gpt_diverse, random=args.random, gpt_sim=args.gpt_sim, random_tokens=args.random_tokens)
elif args.dataset == "nq":
dataset = OpenEndedExplanationDataset(args.llm, gpt_exp=args.gpt_exp, gpt_diverse=args.gpt_diverse, random=args.random, gpt_sim=args.gpt_sim, random_tokens=args.random_tokens)