Skip to content

Commit

Permalink
updated test
Browse files Browse the repository at this point in the history
  • Loading branch information
aagrawl3 committed Dec 15, 2016
1 parent 024ce64 commit f325cd4
Showing 1 changed file with 45 additions and 20 deletions.
65 changes: 45 additions & 20 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,52 @@ def group_queries(data):
index += 1
return query_indexes

training_data = get_data('/Users/madhavagrawal/Downloads/MQ2007/Fold1/train.txt')
test_data = get_data('/Users/madhavagrawal/Downloads/MQ2007/Fold1/test.txt')
# model = LambdaMART(training_data, 2, 10, 0.001)
# model.fit()
# model.save('lambdamart_model')
model = LambdaMART()
model.load('lambdamart_model.lmart')
average_ndcg, predicted_scores = model.validate(test_data)

print 'NDCG score: %f' % (average_ndcg)
query_indexes = group_queries(test_data)
index = query_indexes.keys()[0]
testdata = [test_data[i][0] for i in query_indexes[index]]
pred = [predicted_scores[i] for i in query_indexes[index]]
output = pd.DataFrame({"True label": testdata, "prediction": pred})
output = output.sort('prediction',ascending = False)
output.to_csv("outdemo.csv", index =False)
print output
# for i in query_indexes[index]:
# print test_data[i][0], predicted_scores[i]

def main():
total_ndcg = 0.0
for i in [1,2,3,4,5]:
print 'start Fold ' + str(i)
training_data = get_data('Fold%d/train.txt' % (i))
test_data = get_data('Fold%d/test.txt' % (i))
model = LambdaMART(training_data, 300, 0.001, 'sklearn')
model.fit()
model.save('lambdamart_model_%d' % (i))
# model = LambdaMART()
# model.load('lambdamart_model.lmart')
average_ndcg, predicted_scores = model.validate(test_data, 10)
print average_ndcg
total_ndcg += average_ndcg
total_ndcg /= 5.0
print 'Original average ndcg at 10 is: ' + str(total_ndcg)

total_ndcg = 0.0
for i in [1,2,3,4,5]:
print 'start Fold ' + str(i)
training_data = get_data('Fold%d/train.txt' % (i))
test_data = get_data('Fold%d/test.txt' % (i))
model = LambdaMART(training_data, 300, 0.001, 'original')
model.fit()
model.save('lambdamart_model_sklearn_%d' % (i))
# model = LambdaMART()
# model.load('lambdamart_model.lmart')
average_ndcg, predicted_scores = model.validate(test_data, 10)
print average_ndcg
total_ndcg += average_ndcg
total_ndcg /= 5.0
print 'Sklearn average ndcg at 10 is: ' + str(total_ndcg)

# print 'NDCG score: %f' % (average_ndcg)
# query_indexes = group_queries(test_data)
# index = query_indexes.keys()[0]
# testdata = [test_data[i][0] for i in query_indexes[index]]
# pred = [predicted_scores[i] for i in query_indexes[index]]
# output = pd.DataFrame({"True label": testdata, "prediction": pred})
# output = output.sort('prediction',ascending = False)
# output.to_csv("outdemo.csv", index =False)
# print output
# # for i in query_indexes[index]:
# # print test_data[i][0], predicted_scores[i]


if __name__ == '__main__':
main()

0 comments on commit f325cd4

Please sign in to comment.