Skip to content

Commit

Permalink
[SPARKNLP-1031] Solves Dependency Parsers training issue (#14225)
Browse files Browse the repository at this point in the history
  • Loading branch information
danilojsl authored Apr 5, 2024
1 parent 2c54a27 commit 75dbfcc
Show file tree
Hide file tree
Showing 12 changed files with 280 additions and 348 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class DependencyParserApproach(override val uid: String)

val sentences = cleanConllUSentence.map { conllUWord =>
val wordArray = conllUWord.split(SEPARATOR)
if (!wordArray(ID_INDEX).contains(".")) {
if (wordArray(ID_INDEX).matches("\\d+") && !wordArray(ID_INDEX).contains(".")) {
var head = wordArray(HEAD_INDEX).toInt
if (head == 0) {
head = cleanConllUSentence.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ void setInstIds(DictionarySet dicts,
}

// set special pos
//TODO: Check if this is used somewhere
specialPos = new SpecialPos[length];
for (int i = 0; i < length; ++i) {
if (coarseMap.containsKey(uPosTags[i])) {
Expand All @@ -249,6 +248,9 @@ else if (cpos.equals("VERB"))
else
specialPos[i] = SpecialPos.OTHER;
} else {
if (forms[i] == null || uPosTags[i] == null) {
continue;
}
specialPos[i] = getSpecialPos(forms[i], uPosTags[i]);
}
}
Expand All @@ -265,7 +267,9 @@ private String normalize(String s) {
// (http://groups.csail.mit.edu/nlp/egstra/).
//
private SpecialPos getSpecialPos(String form, String tag) {

if (tag == null || form == null) {
return SpecialPos.OTHER;
}
if (tag.charAt(0) == 'v' || tag.charAt(0) == 'V')
return SpecialPos.V;
else if (tag.charAt(0) == 'n' || tag.charAt(0) == 'N')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,32 +290,33 @@ public DependencyInstance nextSentence(ConllData[] sentence, String conllFormat)
if (dependencyInstance == null) {
return null;
}
//TODO: Here is where cpostagids are set
//Here is where cpostagids are set
dependencyInstance.setInstIds(dictionariesSet, coarseMap, conjWord);

return dependencyInstance;
}

public void pruneLabel(DependencyInstance[] dependencyInstances) {
int numPOS = dictionariesSet.getDictionarySize(POS) + 1;
int numLab = dictionariesSet.getDictionarySize(DEP_LABEL) + 1;
this.pruneLabel = new boolean[numPOS][numPOS][numLab];
int numLabel = dictionariesSet.getDictionarySize(DEP_LABEL) + 1;
this.pruneLabel = new boolean[numPOS][numPOS][numLabel];
int num = 0;

for (DependencyInstance dependencyInstance : dependencyInstances) {
int n = dependencyInstance.getLength();
for (int mod = 1; mod < n; ++mod) {
int head = dependencyInstance.getHeads()[mod];
int lab = dependencyInstance.getDependencyLabelIds()[mod];
if (!this.pruneLabel[dependencyInstance.getXPosTagIds()[head]][dependencyInstance.getXPosTagIds()[mod]][lab]) {
this.pruneLabel[dependencyInstance.getXPosTagIds()[head]][dependencyInstance.getXPosTagIds()[mod]][lab] = true;
int label = dependencyInstance.getDependencyLabelIds()[mod];
int [] posTagIds = dependencyInstance.getXPosTagIds();
if (!this.pruneLabel[posTagIds[head]][posTagIds[mod]][label]) {
this.pruneLabel[posTagIds[head]][posTagIds[mod]][label] = true;
num++;
}
}
}

if (logger.isDebugEnabled()) {
logger.debug(String.format("Prune label: %d/%d", num, numCPOS * numCPOS * numLab));
logger.debug(String.format("Prune label: %d/%d", num, numCPOS * numCPOS * numLabel));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ public class LocalFeatureData {
private float[][] scoresOrProbabilities;
private float[][][] labelScores;

private static final int FORWARD_DIRECTION = 1;
private static final int BACKWARD_DIRECTION = 2;

LocalFeatureData(DependencyInstance dependencyInstance,
TypedDependencyParser parser) {
this.dependencyInstance = dependencyInstance;
Expand Down Expand Up @@ -116,57 +119,137 @@ private FeatureVector getLabelFeature(int[] heads, int[] types, int mod, int ord
return fv;
}

private void predictLabelsDP(int[] heads, int[] dependencyLabelIds, boolean addLoss, DependencyArcList arcLis) {
void predictLabels(int[] heads, int[] dependencyLabelIds, boolean addLoss) {
DependencyArcList arcLis = new DependencyArcList(heads);
predictLabelsDependencyParser(heads, dependencyLabelIds, addLoss, arcLis);
}

private void predictLabelsDependencyParser(int[] heads,
int[] dependencyLabelIds,
boolean addLoss,
DependencyArcList arcList) {
int startLabelIndex = addLoss ? 0 : 1;

for (int mod = 1; mod < sentenceLength; ++mod) {
int head = heads[mod];
int dir = head > mod ? 1 : 2;
int gp = heads[head];
int pdir = gp > head ? 1 : 2;
for (int modifier = 1; modifier < sentenceLength; ++modifier) {
for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) {
int[] posTagIds = dependencyInstance.getXPosTagIds();
boolean pruneLabel = pipe.getPruneLabel()[posTagIds[head]][posTagIds[mod]][labelIndex];
if (pruneLabel) {
dependencyLabelIds[mod] = labelIndex;
float s1 = 0;
if (gammaL > 0)
s1 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 1);
if (gammaL < 1)
s1 += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[mod], labelIndex, dir);
for (int q = startLabelIndex; q < numberOfLabelTypes; ++q) {
float s2 = 0;
if (gp != -1) {
if (pipe.getPruneLabel()[posTagIds[gp]][posTagIds[head]][q]) {
dependencyLabelIds[head] = q;
if (gammaL > 0)
s2 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 2);
if (gammaL < 1)
s2 += (1 - gammaL) * parameters.dotProduct2L(wpU2[gp], wpV2[head], wpW2[mod], q, labelIndex, pdir, dir);
} else {
s2 = Float.NEGATIVE_INFINITY;
}
}
labelScores[mod][labelIndex][q] = s1 + s2 + (addLoss && dependencyInstance.getDependencyLabelIds()[mod] != labelIndex ? 1.0f : 0.0f);
}
} else {
Arrays.fill(labelScores[mod][labelIndex], Float.NEGATIVE_INFINITY);
}
processLabel(
heads,
modifier,
startLabelIndex,
labelIndex,
addLoss,
dependencyLabelIds);
}
}
finalizePredictions(arcList, startLabelIndex, dependencyLabelIds);
}

private int getDirection(int from, int to) {
return from > to ? FORWARD_DIRECTION : BACKWARD_DIRECTION;
}

private void processLabel(
int[] heads,
int modifier,
int startLabelIndex,
int labelIndex,
boolean addLoss,
int[] dependencyLabelIds
) {
int head = heads[modifier];
if (!shouldPruneLabel(head, modifier, labelIndex)) {
Arrays.fill(labelScores[modifier][labelIndex], Float.NEGATIVE_INFINITY);
return;
}

treeDP(0, arcLis, startLabelIndex);
dependencyLabelIds[0] = dependencyInstance.getDependencyLabelIds()[0];
computeDependencyLabels(0, arcLis, dependencyLabelIds, startLabelIndex);
dependencyLabelIds[modifier] = labelIndex;
float score = calculateInitialScore(heads, dependencyLabelIds, modifier, labelIndex);
for (int index = startLabelIndex; index < numberOfLabelTypes; ++index) {
float adjustedScore =
adjustScoreForGrandparent(
heads,
modifier,
labelIndex,
index,
addLoss,
dependencyLabelIds);
labelScores[modifier][labelIndex][index] = score + adjustedScore;
}
}

private float adjustScoreForGrandparent(
int[] heads,
int modifier,
int labelIndex,
int currentLabelIndex,
boolean addLoss,
int[] dependencyLabelIds
) {
float score = Float.NEGATIVE_INFINITY;

int head = heads[modifier];
int grandParent = heads[head];
int direction = getDirection(head, modifier);
int parentDirection = getDirection(grandParent, head);

if (grandParent != -1 && shouldPruneLabel(grandParent, head, currentLabelIndex)) {
dependencyLabelIds[head] = currentLabelIndex;
score = 0; // Reset score for new calculations
if (gammaL > 0)
score += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, head, BACKWARD_DIRECTION);
if (gammaL < 1) {
float dotProduct2L = parameters.dotProduct2L(
wpU2[grandParent],
wpV2[head],
wpW2[modifier],
currentLabelIndex,
labelIndex,
parentDirection,
direction);
score += (1 - gammaL) * dotProduct2L;
}
score += (addLoss && dependencyInstance.getDependencyLabelIds()[modifier] != labelIndex ? 1.0f : 0.0f);
}
return score;
}

private boolean shouldPruneLabel(int dim1, int dim2, int dim3) {
int[] posTagIds = dependencyInstance.getXPosTagIds();
return pipe.getPruneLabel()[posTagIds[dim1]][posTagIds[dim2]][dim3];
}

private float calculateInitialScore(
int[] heads,
int[] dependencyLabelIds,
int modifier,
int labelIndex
) {
int head = heads[modifier];
int direction = getDirection(head, modifier);
float score = 0;
if (gammaL > 0)
score += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, modifier, FORWARD_DIRECTION);
if (gammaL < 1)
score += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[modifier], labelIndex, direction);
return score;
}

private float getLabelScoreTheta(int[] heads, int[] types, int mod, int order) {
private float getLabelScoreTheta(
int[] heads,
int[] types,
int mod,
int order
) {
ScoreCollector collector = new ScoreCollector(parameters.getParamsL());
synFactory.createLabelFeatures(collector, dependencyInstance, heads, types, mod, order);
return collector.getScore();
}

private void finalizePredictions(DependencyArcList arcList, int startLabelIndex, int[] dependencyLabelIds) {
treeDP(0, arcList, startLabelIndex);
dependencyLabelIds[0] = dependencyInstance.getDependencyLabelIds()[0];
computeDependencyLabels(0, arcList, dependencyLabelIds, startLabelIndex);
}

private void treeDP(int indexNode, DependencyArcList dependencyArcs, int startLabelIndex) {
Arrays.fill(scoresOrProbabilities[indexNode], 0);
int startArcIndex = dependencyArcs.startIndex(indexNode);
Expand Down Expand Up @@ -217,9 +300,4 @@ private void computeDependencyLabels(int indexNode,
}
}

void predictLabels(int[] heads, int[] dependencyLabelIds, boolean addLoss) {
DependencyArcList arcLis = new DependencyArcList(heads);
predictLabelsDP(heads, dependencyLabelIds, addLoss, arcLis);
}

}
Loading

0 comments on commit 75dbfcc

Please sign in to comment.