-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassify.py
102 lines (86 loc) · 3.56 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import sqlite3
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier
from sklearn.model_selection import train_test_split as tts
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import OneHotEncoder
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import log_loss
from collections import defaultdict
from sklearn.metrics import f1_score, precision_score, recall_score
import argparse
def read(conn, name):
df = pd.read_sql(f'SELECT vec, label_ids FROM {name} JOIN Files ON {name}.file_id == Files.file_id', conn)
vectors = []
for v in df['vec']:
vectors.append(list(map(float, v.split(','))))
labels = []
unique = set()
count = defaultdict(int)
for lbls in df['label_ids']:
labels.append(list(map(int, lbls.split(','))))
for l in labels[-1]:
unique.add(l)
count[l] += 1
#for l in count:
# nm = conn.cursor().execute(f'SELECT label_desc FROM Labels WHERE label_id = {l}').fetchone()
# print(nm, count[l])
ohe = OneHotEncoder()
ohe.fit(np.asarray(list(unique)).reshape(-1, 1))
for i, lbls in enumerate(labels):
labels[i] = [0] * len(unique)
for l in lbls:
labels[i] += ohe.transform([[l]]).toarray()[0]
return np.asarray(vectors), np.asarray(labels)
def train_test_split(vectors, labels, test_size, random_state):
lbl2vec = defaultdict(list)
for vec, lbls in zip(vectors, labels):
for i, l in enumerate(lbls):
if l:
lbl2vec[i].append(vec)
vectors_train = []
labels_train = []
vectors_test = []
labels_test = []
for l in lbl2vec:
vec = lbl2vec[l]
if len(vec) > 1:
v_train, v_test, l_train, l_test = tts(
vec, [l] * len(vec), test_size=test_size, random_state=random_state
)
vectors_train.extend(v_train)
vectors_test.extend(v_test)
labels_train.extend(l_train)
labels_test.extend(l_test)
return np.asarray(vectors_train), np.asarray(vectors_test), np.asarray(labels_train), np.asarray(labels_test)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=['mouse', 'trecgen', '20ng'])
args = parser.parse_args()
#vector_names = ['word2vec', 'word2vec_tfidf', 'word2vec_idf', 'doc2vec', 'lsa', 'lda', 'rdf', 'topic_net']
conn = sqlite3.connect(f'data/{args.dataset}.sqlite')
vector_names = ['bert', 'word2vec', 'pv_dbow', 'lsa', 'lda']
for name in vector_names:
try:
v, true_labels = read(conn, name)
except pd.io.sql.DatabaseError as e:
continue
#print(true_labels.shape)
#print(true_labels[0])
v_train, v_test, l_train, l_test = train_test_split(v, true_labels, test_size=0.5, random_state=0)
#parameters = {'C': [1, 10]}
#total_len = sum(map(len, v_train))
#clf = OneVsRestClassifier(GridSearchCV(SVC(probability=True, class_weight='balanced'), parameters))
clf = OneVsRestClassifier(SVC(probability=True, class_weight='balanced'))
#clf = KNeighborsClassifier(n_neighbors=10, metric='euclidean')
clf.fit(v_train, l_train)
#pred = clf.predict_proba(v_test)
#loss = log_loss(l_test, pred)
l_pred = clf.predict(v_test)
print('{}, {:.4f}'.format(
name, f1_score(l_test, l_pred, average="weighted"))
)
#print(name, loss)
#exit(0)