Skip to content

Commit

Permalink
SPARKNLP-750 Handling all scores as -Infinity and refactoring depende…
Browse files Browse the repository at this point in the history
…ncy parser (#13620)
  • Loading branch information
danilojsl authored Mar 14, 2023
1 parent bad5435 commit 85ef6ae
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package com.johnsnowlabs.nlp.annotators.parser.typdep;

public class DependencyArcList {
private int n;
private int headsSize;
private int[] st;
private int[] edges;

DependencyArcList(int[] heads) {
n = heads.length;
st = new int[n];
edges = new int[n];
headsSize = heads.length;
st = new int[headsSize];
edges = new int[headsSize];
constructDepTreeArcList(heads);
}

Expand All @@ -33,7 +33,7 @@ int startIndex(int i) {
}

int endIndex(int i) {
return (i >= n - 1) ? n - 1 : st[i + 1];
return (i >= headsSize - 1) ? headsSize - 1 : st[i + 1];
}

public int get(int i) {
Expand All @@ -42,21 +42,22 @@ public int get(int i) {

private void constructDepTreeArcList(int[] heads) {

for (int i = 0; i < n; ++i)
for (int i = 0; i < headsSize; ++i)
st[i] = 0;

for (int i = 1; i < n; ++i) {
for (int i = 1; i < headsSize; ++i) {
int j = heads[i];
++st[j];
}

for (int i = 1; i < n; ++i)
for (int i = 1; i < headsSize; ++i)
st[i] += st[i - 1];

for (int i = n - 1; i > 0; --i) {
for (int i = headsSize - 1; i > 0; --i) {
int j = heads[i];
--st[j];
edges[st[j]] = i;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public class LocalFeatureData {
private Options options;
private Parameters parameters;

private final int len; // sentence length
private final int ntypes; // number of label types
private final int sentenceLength;
private final int numberOfLabelTypes;
private final float gammaL;

FeatureVector[] wordFvs; // word feature vector
Expand All @@ -45,7 +45,7 @@ public class LocalFeatureData {
float[][] wpV2;
float[][] wpW2;

private float[][] f;
private float[][] scoresOrProbabilities;
private float[][][] labelScores;

LocalFeatureData(DependencyInstance dependencyInstance,
Expand All @@ -56,25 +56,25 @@ public class LocalFeatureData {
options = parser.getOptions();
parameters = parser.getParameters();

len = dependencyInstance.getLength();
ntypes = pipe.getTypes().length;
sentenceLength = dependencyInstance.getLength();
numberOfLabelTypes = pipe.getTypes().length;
int rank = options.rankFirstOrderTensor;
int rank2 = options.rankSecondOrderTensor;
gammaL = options.gammaLabel;

wordFvs = new FeatureVector[len];
wpU = new float[len][rank];
wpV = new float[len][rank];
wordFvs = new FeatureVector[sentenceLength];
wpU = new float[sentenceLength][rank];
wpV = new float[sentenceLength][rank];

wpU2 = new float[len][rank2];
wpV2 = new float[len][rank2];
wpW2 = new float[len][rank2];
wpU2 = new float[sentenceLength][rank2];
wpV2 = new float[sentenceLength][rank2];
wpW2 = new float[sentenceLength][rank2];


f = new float[len][ntypes];
labelScores = new float[len][ntypes][ntypes];
scoresOrProbabilities = new float[sentenceLength][numberOfLabelTypes];
labelScores = new float[sentenceLength][numberOfLabelTypes][numberOfLabelTypes];

for (int i = 0; i < len; ++i) {
for (int i = 0; i < sentenceLength; ++i) {
wordFvs[i] = synFactory.createWordFeatures(dependencyInstance, i);

parameters.projectU(wordFvs[i], wpU[i]);
Expand All @@ -84,7 +84,6 @@ public class LocalFeatureData {
parameters.projectV2(wordFvs[i], wpV2 != null ? wpV2[i] : new float[0]);
parameters.projectW2(wordFvs[i], wpW2 != null ? wpW2[i] : new float[0]);


}

}
Expand Down Expand Up @@ -117,88 +116,104 @@ private FeatureVector getLabelFeature(int[] heads, int[] types, int mod, int ord
return fv;
}

private void predictLabelsDP(int[] heads, int[] deplbids, boolean addLoss, DependencyArcList arcLis) {
private void predictLabelsDP(int[] heads, int[] dependencyLabelIds, boolean addLoss, DependencyArcList arcLis) {

int lab0 = addLoss ? 0 : 1;
int startLabelIndex = addLoss ? 0 : 1;

for (int mod = 1; mod < len; ++mod) {
for (int mod = 1; mod < sentenceLength; ++mod) {
int head = heads[mod];
int dir = head > mod ? 1 : 2;
int gp = heads[head];
int pdir = gp > head ? 1 : 2;
for (int p = lab0; p < ntypes; ++p) {
if (pipe.getPruneLabel()[dependencyInstance.getXPosTagIds()[head]][dependencyInstance.getXPosTagIds()[mod]][p]) {
deplbids[mod] = p;
for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) {
int[] posTagIds = dependencyInstance.getXPosTagIds();
boolean pruneLabel = pipe.getPruneLabel()[posTagIds[head]][posTagIds[mod]][labelIndex];
if (pruneLabel) {
dependencyLabelIds[mod] = labelIndex;
float s1 = 0;
if (gammaL > 0)
s1 += gammaL * getLabelScoreTheta(heads, deplbids, mod, 1);
s1 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 1);
if (gammaL < 1)
s1 += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[mod], p, dir);
for (int q = lab0; q < ntypes; ++q) {
s1 += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[mod], labelIndex, dir);
for (int q = startLabelIndex; q < numberOfLabelTypes; ++q) {
float s2 = 0;
if (gp != -1) {
if (pipe.getPruneLabel()[dependencyInstance.getXPosTagIds()[gp]][dependencyInstance.getXPosTagIds()[head]][q]) {
deplbids[head] = q;
if (pipe.getPruneLabel()[posTagIds[gp]][posTagIds[head]][q]) {
dependencyLabelIds[head] = q;
if (gammaL > 0)
s2 += gammaL * getLabelScoreTheta(heads, deplbids, mod, 2);
s2 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 2);
if (gammaL < 1)
s2 += (1 - gammaL) * parameters.dotProduct2L(wpU2[gp], wpV2[head], wpW2[mod], q, p, pdir, dir);
} else s2 = Float.NEGATIVE_INFINITY;
s2 += (1 - gammaL) * parameters.dotProduct2L(wpU2[gp], wpV2[head], wpW2[mod], q, labelIndex, pdir, dir);
} else {
s2 = Float.NEGATIVE_INFINITY;
}
}
labelScores[mod][p][q] = s1 + s2 + (addLoss && dependencyInstance.getDependencyLabelIds()[mod] != p ? 1.0f : 0.0f);
labelScores[mod][labelIndex][q] = s1 + s2 + (addLoss && dependencyInstance.getDependencyLabelIds()[mod] != labelIndex ? 1.0f : 0.0f);
}
} else Arrays.fill(labelScores[mod][p], Float.NEGATIVE_INFINITY);
} else {
Arrays.fill(labelScores[mod][labelIndex], Float.NEGATIVE_INFINITY);
}
}
}

treeDP(0, arcLis, lab0);
deplbids[0] = dependencyInstance.getDependencyLabelIds()[0];
getType(0, arcLis, deplbids, lab0);

treeDP(0, arcLis, startLabelIndex);
dependencyLabelIds[0] = dependencyInstance.getDependencyLabelIds()[0];
computeDependencyLabels(0, arcLis, dependencyLabelIds, startLabelIndex);
}

private float getLabelScoreTheta(int[] heads, int[] types, int mod, int order) {
ScoreCollector col = new ScoreCollector(parameters.getParamsL());
synFactory.createLabelFeatures(col, dependencyInstance, heads, types, mod, order);
return col.getScore();
ScoreCollector collector = new ScoreCollector(parameters.getParamsL());
synFactory.createLabelFeatures(collector, dependencyInstance, heads, types, mod, order);
return collector.getScore();
}

private void treeDP(int i, DependencyArcList arcLis, int lab0) {
Arrays.fill(f[i], 0);
int st = arcLis.startIndex(i);
int ed = arcLis.endIndex(i);
for (int l = st; l < ed; ++l) {
int j = arcLis.get(l);
treeDP(j, arcLis, lab0);
for (int p = lab0; p < ntypes; ++p) {
float best = Float.NEGATIVE_INFINITY;
for (int q = lab0; q < ntypes; ++q) {
float s = f[j][q] + labelScores[j][q][p];
if (s > best)
best = s;
private void treeDP(int indexNode, DependencyArcList dependencyArcs, int startLabelIndex) {
Arrays.fill(scoresOrProbabilities[indexNode], 0);
int startArcIndex = dependencyArcs.startIndex(indexNode);
int endArcIndex = dependencyArcs.endIndex(indexNode);
for (int arcIndex = startArcIndex; arcIndex < endArcIndex; ++arcIndex) {
int currentNode = dependencyArcs.get(arcIndex);
treeDP(currentNode, dependencyArcs, startLabelIndex);
for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) {
float currentScore = scoresOrProbabilities[currentNode][startLabelIndex];
float currentLabelScore = labelScores[currentNode][startLabelIndex][labelIndex];
float bestScore = currentScore + currentLabelScore;
for (int q = startLabelIndex + 1; q < numberOfLabelTypes; ++q) {
float score = scoresOrProbabilities[currentNode][q] + labelScores[currentNode][q][labelIndex];
if (score > bestScore)
bestScore = score;
}
f[i][p] += best;
scoresOrProbabilities[indexNode][labelIndex] += bestScore;
}
}
}

private void getType(int i, DependencyArcList arcLis, int[] types, int lab0) {
int p = types[i];
int st = arcLis.startIndex(i);
int ed = arcLis.endIndex(i);
for (int l = st; l < ed; ++l) {
int j = arcLis.get(l);
int bestq = 0;
float best = Float.NEGATIVE_INFINITY;
for (int q = lab0; q < ntypes; ++q) {
float s = f[j][q] + labelScores[j][q][p];
if (s > best) {
best = s;
bestq = q;
private void computeDependencyLabels(int indexNode,
DependencyArcList dependencyArcs,
int[] dependencyLabelIds,
int startLabelIndex) {
int dependencyLabelId = dependencyLabelIds[indexNode];
int startArcIndex = dependencyArcs.startIndex(indexNode);
int endArcIndex = dependencyArcs.endIndex(indexNode);
for (int arcIndex = startArcIndex; arcIndex < endArcIndex; ++arcIndex) {
int currentNode = dependencyArcs.get(arcIndex);
int bestLabel = 0;
float bestScore = Float.NEGATIVE_INFINITY;
for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) {
float currentScore = scoresOrProbabilities[currentNode][labelIndex];
float currentLabelScore = labelScores[currentNode][labelIndex][dependencyLabelId];
float totalScore = currentScore + currentLabelScore;
if (totalScore > bestScore) {
bestScore = totalScore;
bestLabel = labelIndex;
}
}
types[j] = bestq;
getType(j, arcLis, types, lab0);
if (bestScore == Float.NEGATIVE_INFINITY) {
// if all scores are -Infinity, assign the original type
bestLabel = dependencyLabelIds[currentNode];
}
dependencyLabelIds[currentNode] = bestLabel;
computeDependencyLabels(currentNode, dependencyArcs, dependencyLabelIds, startLabelIndex);
}
}

Expand Down

0 comments on commit 85ef6ae

Please sign in to comment.