Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKNLP-750 DependencyParserModel Outputs All Chunks as <no-type> #13620

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package com.johnsnowlabs.nlp.annotators.parser.typdep;

public class DependencyArcList {
private int n;
private int headsSize;
private int[] st;
private int[] edges;

DependencyArcList(int[] heads) {
n = heads.length;
st = new int[n];
edges = new int[n];
headsSize = heads.length;
st = new int[headsSize];
edges = new int[headsSize];
constructDepTreeArcList(heads);
}

Expand All @@ -33,7 +33,7 @@ int startIndex(int i) {
}

int endIndex(int i) {
return (i >= n - 1) ? n - 1 : st[i + 1];
return (i >= headsSize - 1) ? headsSize - 1 : st[i + 1];
}

public int get(int i) {
Expand All @@ -42,21 +42,22 @@ public int get(int i) {

private void constructDepTreeArcList(int[] heads) {

for (int i = 0; i < n; ++i)
for (int i = 0; i < headsSize; ++i)
st[i] = 0;

for (int i = 1; i < n; ++i) {
for (int i = 1; i < headsSize; ++i) {
int j = heads[i];
++st[j];
}

for (int i = 1; i < n; ++i)
for (int i = 1; i < headsSize; ++i)
st[i] += st[i - 1];

for (int i = n - 1; i > 0; --i) {
for (int i = headsSize - 1; i > 0; --i) {
int j = heads[i];
--st[j];
edges[st[j]] = i;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public class LocalFeatureData {
private Options options;
private Parameters parameters;

private final int len; // sentence length
private final int ntypes; // number of label types
private final int sentenceLength;
private final int numberOfLabelTypes;
private final float gammaL;

FeatureVector[] wordFvs; // word feature vector
Expand All @@ -45,7 +45,7 @@ public class LocalFeatureData {
float[][] wpV2;
float[][] wpW2;

private float[][] f;
private float[][] scoresOrProbabilities;
private float[][][] labelScores;

LocalFeatureData(DependencyInstance dependencyInstance,
Expand All @@ -56,25 +56,25 @@ public class LocalFeatureData {
options = parser.getOptions();
parameters = parser.getParameters();

len = dependencyInstance.getLength();
ntypes = pipe.getTypes().length;
sentenceLength = dependencyInstance.getLength();
numberOfLabelTypes = pipe.getTypes().length;
int rank = options.rankFirstOrderTensor;
int rank2 = options.rankSecondOrderTensor;
gammaL = options.gammaLabel;

wordFvs = new FeatureVector[len];
wpU = new float[len][rank];
wpV = new float[len][rank];
wordFvs = new FeatureVector[sentenceLength];
wpU = new float[sentenceLength][rank];
wpV = new float[sentenceLength][rank];

wpU2 = new float[len][rank2];
wpV2 = new float[len][rank2];
wpW2 = new float[len][rank2];
wpU2 = new float[sentenceLength][rank2];
wpV2 = new float[sentenceLength][rank2];
wpW2 = new float[sentenceLength][rank2];


f = new float[len][ntypes];
labelScores = new float[len][ntypes][ntypes];
scoresOrProbabilities = new float[sentenceLength][numberOfLabelTypes];
labelScores = new float[sentenceLength][numberOfLabelTypes][numberOfLabelTypes];

for (int i = 0; i < len; ++i) {
for (int i = 0; i < sentenceLength; ++i) {
wordFvs[i] = synFactory.createWordFeatures(dependencyInstance, i);

parameters.projectU(wordFvs[i], wpU[i]);
Expand All @@ -84,7 +84,6 @@ public class LocalFeatureData {
parameters.projectV2(wordFvs[i], wpV2 != null ? wpV2[i] : new float[0]);
parameters.projectW2(wordFvs[i], wpW2 != null ? wpW2[i] : new float[0]);


}

}
Expand Down Expand Up @@ -117,88 +116,104 @@ private FeatureVector getLabelFeature(int[] heads, int[] types, int mod, int ord
return fv;
}

private void predictLabelsDP(int[] heads, int[] deplbids, boolean addLoss, DependencyArcList arcLis) {
private void predictLabelsDP(int[] heads, int[] dependencyLabelIds, boolean addLoss, DependencyArcList arcLis) {

int lab0 = addLoss ? 0 : 1;
int startLabelIndex = addLoss ? 0 : 1;

for (int mod = 1; mod < len; ++mod) {
for (int mod = 1; mod < sentenceLength; ++mod) {
int head = heads[mod];
int dir = head > mod ? 1 : 2;
int gp = heads[head];
int pdir = gp > head ? 1 : 2;
for (int p = lab0; p < ntypes; ++p) {
if (pipe.getPruneLabel()[dependencyInstance.getXPosTagIds()[head]][dependencyInstance.getXPosTagIds()[mod]][p]) {
deplbids[mod] = p;
for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) {
int[] posTagIds = dependencyInstance.getXPosTagIds();
boolean pruneLabel = pipe.getPruneLabel()[posTagIds[head]][posTagIds[mod]][labelIndex];
if (pruneLabel) {
dependencyLabelIds[mod] = labelIndex;
float s1 = 0;
if (gammaL > 0)
s1 += gammaL * getLabelScoreTheta(heads, deplbids, mod, 1);
s1 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 1);
if (gammaL < 1)
s1 += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[mod], p, dir);
for (int q = lab0; q < ntypes; ++q) {
s1 += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[mod], labelIndex, dir);
for (int q = startLabelIndex; q < numberOfLabelTypes; ++q) {
float s2 = 0;
if (gp != -1) {
if (pipe.getPruneLabel()[dependencyInstance.getXPosTagIds()[gp]][dependencyInstance.getXPosTagIds()[head]][q]) {
deplbids[head] = q;
if (pipe.getPruneLabel()[posTagIds[gp]][posTagIds[head]][q]) {
dependencyLabelIds[head] = q;
if (gammaL > 0)
s2 += gammaL * getLabelScoreTheta(heads, deplbids, mod, 2);
s2 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 2);
if (gammaL < 1)
s2 += (1 - gammaL) * parameters.dotProduct2L(wpU2[gp], wpV2[head], wpW2[mod], q, p, pdir, dir);
} else s2 = Float.NEGATIVE_INFINITY;
s2 += (1 - gammaL) * parameters.dotProduct2L(wpU2[gp], wpV2[head], wpW2[mod], q, labelIndex, pdir, dir);
} else {
s2 = Float.NEGATIVE_INFINITY;
}
}
labelScores[mod][p][q] = s1 + s2 + (addLoss && dependencyInstance.getDependencyLabelIds()[mod] != p ? 1.0f : 0.0f);
labelScores[mod][labelIndex][q] = s1 + s2 + (addLoss && dependencyInstance.getDependencyLabelIds()[mod] != labelIndex ? 1.0f : 0.0f);
}
} else Arrays.fill(labelScores[mod][p], Float.NEGATIVE_INFINITY);
} else {
Arrays.fill(labelScores[mod][labelIndex], Float.NEGATIVE_INFINITY);
}
}
}

treeDP(0, arcLis, lab0);
deplbids[0] = dependencyInstance.getDependencyLabelIds()[0];
getType(0, arcLis, deplbids, lab0);

treeDP(0, arcLis, startLabelIndex);
dependencyLabelIds[0] = dependencyInstance.getDependencyLabelIds()[0];
computeDependencyLabels(0, arcLis, dependencyLabelIds, startLabelIndex);
}

private float getLabelScoreTheta(int[] heads, int[] types, int mod, int order) {
ScoreCollector col = new ScoreCollector(parameters.getParamsL());
synFactory.createLabelFeatures(col, dependencyInstance, heads, types, mod, order);
return col.getScore();
ScoreCollector collector = new ScoreCollector(parameters.getParamsL());
synFactory.createLabelFeatures(collector, dependencyInstance, heads, types, mod, order);
return collector.getScore();
}

private void treeDP(int i, DependencyArcList arcLis, int lab0) {
Arrays.fill(f[i], 0);
int st = arcLis.startIndex(i);
int ed = arcLis.endIndex(i);
for (int l = st; l < ed; ++l) {
int j = arcLis.get(l);
treeDP(j, arcLis, lab0);
for (int p = lab0; p < ntypes; ++p) {
float best = Float.NEGATIVE_INFINITY;
for (int q = lab0; q < ntypes; ++q) {
float s = f[j][q] + labelScores[j][q][p];
if (s > best)
best = s;
private void treeDP(int indexNode, DependencyArcList dependencyArcs, int startLabelIndex) {
Arrays.fill(scoresOrProbabilities[indexNode], 0);
int startArcIndex = dependencyArcs.startIndex(indexNode);
int endArcIndex = dependencyArcs.endIndex(indexNode);
for (int arcIndex = startArcIndex; arcIndex < endArcIndex; ++arcIndex) {
int currentNode = dependencyArcs.get(arcIndex);
treeDP(currentNode, dependencyArcs, startLabelIndex);
for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) {
float currentScore = scoresOrProbabilities[currentNode][startLabelIndex];
float currentLabelScore = labelScores[currentNode][startLabelIndex][labelIndex];
float bestScore = currentScore + currentLabelScore;
for (int q = startLabelIndex + 1; q < numberOfLabelTypes; ++q) {
float score = scoresOrProbabilities[currentNode][q] + labelScores[currentNode][q][labelIndex];
if (score > bestScore)
bestScore = score;
}
f[i][p] += best;
scoresOrProbabilities[indexNode][labelIndex] += bestScore;
}
}
}

private void getType(int i, DependencyArcList arcLis, int[] types, int lab0) {
int p = types[i];
int st = arcLis.startIndex(i);
int ed = arcLis.endIndex(i);
for (int l = st; l < ed; ++l) {
int j = arcLis.get(l);
int bestq = 0;
float best = Float.NEGATIVE_INFINITY;
for (int q = lab0; q < ntypes; ++q) {
float s = f[j][q] + labelScores[j][q][p];
if (s > best) {
best = s;
bestq = q;
private void computeDependencyLabels(int indexNode,
DependencyArcList dependencyArcs,
int[] dependencyLabelIds,
int startLabelIndex) {
int dependencyLabelId = dependencyLabelIds[indexNode];
int startArcIndex = dependencyArcs.startIndex(indexNode);
int endArcIndex = dependencyArcs.endIndex(indexNode);
for (int arcIndex = startArcIndex; arcIndex < endArcIndex; ++arcIndex) {
int currentNode = dependencyArcs.get(arcIndex);
int bestLabel = 0;
float bestScore = Float.NEGATIVE_INFINITY;
for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) {
float currentScore = scoresOrProbabilities[currentNode][labelIndex];
float currentLabelScore = labelScores[currentNode][labelIndex][dependencyLabelId];
float totalScore = currentScore + currentLabelScore;
if (totalScore > bestScore) {
bestScore = totalScore;
bestLabel = labelIndex;
}
}
types[j] = bestq;
getType(j, arcLis, types, lab0);
if (bestScore == Float.NEGATIVE_INFINITY) {
// if all scores are -Infinity, assign the original type
bestLabel = dependencyLabelIds[currentNode];
}
dependencyLabelIds[currentNode] = bestLabel;
computeDependencyLabels(currentNode, dependencyArcs, dependencyLabelIds, startLabelIndex);
}
}

Expand Down