From 166bcbc32a01f280d5b34e3cc2efa5ce9f5c686f Mon Sep 17 00:00:00 2001 From: Danilo Burbano <37355249+danilojsl@users.noreply.github.com> Date: Tue, 14 Mar 2023 03:53:42 -0500 Subject: [PATCH] SPARKNLP-750 Handling all scores as -Infinity and refactoring dependency parser (#13620) --- .../parser/typdep/DependencyArcList.java | 19 +-- .../parser/typdep/LocalFeatureData.java | 149 ++++++++++-------- 2 files changed, 92 insertions(+), 76 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/DependencyArcList.java b/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/DependencyArcList.java index 8502101186b944..d3d383b95f3e26 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/DependencyArcList.java +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/DependencyArcList.java @@ -17,14 +17,14 @@ package com.johnsnowlabs.nlp.annotators.parser.typdep; public class DependencyArcList { - private int n; + private int headsSize; private int[] st; private int[] edges; DependencyArcList(int[] heads) { - n = heads.length; - st = new int[n]; - edges = new int[n]; + headsSize = heads.length; + st = new int[headsSize]; + edges = new int[headsSize]; constructDepTreeArcList(heads); } @@ -33,7 +33,7 @@ int startIndex(int i) { } int endIndex(int i) { - return (i >= n - 1) ? n - 1 : st[i + 1]; + return (i >= headsSize - 1) ? headsSize - 1 : st[i + 1]; } public int get(int i) { @@ -42,21 +42,22 @@ public int get(int i) { private void constructDepTreeArcList(int[] heads) { - for (int i = 0; i < n; ++i) + for (int i = 0; i < headsSize; ++i) st[i] = 0; - for (int i = 1; i < n; ++i) { + for (int i = 1; i < headsSize; ++i) { int j = heads[i]; ++st[j]; } - for (int i = 1; i < n; ++i) + for (int i = 1; i < headsSize; ++i) st[i] += st[i - 1]; - for (int i = n - 1; i > 0; --i) { + for (int i = headsSize - 1; i > 0; --i) { int j = heads[i]; --st[j]; edges[st[j]] = i; } } + } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/LocalFeatureData.java b/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/LocalFeatureData.java index 83d5d0fdf4d9db..f6487c32d6d39b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/LocalFeatureData.java +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/parser/typdep/LocalFeatureData.java @@ -30,8 +30,8 @@ public class LocalFeatureData { private Options options; private Parameters parameters; - private final int len; // sentence length - private final int ntypes; // number of label types + private final int sentenceLength; + private final int numberOfLabelTypes; private final float gammaL; FeatureVector[] wordFvs; // word feature vector @@ -45,7 +45,7 @@ public class LocalFeatureData { float[][] wpV2; float[][] wpW2; - private float[][] f; + private float[][] scoresOrProbabilities; private float[][][] labelScores; LocalFeatureData(DependencyInstance dependencyInstance, @@ -56,25 +56,25 @@ public class LocalFeatureData { options = parser.getOptions(); parameters = parser.getParameters(); - len = dependencyInstance.getLength(); - ntypes = pipe.getTypes().length; + sentenceLength = dependencyInstance.getLength(); + numberOfLabelTypes = pipe.getTypes().length; int rank = options.rankFirstOrderTensor; int rank2 = options.rankSecondOrderTensor; gammaL = options.gammaLabel; - wordFvs = new FeatureVector[len]; - wpU = new float[len][rank]; - wpV = new float[len][rank]; + wordFvs = new FeatureVector[sentenceLength]; + wpU = new float[sentenceLength][rank]; + wpV = new float[sentenceLength][rank]; - wpU2 = new float[len][rank2]; - wpV2 = new float[len][rank2]; - wpW2 = new float[len][rank2]; + wpU2 = new float[sentenceLength][rank2]; + wpV2 = new float[sentenceLength][rank2]; + wpW2 = new float[sentenceLength][rank2]; - f = new float[len][ntypes]; - labelScores = new float[len][ntypes][ntypes]; + scoresOrProbabilities = new float[sentenceLength][numberOfLabelTypes]; + labelScores = new float[sentenceLength][numberOfLabelTypes][numberOfLabelTypes]; - for (int i = 0; i < len; ++i) { + for (int i = 0; i < sentenceLength; ++i) { wordFvs[i] = synFactory.createWordFeatures(dependencyInstance, i); parameters.projectU(wordFvs[i], wpU[i]); @@ -84,7 +84,6 @@ public class LocalFeatureData { parameters.projectV2(wordFvs[i], wpV2 != null ? wpV2[i] : new float[0]); parameters.projectW2(wordFvs[i], wpW2 != null ? wpW2[i] : new float[0]); - } } @@ -117,88 +116,104 @@ private FeatureVector getLabelFeature(int[] heads, int[] types, int mod, int ord return fv; } - private void predictLabelsDP(int[] heads, int[] deplbids, boolean addLoss, DependencyArcList arcLis) { + private void predictLabelsDP(int[] heads, int[] dependencyLabelIds, boolean addLoss, DependencyArcList arcLis) { - int lab0 = addLoss ? 0 : 1; + int startLabelIndex = addLoss ? 0 : 1; - for (int mod = 1; mod < len; ++mod) { + for (int mod = 1; mod < sentenceLength; ++mod) { int head = heads[mod]; int dir = head > mod ? 1 : 2; int gp = heads[head]; int pdir = gp > head ? 1 : 2; - for (int p = lab0; p < ntypes; ++p) { - if (pipe.getPruneLabel()[dependencyInstance.getXPosTagIds()[head]][dependencyInstance.getXPosTagIds()[mod]][p]) { - deplbids[mod] = p; + for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) { + int[] posTagIds = dependencyInstance.getXPosTagIds(); + boolean pruneLabel = pipe.getPruneLabel()[posTagIds[head]][posTagIds[mod]][labelIndex]; + if (pruneLabel) { + dependencyLabelIds[mod] = labelIndex; float s1 = 0; if (gammaL > 0) - s1 += gammaL * getLabelScoreTheta(heads, deplbids, mod, 1); + s1 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 1); if (gammaL < 1) - s1 += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[mod], p, dir); - for (int q = lab0; q < ntypes; ++q) { + s1 += (1 - gammaL) * parameters.dotProductL(wpU[head], wpV[mod], labelIndex, dir); + for (int q = startLabelIndex; q < numberOfLabelTypes; ++q) { float s2 = 0; if (gp != -1) { - if (pipe.getPruneLabel()[dependencyInstance.getXPosTagIds()[gp]][dependencyInstance.getXPosTagIds()[head]][q]) { - deplbids[head] = q; + if (pipe.getPruneLabel()[posTagIds[gp]][posTagIds[head]][q]) { + dependencyLabelIds[head] = q; if (gammaL > 0) - s2 += gammaL * getLabelScoreTheta(heads, deplbids, mod, 2); + s2 += gammaL * getLabelScoreTheta(heads, dependencyLabelIds, mod, 2); if (gammaL < 1) - s2 += (1 - gammaL) * parameters.dotProduct2L(wpU2[gp], wpV2[head], wpW2[mod], q, p, pdir, dir); - } else s2 = Float.NEGATIVE_INFINITY; + s2 += (1 - gammaL) * parameters.dotProduct2L(wpU2[gp], wpV2[head], wpW2[mod], q, labelIndex, pdir, dir); + } else { + s2 = Float.NEGATIVE_INFINITY; + } } - labelScores[mod][p][q] = s1 + s2 + (addLoss && dependencyInstance.getDependencyLabelIds()[mod] != p ? 1.0f : 0.0f); + labelScores[mod][labelIndex][q] = s1 + s2 + (addLoss && dependencyInstance.getDependencyLabelIds()[mod] != labelIndex ? 1.0f : 0.0f); } - } else Arrays.fill(labelScores[mod][p], Float.NEGATIVE_INFINITY); + } else { + Arrays.fill(labelScores[mod][labelIndex], Float.NEGATIVE_INFINITY); + } } } - treeDP(0, arcLis, lab0); - deplbids[0] = dependencyInstance.getDependencyLabelIds()[0]; - getType(0, arcLis, deplbids, lab0); - + treeDP(0, arcLis, startLabelIndex); + dependencyLabelIds[0] = dependencyInstance.getDependencyLabelIds()[0]; + computeDependencyLabels(0, arcLis, dependencyLabelIds, startLabelIndex); } private float getLabelScoreTheta(int[] heads, int[] types, int mod, int order) { - ScoreCollector col = new ScoreCollector(parameters.getParamsL()); - synFactory.createLabelFeatures(col, dependencyInstance, heads, types, mod, order); - return col.getScore(); + ScoreCollector collector = new ScoreCollector(parameters.getParamsL()); + synFactory.createLabelFeatures(collector, dependencyInstance, heads, types, mod, order); + return collector.getScore(); } - private void treeDP(int i, DependencyArcList arcLis, int lab0) { - Arrays.fill(f[i], 0); - int st = arcLis.startIndex(i); - int ed = arcLis.endIndex(i); - for (int l = st; l < ed; ++l) { - int j = arcLis.get(l); - treeDP(j, arcLis, lab0); - for (int p = lab0; p < ntypes; ++p) { - float best = Float.NEGATIVE_INFINITY; - for (int q = lab0; q < ntypes; ++q) { - float s = f[j][q] + labelScores[j][q][p]; - if (s > best) - best = s; + private void treeDP(int indexNode, DependencyArcList dependencyArcs, int startLabelIndex) { + Arrays.fill(scoresOrProbabilities[indexNode], 0); + int startArcIndex = dependencyArcs.startIndex(indexNode); + int endArcIndex = dependencyArcs.endIndex(indexNode); + for (int arcIndex = startArcIndex; arcIndex < endArcIndex; ++arcIndex) { + int currentNode = dependencyArcs.get(arcIndex); + treeDP(currentNode, dependencyArcs, startLabelIndex); + for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) { + float currentScore = scoresOrProbabilities[currentNode][startLabelIndex]; + float currentLabelScore = labelScores[currentNode][startLabelIndex][labelIndex]; + float bestScore = currentScore + currentLabelScore; + for (int q = startLabelIndex + 1; q < numberOfLabelTypes; ++q) { + float score = scoresOrProbabilities[currentNode][q] + labelScores[currentNode][q][labelIndex]; + if (score > bestScore) + bestScore = score; } - f[i][p] += best; + scoresOrProbabilities[indexNode][labelIndex] += bestScore; } } } - private void getType(int i, DependencyArcList arcLis, int[] types, int lab0) { - int p = types[i]; - int st = arcLis.startIndex(i); - int ed = arcLis.endIndex(i); - for (int l = st; l < ed; ++l) { - int j = arcLis.get(l); - int bestq = 0; - float best = Float.NEGATIVE_INFINITY; - for (int q = lab0; q < ntypes; ++q) { - float s = f[j][q] + labelScores[j][q][p]; - if (s > best) { - best = s; - bestq = q; + private void computeDependencyLabels(int indexNode, + DependencyArcList dependencyArcs, + int[] dependencyLabelIds, + int startLabelIndex) { + int dependencyLabelId = dependencyLabelIds[indexNode]; + int startArcIndex = dependencyArcs.startIndex(indexNode); + int endArcIndex = dependencyArcs.endIndex(indexNode); + for (int arcIndex = startArcIndex; arcIndex < endArcIndex; ++arcIndex) { + int currentNode = dependencyArcs.get(arcIndex); + int bestLabel = 0; + float bestScore = Float.NEGATIVE_INFINITY; + for (int labelIndex = startLabelIndex; labelIndex < numberOfLabelTypes; ++labelIndex) { + float currentScore = scoresOrProbabilities[currentNode][labelIndex]; + float currentLabelScore = labelScores[currentNode][labelIndex][dependencyLabelId]; + float totalScore = currentScore + currentLabelScore; + if (totalScore > bestScore) { + bestScore = totalScore; + bestLabel = labelIndex; } } - types[j] = bestq; - getType(j, arcLis, types, lab0); + if (bestScore == Float.NEGATIVE_INFINITY) { + // if all scores are -Infinity, assign the original type + bestLabel = dependencyLabelIds[currentNode]; + } + dependencyLabelIds[currentNode] = bestLabel; + computeDependencyLabels(currentNode, dependencyArcs, dependencyLabelIds, startLabelIndex); } }