forked from elastic/elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] add sentence piece pre-compiled normalizer (elastic#87575)
This is one of the many prerequisites for supporting sentence-piece tokenization within NLP. Sentence piece is a fairly complicated and involved tokenization scheme. This commit contains the normalization logic that transforms the provided string from its current utf8 bytes into a standard normalized set of utf8 bytes. The typical storage for this normalizer is a compressed representation of a DARTS array and a null delimited normalization string.
- Loading branch information
Showing
4 changed files
with
306 additions
and
0 deletions.
There are no files selected for viewing
209 changes: 209 additions & 0 deletions
209
...ava/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
* | ||
* This Java port DoubleArray Trie Structure, precompiled charmap parsing and sentence piece normalizer was derived from | ||
* Huggingface's spm-precompiled. | ||
* project at https://github.com/huggingface/spm_precompiled | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; | ||
|
||
import com.ibm.icu.text.BreakIterator; | ||
|
||
import org.apache.lucene.util.BytesRef; | ||
import org.apache.lucene.util.BytesRefBuilder; | ||
import org.apache.lucene.util.UnicodeUtil; | ||
|
||
import java.nio.ByteBuffer; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.Base64; | ||
import java.util.Locale; | ||
import java.util.Optional; | ||
import java.util.OptionalInt; | ||
|
||
/** | ||
* This is custom normalizer logic purpose built to replicate the logic in DoubleArray Trie System (darts) | ||
* object and the sentence piece normalizer. | ||
* | ||
* Links with further explanation of various parts of the algorithm: | ||
* - <a href="https://github.com/huggingface/spm_precompiled/blob/81b911a362adef3ad3cc6d5835d2980690dbb871/src/lib.rs"> | ||
* huggingface lib | ||
* </a> | ||
* - <a href="https://github.com/google/sentencepiece/blob/bc53923a9147dc8ffa54034c8ed774de78cc4d39/third_party/darts_clone/darts.h#L469"> | ||
* DARTS | ||
* </a> | ||
* - <a href="https://github.com/google/sentencepiece/blob/91809e5c70ed0e6364267a0f0fed66c144482ce4/src/normalizer.cc">SP normalizer</a> | ||
*/ | ||
public class PrecompiledCharMapNormalizer { | ||
|
||
static PrecompiledCharMapNormalizer fromBase64Str(String s) { | ||
int offset = 0; | ||
byte[] bytes = Base64.getDecoder().decode(s); | ||
int trieSize = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt(); | ||
offset += 4; | ||
int size = trieSize / 4; | ||
int[] offsets = new int[size]; | ||
for (int i = 0; i < size; i++) { | ||
offsets[i] = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt(); | ||
offset += 4; | ||
} | ||
String utf8Str = new String(bytes, offset, bytes.length - offset, StandardCharsets.UTF_8); | ||
return new PrecompiledCharMapNormalizer(offsets, utf8Str); | ||
} | ||
|
||
// The offsets for each normalization piece. Used in DARTS algorithm to iterate and find appropriate section | ||
// in normalizedStrUtf8Bytes | ||
private final int[] offsets; | ||
// The entire normalized bytes representations delimited by NULL | ||
private final byte[] normalizedStrUtf8Bytes; | ||
// Continually reused to copy a single char into utf8 bytes | ||
private final byte[] reusableCharByteBuffer = new byte[4]; | ||
|
||
public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr) { | ||
this.offsets = offsets; | ||
this.normalizedStrUtf8Bytes = normalizedStr.getBytes(StandardCharsets.UTF_8); | ||
} | ||
|
||
boolean hasLeaf(int v) { | ||
return ((v >>> 8) & 1) == 1; | ||
} | ||
|
||
int label(int v) { | ||
return (v & ((1 << 31) | 0xFF)); | ||
} | ||
|
||
int value(int v) { | ||
return (v & ((1 << 31) - 1)); | ||
} | ||
|
||
int offset(int v) { | ||
return (v >>> 10) << ((v & (1 << 9)) >>> 6); | ||
} | ||
|
||
OptionalInt commonPrefix(byte[] inputBytes) { | ||
return commonPrefix(inputBytes, 0, inputBytes.length); | ||
} | ||
|
||
/** | ||
* This finds a common prefix position within the normalization byte string. | ||
* | ||
* Since the normalization string is NULL delimited, start at the returned index and continue until you hit the NULL byte. That is | ||
* then the normalized string. | ||
* | ||
* The prefix search is done according to DoubleArray Trie System (DARTS). | ||
* | ||
* See: | ||
* <a href="https://github.com/google/sentencepiece/blob/bc53923a9147dc8ffa54034c8ed774de78cc4d39/third_party/darts_clone/darts.h#L469"> | ||
* DARTS | ||
* </a> | ||
* @param inputBytes utf8 bytes to normalize | ||
* @param offset offset position to start given the input | ||
* @param len the length of bytes to consider | ||
* @return The starting position in the normalization string of the normalized bytes, if found. | ||
*/ | ||
OptionalInt commonPrefix(byte[] inputBytes, int offset, int len) { | ||
int pos = 0; | ||
OptionalInt vs = OptionalInt.empty(); | ||
int v = offsets[pos]; | ||
pos ^= offset(v); | ||
for (int i = offset; i < offset + len; i++) { | ||
// bytes can be negative in java, handle it and require unsigned | ||
int k = inputBytes[i]; | ||
if (k < 0) { | ||
k += 256; | ||
} | ||
if (k == 0) { | ||
break; | ||
} | ||
pos ^= k; | ||
v = offsets[pos]; | ||
if (label(v) != k) { | ||
return vs; | ||
} | ||
pos ^= offset(v); | ||
if (hasLeaf(v)) { | ||
vs = OptionalInt.of(value(offsets[pos])); | ||
return vs; | ||
} | ||
} | ||
return vs; | ||
} | ||
|
||
Optional<BytesRef> normalizePart(byte[] strBytes, int offset, int len) { | ||
OptionalInt index = commonPrefix(strBytes, offset, len); | ||
if (index.isEmpty()) { | ||
return Optional.empty(); | ||
} | ||
int firstIndex = index.getAsInt(); | ||
int secondIndex = firstIndex; | ||
// Parsed normalized string has normalization sections partitioned by \0 (NULL) byte | ||
while (secondIndex < normalizedStrUtf8Bytes.length && normalizedStrUtf8Bytes[secondIndex] != 0) { | ||
secondIndex++; | ||
} | ||
if (secondIndex == firstIndex) { | ||
return Optional.empty(); | ||
} | ||
return Optional.of(new BytesRef(normalizedStrUtf8Bytes, firstIndex, secondIndex - firstIndex)); | ||
} | ||
|
||
String normalize(String str) { | ||
// We need to iterate actual Unicode graphemes (this includes surrogate pairs, etc.) | ||
// I would much rather translate the entire input string text into utf-8 bytes, and then iterate to the appropriate | ||
// break points from there. But, this seemed the easiest way for now | ||
// | ||
// Keep in mind, these break points aren't necessarily surrogate pairs, but also codepoints that contain a combining mark | ||
BreakIterator b = BreakIterator.getCharacterInstance(Locale.ROOT); | ||
b.setText(str); | ||
int start = b.first(); | ||
// If we knew the utf-8 length ahead of time (and iterated over the bytes in the appropriate chunks) | ||
// we could pre-populate the known length here. | ||
BytesRefBuilder strBuilder = new BytesRefBuilder(); | ||
for (int end = b.next(); end != BreakIterator.DONE; start = end, end = b.next()) { | ||
// TODO: It would be awesome if we could translate these starts and ends to byte positions, if we could performance would be | ||
// dramatically improved | ||
String unicodeStr = str.substring(start, end); | ||
byte[] unicode = unicodeStr.getBytes(StandardCharsets.UTF_8); | ||
// The trie only go up to a depth of 5 bytes. | ||
// So even looking at it for graphemes (with combining, surrogate, etc.) that are 6+ bytes in length is useless. | ||
if (unicode.length < 6) { | ||
Optional<BytesRef> subStr = normalizePart(unicode, 0, unicode.length); | ||
if (subStr.isPresent()) { | ||
strBuilder.append(subStr.get()); | ||
continue; | ||
} | ||
} | ||
int charIndex = 0; | ||
int charByteIndex = 0; | ||
char[] unicodeCharArray = unicodeStr.toCharArray(); | ||
for (char c : unicodeCharArray) { | ||
Optional<BytesRef> subStr = normalizePart(unicode, charByteIndex, numUtf8Bytes(c)); | ||
if (subStr.isPresent()) { | ||
strBuilder.append(subStr.get()); | ||
} else { | ||
int numBytes = UnicodeUtil.UTF16toUTF8(unicodeCharArray, charIndex, 1, reusableCharByteBuffer); | ||
strBuilder.append(reusableCharByteBuffer, 0, numBytes); | ||
} | ||
charByteIndex += numUtf8Bytes(c); | ||
++charIndex; | ||
} | ||
} | ||
return strBuilder.get().utf8ToString(); | ||
} | ||
|
||
private static int numUtf8Bytes(int c) { | ||
if (c < 128) { | ||
return 1; | ||
} | ||
if (c < 2048) { | ||
return 2; | ||
} | ||
if (c < 65536) { | ||
return 3; | ||
} | ||
return 4; | ||
} | ||
|
||
} |
39 changes: 39 additions & 0 deletions
39
...src/test/java/org/elasticsearch/xpack/ml/inference/nlp/tokenizers/PreCompiledCharMap.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; | ||
|
||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.xcontent.ParseField; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xcontent.XContentParserConfiguration; | ||
import org.elasticsearch.xcontent.XContentType; | ||
|
||
import java.io.IOException; | ||
|
||
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; | ||
|
||
record PreCompiledCharMap(String charMapStr) { | ||
static ParseField PRECOMPILED_CHARSMAP = new ParseField("precompiled_charsmap"); | ||
static ConstructingObjectParser<PreCompiledCharMap, Void> PARSER = new ConstructingObjectParser<>( | ||
"precompiled_charsmap_config", | ||
true, | ||
a -> new PreCompiledCharMap((String) a[0]) | ||
); | ||
static { | ||
PARSER.declareString(constructorArg(), PRECOMPILED_CHARSMAP); | ||
} | ||
|
||
static PreCompiledCharMap fromResource(String resourcePath) throws IOException { | ||
try ( | ||
XContentParser parser = XContentType.JSON.xContent() | ||
.createParser(XContentParserConfiguration.EMPTY, PreCompiledCharMap.class.getResourceAsStream(resourcePath)) | ||
) { | ||
return PreCompiledCharMap.PARSER.apply(parser, null); | ||
} | ||
} | ||
} |
55 changes: 55 additions & 0 deletions
55
...rg/elasticsearch/xpack/ml/inference/nlp/tokenizers/PrecompiledCharMapNormalizerTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.ml.inference.nlp.tokenizers; | ||
|
||
import org.elasticsearch.test.ESTestCase; | ||
|
||
import java.io.IOException; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.OptionalInt; | ||
|
||
import static org.hamcrest.Matchers.equalTo; | ||
import static org.hamcrest.Matchers.is; | ||
|
||
public class PrecompiledCharMapNormalizerTests extends ESTestCase { | ||
|
||
public void testCommonPrefix() throws IOException { | ||
PrecompiledCharMapNormalizer parsed = loadTestCharMap(); | ||
OptionalInt local = parsed.commonPrefix("\uFB01".getBytes(StandardCharsets.UTF_8)); | ||
assertThat(local.isPresent(), is(true)); | ||
assertThat(local.getAsInt(), equalTo(2130)); | ||
String transformed = parsed.normalize("\uFB01"); | ||
assertThat(transformed, equalTo("fi")); | ||
assertThat(parsed.normalize("𝔾"), equalTo("G")); | ||
assertThat(parsed.normalize("\uD835\uDD60"), equalTo("o")); | ||
assertThat(parsed.normalize("\u200D"), equalTo(" ")); | ||
assertThat(parsed.normalize("เขาไม่ได้พูดสักคำ"), equalTo("เขาไม\u0E48ได\u0E49พ\u0E39ดส\u0E31กค\u0E4Dา")); | ||
} | ||
|
||
public void testAdverseScenario() throws IOException { | ||
PrecompiledCharMapNormalizer parsed = loadTestCharMap(); | ||
assertThat(parsed.normalize("คำ"), equalTo("ค\u0e4dา")); | ||
} | ||
|
||
public void testAdverseScenarioHindi() throws IOException { | ||
PrecompiledCharMapNormalizer parsed = loadTestCharMap(); | ||
assertThat(parsed.normalize("ड़ी दुख"), equalTo("ड\u093cी द\u0941ख")); | ||
} | ||
|
||
public void testTwoCharUnicode() throws IOException { | ||
PrecompiledCharMapNormalizer parsed = loadTestCharMap(); | ||
assertThat(parsed.normalize("آ"), equalTo("آ")); | ||
} | ||
|
||
private static PrecompiledCharMapNormalizer loadTestCharMap() throws IOException { | ||
PreCompiledCharMap map = PreCompiledCharMap.fromResource( | ||
"/org.elasticsearch.xpack.ml.inference.nlp.tokenizers/precompiled_char_map.json" | ||
); | ||
return PrecompiledCharMapNormalizer.fromBase64Str(map.charMapStr()); | ||
} | ||
} |
3 changes: 3 additions & 0 deletions
3
...t/resources/org.elasticsearch.xpack.ml.inference.nlp.tokenizers/precompiled_char_map.json
Large diffs are not rendered by default.
Oops, something went wrong.