Skip to content

Commit

Permalink
Merge pull request #2 from carschno/getSentenceVector
Browse files Browse the repository at this point in the history
Add getSentenceVector()
  • Loading branch information
carschno authored Jul 25, 2019
2 parents 0e03434 + 9793d7b commit a2d8f93
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: java
jdk:
- oraclejdk11
- openjdk11
os:
- linux
cache: bundler
Expand Down
9 changes: 8 additions & 1 deletion src/main/cpp/fasttext_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ namespace FastTextWrapper {

std::vector<real> FastTextApi::getVector(const std::string& word) {
Vector vec(privateMembers->args_->dim);
fastText.getVector(vec, word);
fastText.getWordVector(vec, word);
return std::vector<real>(vec.data(), vec.data() + vec.size());
}

std::vector<real> FastTextApi::getSentenceVector(const std::string& sentence) {
Vector vec(privateMembers->args_->dim);
std::istringstream in(sentence);
fastText.getSentenceVector(in, vec);
return std::vector<real>(vec.data(), vec.data() + vec.size());
}

Expand Down
1 change: 1 addition & 0 deletions src/main/cpp/fasttext_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace FastTextWrapper {
std::vector<std::string> predict(const std::string&, int32_t);
std::vector<std::pair<real,std::string>> predictProba(const std::string&, int32_t);
std::vector<real> getVector(const std::string&);
std::vector<real> getSentenceVector(const std::string&);
std::vector<std::string> getWords();
std::vector<std::string> getLabels();
std::string getWord(int32_t);
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/github/jfasttext/FastTextWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ public DoubleIntPair put(float firstValue, int secondValue) {
public native @ByVal FloatStringPairVector predictProba(@StdString String arg0, int arg1);
public native @ByVal RealVector getVector(@StdString BytePointer arg0);
public native @ByVal RealVector getVector(@StdString String arg0);
public native @ByVal RealVector getSentenceVector(@StdString BytePointer arg0);
public native @ByVal RealVector getSentenceVector(@StdString String arg0);
public native @ByVal StringVector getWords();
public native @ByVal StringVector getLabels();
public native @StdString BytePointer getWord(int arg0);
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/com/github/jfasttext/JFastText.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ public List<Float> getVector(String word) {
return wordVec;
}

public List<Float> getSentenceVector(String sentence) {
if (!sentence.endsWith("\n")) {
sentence += "\n";
}
FastTextWrapper.RealVector rv = fta.getSentenceVector(sentence);
List<Float> wordVec = new ArrayList<>();
for (int i = 0; i < rv.size(); i++) {
wordVec.add(rv.get(i));
}
return wordVec;
}

public int getNWords() {
return fta.getNWords();
}
Expand Down
20 changes: 17 additions & 3 deletions src/test/java/com/github/jfasttext/JFastTextTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import java.util.List;

import static org.junit.Assert.assertEquals;

@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class JFastTextTest {

Expand All @@ -16,7 +18,9 @@ public void test01TrainSupervisedCmd() {
jft.runCmd(new String[] {
"supervised",
"-input", "src/test/resources/data/labeled_data.txt",
"-output", "src/test/resources/models/supervised.model"
"-output", "src/test/resources/models/supervised.model",
"-wordNgrams", "3",
"-bucket", "100"
});
}

Expand Down Expand Up @@ -86,11 +90,21 @@ public void test07GetVector() throws Exception {
System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec);
}

@Test
public void test08GetSentenceVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
String word = "soccers";
List<Float> vec = jft.getSentenceVector(word);
int expectedSize = 100;
assertEquals(expectedSize, vec.size());
}

/**
* Test retrieving model's information: words, labels, learning rate, etc.
*/
@Test
public void test08ModelInfo() throws Exception {
public void test09ModelInfo() throws Exception {
System.out.printf("\nSupervised model information:\n");
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
Expand All @@ -113,7 +127,7 @@ public void test08ModelInfo() throws Exception {
* allocated by native function calls).
*/
@Test
public void test09ModelUnloading() throws Exception {
public void test10ModelUnloading() throws Exception {
JFastText jft = new JFastText();
System.out.println("\nLoading model ...");
jft.loadModel("src/test/resources/models/supervised.model.bin");
Expand Down

0 comments on commit a2d8f93

Please sign in to comment.