Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLE-1053 rebase asapp fixes #4

Open
wants to merge 29 commits into
base: ASAPP-fixes
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c0e7a13
merge computeSubwords functions
Celebio Sep 13, 2018
711f513
added FAQ on how to get reproducible results (#633)
Sep 13, 2018
a5d22ab
Fix broken link (#590)
EmilStenstrom Sep 13, 2018
8e68462
Conforming to Facebook c++ style
Celebio Oct 24, 2018
be1e597
Compute precision/recall for each label
Celebio Oct 24, 2018
25c3994
fixing python binding for `predict` function
Celebio Oct 26, 2018
6efad35
duplicate import removed
edenbaus Nov 2, 2018
58fe650
fixing missing include in productquantizer.cc that is causing compila…
Celebio Nov 2, 2018
2e52f53
Refactor model testing and metrics code. (#672)
Nov 6, 2018
4a3b5af
meter class refactoring for per-label stats, some function deprecatio…
Celebio Nov 6, 2018
d759dd1
adding python binding for `test-label`
Celebio Nov 6, 2018
0ddcd5f
adding coverage option for Makefile and setup.py
Celebio Nov 6, 2018
41a0f39
putting back the usage of vector to loop in C++ in multiline prediction
Celebio Nov 6, 2018
c180783
fix circleci errors
Celebio Nov 7, 2018
5c229ab
Fixed typos at readme.md (#662)
schneiderl Nov 8, 2018
ead7911
fix support for older C++11 compilers for python bindings
Celebio Nov 20, 2018
4aee63d
Add circleci build badges to the README.md
Celebio Nov 21, 2018
256032b
remove printing functions from fasttext class
Celebio Nov 23, 2018
b8022b5
python install, a more robust pybind11 include
Celebio Nov 27, 2018
a84a6a4
add argument names in fasttext.h
Celebio Nov 27, 2018
71b4101
Normalize buffer vector in analogy queries
Celebio Nov 27, 2018
8850c51
One-vs-all cross-entropy loss
Celebio Nov 27, 2018
7deac6d
adding ova loss option to python bindings
Celebio Dec 4, 2018
501b9b1
Better default for number of threads
whiletruelearn Dec 4, 2018
7842495
Re-licensing fasttext to MIT
Dec 18, 2018
3c4a3ea
footer language : default to EN (#581)
Dec 20, 2018
67e8950
set version to have an ASAPP suffix, add Cython to install_requires
fwph May 22, 2018
f74aad6
bump version after publish script change
cdfox-asapp Sep 7, 2018
b7fa4e7
Update setup.py
cdfox-asapp Oct 18, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
meter class refactoring for per-label stats, some function deprecatio…
…n in fasttext

Summary:
This diff is following up the pull-request diff `Refactor model testing and metrics code`:
- Merging classes LabelMetricsAccumulator and MetricsAccumulator into one : Meter
- putting back removed function signatures in fasttext.h and marking them as deprecated
- removal of f1 score from results (that will be added again later)
- simplifying main.cc thanks to the new api

Reviewed By: EdouardGrave

Differential Revision: D12903111

fbshipit-source-id: eb4116b207aad1713754c136e2a064e9517fdb57
Celebio authored and facebook-github-bot committed Nov 6, 2018

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
commit 4a3b5afa9206810ef43be5d9b20e409b18d86a72
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ set(HEADER_FILES
src/dictionary.h
src/fasttext.h
src/matrix.h
src/metrics.h
src/meter.h
src/model.h
src/productquantizer.h
src/qmatrix.h
@@ -37,7 +37,7 @@ set(SOURCE_FILES
src/fasttext.cc
src/main.cc
src/matrix.cc
src/metrics.cc
src/meter.cc
src/model.cc
src/productquantizer.cc
src/qmatrix.cc
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

CXX = c++
CXXFLAGS = -pthread -std=c++0x -march=native
OBJS = args.o dictionary.o productquantizer.o matrix.o qmatrix.o vector.o model.o utils.o metrics.o fasttext.o
OBJS = args.o dictionary.o productquantizer.o matrix.o qmatrix.o vector.o model.o utils.o meter.o fasttext.o
INCLUDES = -I.

opt: CXXFLAGS += -O3 -funroll-loops
@@ -42,8 +42,8 @@ model.o: src/model.cc src/model.h src/args.h
utils.o: src/utils.cc src/utils.h
$(CXX) $(CXXFLAGS) -c src/utils.cc

metrics.o: src/metrics.cc src/metrics.h
$(CXX) $(CXXFLAGS) -c src/metrics.cc
meter.o: src/meter.cc src/meter.h
$(CXX) $(CXXFLAGS) -c src/meter.cc

fasttext.o: src/fasttext.cc src/*.h
$(CXX) $(CXXFLAGS) -c src/fasttext.cc
7 changes: 3 additions & 4 deletions python/fastText/pybind/fasttext_pybind.cc
Original file line number Diff line number Diff line change
@@ -150,12 +150,11 @@ PYBIND11_MODULE(fasttext_pybind, m) {
if (!ifs.is_open()) {
throw std::invalid_argument("Test file cannot be opened!");
}
fasttext::MetricsAccumulator metricsAccumulator;
m.test(ifs, k, 0.0, metricsAccumulator);
fasttext::Meter meter;
m.test(ifs, k, 0.0, meter);
ifs.close();
const auto& metrics = metricsAccumulator.metrics();
return std::tuple<int64_t, double, double>(
metrics.numExamples, metrics.precision(), metrics.recall());
meter.nexamples(), meter.precision(), meter.recall());
})
.def(
"getSentenceVector",
47 changes: 41 additions & 6 deletions src/fasttext.cc
Original file line number Diff line number Diff line change
@@ -369,11 +369,17 @@ void FastText::skipgram(
}
}

void FastText::test(
std::istream& in,
int32_t k,
real threshold,
MetricsAccumulator& accumulator) {
std::tuple<int64_t, double, double>
FastText::test(std::istream& in, int32_t k, real threshold) {
Meter meter;
test(in, k, threshold, meter);

return std::tuple<int64_t, double, double>(
meter.nexamples(), meter.precision(), meter.recall());
}

void FastText::test(std::istream& in, int32_t k, real threshold, Meter& meter)
const {
std::vector<int32_t> line;
std::vector<int32_t> labels;
std::vector<std::pair<real, int32_t>> predictions;
@@ -386,7 +392,7 @@ void FastText::test(
if (!labels.empty() && !line.empty()) {
predictions.clear();
predict(k, line, predictions, threshold);
accumulator.log(labels, predictions);
meter.log(labels, predictions);
}
}
}
@@ -432,6 +438,13 @@ void FastText::predict(
}
}

void FastText::printLabelStats(std::istream& in, int32_t k, real threshold)
const {
Meter meter;
test(in, k, threshold, meter);
writePerLabelMetrics(std::cout, meter);
}

void FastText::getSentenceVector(std::istream& in, fasttext::Vector& svec) {
svec.zero();
if (args_->model == model_name::sup) {
@@ -714,4 +727,26 @@ bool FastText::isQuant() const {
return quant_;
}

void FastText::writePerLabelMetrics(std::ostream& out, Meter& meter) const {
out << std::fixed;
out << std::setprecision(6);

auto writeMetric = [&](const std::string& name, double value) {
out << name << " : ";
if (std::isfinite(value)) {
out << value;
} else {
out << "--------";
}
out << " ";
};

for (int32_t i = 0; i < dict_->nlabels(); i++) {
writeMetric("F1-Score", meter.f1Score(i));
writeMetric("Precision", meter.precision(i));
writeMetric("Recall", meter.recall(i));
out << " " << dict_->getLabel(i) << std::endl;
}
}

} // namespace fasttext
12 changes: 10 additions & 2 deletions src/fasttext.h
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
#include "args.h"
#include "dictionary.h"
#include "matrix.h"
#include "metrics.h"
#include "meter.h"
#include "model.h"
#include "qmatrix.h"
#include "real.h"
@@ -94,7 +94,8 @@ class FastText {
std::vector<int32_t> selectEmbeddings(int32_t) const;
void getSentenceVector(std::istream&, Vector&);
void quantize(const Args);
void test(std::istream&, int32_t, real, MetricsAccumulator&);
std::tuple<int64_t, double, double> test(std::istream&, int32_t, real = 0.0);
void test(std::istream&, int32_t, real, Meter&) const;
void predict(
int32_t,
const std::vector<int32_t>&,
@@ -116,5 +117,12 @@ class FastText {
void loadVectors(std::string);
int getDimension() const;
bool isQuant() const;

FASTTEXT_DEPRECATED(
"This function is deprecated, please use `test` function.")
void printLabelStats(std::istream&, int32_t, real = 0.0) const;
FASTTEXT_DEPRECATED(
"This function is deprecated and will be removed along with `printLabelStats`.")
void writePerLabelMetrics(std::ostream&, Meter&) const;
};
} // namespace fasttext
72 changes: 14 additions & 58 deletions src/main.cc
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ void printPredictUsage() {
<< std::endl;
}

void printPrintLabelStatsUsage() {
void printTestLabelUsage() {
std::cerr
<< "usage: fasttext test-label <model> <test-data> [<k>] [<th>]\n\n"
<< " <model> model filename\n"
@@ -126,41 +126,40 @@ void printDumpUsage() {
}

void test(const std::vector<std::string>& args) {
bool perLabel = args[1] == "test-label";

if (args.size() < 4 || args.size() > 6) {
printTestUsage();
perLabel ? printTestLabelUsage() : printTestUsage();
exit(EXIT_FAILURE);
}

const auto& model = args[2];
const auto& input = args[3];
int32_t k = args.size() > 4 ? std::stoi(args[4]) : 1;
real threshold = args.size() > 5 ? std::stof(args[5]) : 0;
real threshold = args.size() > 5 ? std::stof(args[5]) : 0.0;

FastText fasttext;
fasttext.loadModel(model);

MetricsAccumulator metricsAccumulator;
Meter meter;

if (input == "-") {
fasttext.test(std::cin, k, threshold, metricsAccumulator);
fasttext.test(std::cin, k, threshold, meter);
} else {
std::ifstream ifs(input);
if (!ifs.is_open()) {
std::cerr << "Test file cannot be opened!" << std::endl;
exit(EXIT_FAILURE);
}
fasttext.test(ifs, k, threshold, metricsAccumulator);
fasttext.test(ifs, k, threshold, meter);
}

const auto& metrics = metricsAccumulator.metrics();
if (perLabel) {
fasttext.writePerLabelMetrics(std::cout, meter);
}
meter.writeGeneralMetrics(std::cout, k);

std::cout << "N"
<< "\t" << metrics.numExamples << std::endl;
std::cout << std::setprecision(3);
std::cout << "P@" << k << "\t" << metrics.precision() << std::endl;
std::cout << "R@" << k << "\t" << metrics.recall() << std::endl;
std::cout << "F1"
<< "\t" << metrics.f1Score() << std::endl;
exit(0);
}

void predict(const std::vector<std::string>& args) {
@@ -197,47 +196,6 @@ void predict(const std::vector<std::string>& args) {
exit(0);
}

void printLabelStats(const std::vector<std::string>& args) {
if (args.size() < 4 || args.size() > 6) {
printPrintLabelStatsUsage();
exit(EXIT_FAILURE);
}

const auto& model = args[2];
const auto& input = args[3];
int32_t k = args.size() > 4 ? std::stoi(args[4]) : 1;
real threshold = args.size() > 5 ? std::stof(args[5]) : 0;

FastText fasttext;
fasttext.loadModel(model);

LabelMetricsAccumulator metricsAccumulator;

if (input == "-") {
fasttext.test(std::cin, k, threshold, metricsAccumulator);
} else {
std::ifstream ifs(input);
if (!ifs.is_open()) {
std::cerr << "Test file cannot be opened!" << std::endl;
exit(EXIT_FAILURE);
}
fasttext.test(ifs, k, threshold, metricsAccumulator);
}

metricsAccumulator.write(std::cout, fasttext.getDictionary());
const auto& metrics = metricsAccumulator.metrics();

std::cout << "N"
<< "\t" << metrics.numExamples << std::endl;
std::cout << std::setprecision(3);
std::cout << "P@" << k << "\t" << metrics.precision() << std::endl;
std::cout << "R@" << k << "\t" << metrics.recall() << std::endl;
std::cout << "F1"
<< "\t" << metrics.f1Score() << std::endl;

exit(0);
}

void printWordVectors(const std::vector<std::string> args) {
if (args.size() != 3) {
printPrintWordVectorsUsage();
@@ -391,7 +349,7 @@ int main(int argc, char** argv) {
std::string command(args[1]);
if (command == "skipgram" || command == "cbow" || command == "supervised") {
train(args);
} else if (command == "test") {
} else if (command == "test" || command == "test-label") {
test(args);
} else if (command == "quantize") {
quantize(args);
@@ -407,8 +365,6 @@ int main(int argc, char** argv) {
analogies(args);
} else if (command == "predict" || command == "predict-prob") {
predict(args);
} else if (command == "test-label") {
printLabelStats(args);
} else if (command == "dump") {
dump(args);
} else {
60 changes: 60 additions & 0 deletions src/meter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "meter.h"

#include <algorithm>
#include <cmath>
#include <iomanip>
#include <limits>

namespace fasttext {

void Meter::log(
const std::vector<int32_t>& labels,
const std::vector<std::pair<real, int32_t>>& predictions) {
nexamples_++;
metrics_.gold += labels.size();
metrics_.predicted += predictions.size();

for (const auto& prediction : predictions) {
labelMetrics_[prediction.second].predicted++;

if (std::find(labels.begin(), labels.end(), prediction.second) !=
labels.end()) {
labelMetrics_[prediction.second].predictedGold++;
metrics_.predictedGold++;
}
}

for (const auto& label : labels) {
labelMetrics_[label].gold++;
}
}

double Meter::precision(int32_t i) {
return labelMetrics_[i].precision();
}

double Meter::recall(int32_t i) {
return labelMetrics_[i].recall();
}

double Meter::f1Score(int32_t i) {
return labelMetrics_[i].f1Score();
}

double Meter::precision() const {
return metrics_.precision();
}

double Meter::recall() const {
return metrics_.recall();
}

void Meter::writeGeneralMetrics(std::ostream& out, int32_t k) const {
out << "N"
<< "\t" << nexamples_ << std::endl;
out << std::setprecision(3);
out << "P@" << k << "\t" << metrics_.precision() << std::endl;
out << "R@" << k << "\t" << metrics_.recall() << std::endl;
}

} // namespace fasttext
53 changes: 53 additions & 0 deletions src/meter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#pragma once

#include <unordered_map>
#include <vector>

#include "dictionary.h"
#include "real.h"

namespace fasttext {

class Meter {
struct Metrics {
uint64_t gold;
uint64_t predicted;
uint64_t predictedGold;

Metrics() : gold(0), predicted(0), predictedGold(0) {}

double precision() const {
return predictedGold / double(predicted);
}
double recall() const {
return predictedGold / double(gold);
}
double f1Score() const {
return 2 * predictedGold / double(predicted + gold);
}
};

public:
Meter() : metrics_(), nexamples_(0), labelMetrics_() {}

void log(
const std::vector<int32_t>& labels,
const std::vector<std::pair<real, int32_t>>& predictions);

double precision(int32_t);
double recall(int32_t);
double f1Score(int32_t);
double precision() const;
double recall() const;
uint64_t nexamples() const {
return nexamples_;
}
void writeGeneralMetrics(std::ostream& out, int32_t k) const;

private:
Metrics metrics_{};
uint64_t nexamples_;
std::unordered_map<int32_t, Metrics> labelMetrics_;
};

} // namespace fasttext
81 changes: 0 additions & 81 deletions src/metrics.cc

This file was deleted.

53 changes: 0 additions & 53 deletions src/metrics.h

This file was deleted.