Skip to content

Commit

Permalink
changed output
Browse files Browse the repository at this point in the history
  • Loading branch information
aagrawl3 committed Dec 8, 2016
1 parent 10ea2b4 commit 9496c78
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lambdamart import LambdaMART
import numpy as np
import pandas as pd

def get_data(file_loc):
f = open(file_loc, 'r')
Expand All @@ -18,10 +19,36 @@ def get_data(file_loc):
f.close()
return np.array(data)

def group_queries(data):
query_indexes = {}
index = 0
for record in data:
query_indexes.setdefault(record[1], [])
query_indexes[record[1]].append(index)
index += 1
return query_indexes

training_data = get_data('/Users/madhavagrawal/Downloads/MQ2007/Fold1/train.txt')
test_data = get_data('/Users/madhavagrawal/Downloads/MQ2007/Fold1/test.txt')
model = LambdaMART(training_data, 2, 10, 0.001)
model.fit()
model.save('lambdamart_model')
# model = LambdaMART(training_data, 2, 10, 0.001)
# model.fit()
# model.save('lambdamart_model')
model = LambdaMART()
model.load('lambdamart_model.lmart')
average_ndcg, predicted_scores = model.validate(test_data)
print average_ndcg

print 'NDCG score: %f' % (average_ndcg)
query_indexes = group_queries(test_data)
index = query_indexes.keys()[0]
testdata = [test_data[i][0] for i in query_indexes[index]]
pred = [predicted_scores[i] for i in query_indexes[index]]
output = pd.DataFrame({"True label": testdata, "prediction": pred})
output = output.sort('prediction',ascending = False)
output.to_csv("outdemo.csv", index =False)
print output
# for i in query_indexes[index]:
# print test_data[i][0], predicted_scores[i]




0 comments on commit 9496c78

Please sign in to comment.