-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgbmModel.py
81 lines (47 loc) · 1.81 KB
/
gbmModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import lightgbm as lgb
import os
import numpy as np
import pickle
class GBMModel( ):
def __init__( self , path = "" , name = "" ):
# thic class aims to load gbm models
self.path = path
self.name = name
self.models = []
print("Loading LGB models")
self.load_models()
#print( len( self.models ) )
return
def load_models( self ):
files = os.listdir( self.path )
files = [ x for x in files if self.name in x ]
for f in files:
path = self.path + "/" + f
l = lgb.Booster( model_file = path )
self.models.append( l )
with open( self.path + "/logistic.pkl" , 'rb') as fid:
self.logistic = pickle.load(fid)
return True
def get_top_k( self , data , k ):
# data is an array -> [ word_emb_issue , word_emb_candidate_ ]
# return the index of the highest ranked scores
#print( data.shape )
y_preds = np.zeros( ( len( data ) ))
for i , gbm in enumerate( self.models ) :
y = gbm.predict( data )
y_preds += y
y_preds = y_preds/len( self.models)
#iiix = y_preds.argsort()[-20:][::-1]
#iiix = np.arange( len( y_preds ))
#return iiix , y_preds[iiix ]
y_preds = self.logistic.predict_proba( y_preds.reshape( (-1 , 1 )) )[: , 1 ]
#clf.predict_proba( calibration_data)[:,1]
indexs = y_preds.argsort()[-k:][::-1]
#indexs = np.where( y_preds == 1.0 )[0] #y_preds.argsort()[-k:][::-1]
#print( len( indexs ) )
#print( len( indexs ) )
#indexs = y_preds.argsort()[-20:][::-1]
#print( "valids:" , indexs )
#print("maxxxx" , y_preds[ indexs ].max( ))
return indexs , y_preds[ indexs ]
#TITLE-ABS-KEY ( ( requirement OR requirements OR software OR NLP OR machine learning) AND ( link OR dependency OR interdependency ) AND ( detection OR detector OR extraction ) AND ( "natural language" ) ) AND ( PUBYEAR > 2010 AND PUBYEAR < 2019 )