Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into feat/optimizer
Browse files Browse the repository at this point in the history
# Conflicts:
#	classifier.py
#	datasets.py
  • Loading branch information
lkaesberg committed Jun 12, 2023
2 parents f32a442 + 25dd324 commit 96f7d6f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 32 deletions.
21 changes: 0 additions & 21 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,24 +379,3 @@ def get_args():

print('Evaluating on SST...')
test(config)

print('Training Sentiment Classifier on cfimdb...')
config = SimpleNamespace(
filepath='cfimdb-classifier.pt',
lr=args.lr,
use_gpu=args.use_gpu,
epochs=args.epochs,
batch_size=8,
hidden_dropout_prob=args.hidden_dropout_prob,
train='data/ids-cfimdb-train.csv',
dev='data/ids-cfimdb-dev.csv',
test='data/ids-cfimdb-test-student.csv',
option=args.option,
dev_out='predictions/' + args.option + '-cfimdb-dev-out.csv',
test_out='predictions/' + args.option + '-cfimdb-test-out.csv'
)

train(config)

print('Evaluating on cfimdb...')
test(config)
21 changes: 10 additions & 11 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,16 @@ def load_multitask_test_data():
sentiment_data = []

with open(sentiment_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
sent = record['sentence'].lower().strip()
sentiment_data.append(sent)

print(f"Loaded {len(sentiment_data)} test examples from {sentiment_filename}")

paraphrase_data = []
with open(paraphrase_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
# if record['split'] != split:
for record in csv.DictReader(fp,delimiter = '\t'):
#if record['split'] != split:
# continue
paraphrase_data.append((preprocess_string(record['sentence1']),
preprocess_string(record['sentence2']),
Expand All @@ -230,7 +230,7 @@ def load_multitask_test_data():

similarity_data = []
with open(similarity_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
similarity_data.append((preprocess_string(record['sentence1']),
preprocess_string(record['sentence2']),
))
Expand All @@ -245,13 +245,13 @@ def load_multitask_data(sentiment_filename, paraphrase_filename, similarity_file
num_labels = {}
if split == 'test':
with open(sentiment_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
sent = record['sentence'].lower().strip()
sent_id = record['id'].lower().strip()
sentiment_data.append((sent, sent_id))
else:
with open(sentiment_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
sent = record['sentence'].lower().strip()
sent_id = record['id'].lower().strip()
label = int(record['sentiment'].strip())
Expand All @@ -264,16 +264,15 @@ def load_multitask_data(sentiment_filename, paraphrase_filename, similarity_file
paraphrase_data = []
if split == 'test':
with open(paraphrase_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
sent_id = record['id'].lower().strip()
paraphrase_data.append((preprocess_string(record['sentence1']),
preprocess_string(record['sentence2']),
sent_id))

else:
with open(paraphrase_filename, 'r', encoding='utf-8') as fp:
print(fp)
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
try:
sent_id = record['id'].lower().strip()
paraphrase_data.append((preprocess_string(record['sentence1']),
Expand All @@ -287,14 +286,14 @@ def load_multitask_data(sentiment_filename, paraphrase_filename, similarity_file
similarity_data = []
if split == 'test':
with open(similarity_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
sent_id = record['id'].lower().strip()
similarity_data.append((preprocess_string(record['sentence1']),
preprocess_string(record['sentence2'])
, sent_id))
else:
with open(similarity_filename, 'r', encoding='utf-8') as fp:
for record in csv.DictReader(fp, delimiter='\t'):
for record in csv.DictReader(fp,delimiter = '\t'):
sent_id = record['id'].lower().strip()
similarity_data.append((preprocess_string(record['sentence1']),
preprocess_string(record['sentence2']),
Expand Down

0 comments on commit 96f7d6f

Please sign in to comment.