Skip to content

Commit

Permalink
added more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aagrawl3 committed Dec 14, 2016
1 parent d2b7a63 commit 775c203
Showing 1 changed file with 91 additions and 6 deletions.
97 changes: 91 additions & 6 deletions lambdamart.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,9 @@ def single_dcg(scores, i, j):
"""
return (np.power(2, scores[i]) - 1) / np.log2(j + 2)

#true_scores, predicted_scores
def compute_lambda(args):
"""
Returns the DCG value at a single point.
Returns the lambda and w values for a given query.
Parameters
----------
args : zipped value of true_scores, predicted_scores, good_ij_pairs, idcg, query_key
Expand Down Expand Up @@ -155,6 +154,20 @@ def compute_lambda(args):
return lambdas[rev_indexes], w[rev_indexes], query_key

def group_queries(training_data, qid_index):
"""
Returns a dictionary that groups the documents by their query ids.
Parameters
----------
training_data : Numpy array of lists
Contains a list of document information. Each document’s format is [relevance score, query index, feature vector]
qid_index : int
This is the index where the qid is located in the training data
Returns
-------
query_indexes : dictionary
The keys were the different query ids and teh values were the indexes in the training data that are associated of those keys.
"""
query_indexes = {}
index = 0
for record in training_data:
Expand All @@ -164,6 +177,19 @@ def group_queries(training_data, qid_index):
return query_indexes

def get_pairs(scores):
"""
Returns pairs of indexes where the first value in the pair has a higher score than the second value in the pair.
Parameters
----------
scores : list of int
Contain a list of numbers
Returns
-------
query_pair : list of pairs
This contains a list of pairs of indexes in scores.
"""

query_pair = []
for query_scores in scores:
temp = sorted(query_scores, reverse=True)
Expand All @@ -178,10 +204,21 @@ def get_pairs(scores):
class LambdaMART:

def __init__(self, training_data=None, number_of_trees=5, learning_rate=0.1, tree_type='sklearn'):
'''
The format for training data is as follows:
[relevance, q_id, [feature vector]]
'''
"""
This is the constructor for the LambdaMART object.
Parameters
----------
training_data : list of int
Contain a list of numbers
number_of_trees : int (default: 5)
Number of trees LambdaMART goes through
learning_rate : float (default: 0.1)
Rate at which we update our prediction with each tree
tree_type : string (default: “sklearn”)
Either “sklearn” for using Sklearn implementation of the tree or “original” for using
our implementation of the tree.
"""

if tree_type != 'sklearn' and tree_type != 'original':
raise ValueError('The "tree_type" must be "sklearn" or "original"')
self.training_data = training_data
Expand All @@ -191,6 +228,10 @@ def __init__(self, training_data=None, number_of_trees=5, learning_rate=0.1, tre
self.tree_type = tree_type

def fit(self):
"""
Fits the model on the training data.
"""

predicted_scores = np.zeros(len(self.training_data))
query_indexes = group_queries(self.training_data, 1)
query_keys = query_indexes.keys()
Expand Down Expand Up @@ -230,6 +271,18 @@ def fit(self):
predicted_scores += prediction * self.learning_rate

def predict(self, data):
"""
Predicts the scores for the test dataset.
Parameters
----------
data : Numpy array of documents
Numpy array of documents with each document’s format is [query index, feature vector]
Returns
-------
predicted_scores : Numpy array of scores
This contains an array or the predicted scores for the documents.
"""
data = np.array(data)
query_indexes = group_queries(data, 0)
predicted_scores = np.zeros(len(data))
Expand All @@ -241,6 +294,22 @@ def predict(self, data):
return predicted_scores

def validate(self, data, k):
"""
Predicts the scores for the test dataset and calculates the NDCG value.
Parameters
----------
data : Numpy array of documents
Numpy array of documents with each document’s format is [relevance score, query index, feature vector]
k : int
this is used to compute the NDCG@k
Returns
-------
average_ndcg : float
This is the average NDCG value of all the queries
predicted_scores : Numpy array of scores
This contains an array or the predicted scores for the documents.
"""
data = np.array(data)
query_indexes = group_queries(data, 1)
average_ndcg = []
Expand All @@ -261,9 +330,25 @@ def validate(self, data, k):
return average_ndcg, predicted_scores

def save(self, fname):
"""
Saves the model into a “.lmart” file with the name given as a parameter.
Parameters
----------
fname : string
Filename of the file you want to save
"""
pickle.dump(self, open('%s.lmart' % (fname), "w"), protocol=2)

def load(self, fname):
"""
Loads the model from the “.lmart” file given as a parameter.
Parameters
----------
fname : string
Filename of the file you want to load
"""
model = pickle.load(open(fname , "r"))
self.training_data = model.training_data
self.number_of_trees = model.number_of_trees
Expand Down

0 comments on commit 775c203

Please sign in to comment.