Skip to content

Commit

Permalink
[ML] add sentence piece pre-compiled normalizer (elastic#87575)
Browse files Browse the repository at this point in the history
This is one of the many prerequisites for supporting sentence-piece tokenization within NLP.

Sentence piece is a fairly complicated and involved tokenization scheme.

This commit contains the normalization logic that transforms the provided string from its current utf8 bytes into a standard normalized set of utf8 bytes.

The typical storage for this normalizer is a compressed representation of a DARTS array and a null delimited normalization string.
  • Loading branch information
benwtrent authored Jun 15, 2022
1 parent 65009b4 commit 2d571a0
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*
* This Java port DoubleArray Trie Structure, precompiled charmap parsing and sentence piece normalizer was derived from
* Huggingface's spm-precompiled.
* project at https://github.com/huggingface/spm_precompiled
*/

package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import com.ibm.icu.text.BreakIterator;

import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.UnicodeUtil;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Locale;
import java.util.Optional;
import java.util.OptionalInt;

/**
* This is custom normalizer logic purpose built to replicate the logic in DoubleArray Trie System (darts)
* object and the sentence piece normalizer.
*
* Links with further explanation of various parts of the algorithm:
* - <a href="https://github.com/huggingface/spm_precompiled/blob/81b911a362adef3ad3cc6d5835d2980690dbb871/src/lib.rs">
* huggingface lib
* </a>
* - <a href="https://github.com/google/sentencepiece/blob/bc53923a9147dc8ffa54034c8ed774de78cc4d39/third_party/darts_clone/darts.h#L469">
* DARTS
* </a>
* - <a href="https://github.com/google/sentencepiece/blob/91809e5c70ed0e6364267a0f0fed66c144482ce4/src/normalizer.cc">SP normalizer</a>
*/
public class PrecompiledCharMapNormalizer {

static PrecompiledCharMapNormalizer fromBase64Str(String s) {
int offset = 0;
byte[] bytes = Base64.getDecoder().decode(s);
int trieSize = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt();
offset += 4;
int size = trieSize / 4;
int[] offsets = new int[size];
for (int i = 0; i < size; i++) {
offsets[i] = ByteBuffer.wrap(bytes, offset, 4).order(java.nio.ByteOrder.LITTLE_ENDIAN).getInt();
offset += 4;
}
String utf8Str = new String(bytes, offset, bytes.length - offset, StandardCharsets.UTF_8);
return new PrecompiledCharMapNormalizer(offsets, utf8Str);
}

// The offsets for each normalization piece. Used in DARTS algorithm to iterate and find appropriate section
// in normalizedStrUtf8Bytes
private final int[] offsets;
// The entire normalized bytes representations delimited by NULL
private final byte[] normalizedStrUtf8Bytes;
// Continually reused to copy a single char into utf8 bytes
private final byte[] reusableCharByteBuffer = new byte[4];

public PrecompiledCharMapNormalizer(int[] offsets, String normalizedStr) {
this.offsets = offsets;
this.normalizedStrUtf8Bytes = normalizedStr.getBytes(StandardCharsets.UTF_8);
}

boolean hasLeaf(int v) {
return ((v >>> 8) & 1) == 1;
}

int label(int v) {
return (v & ((1 << 31) | 0xFF));
}

int value(int v) {
return (v & ((1 << 31) - 1));
}

int offset(int v) {
return (v >>> 10) << ((v & (1 << 9)) >>> 6);
}

OptionalInt commonPrefix(byte[] inputBytes) {
return commonPrefix(inputBytes, 0, inputBytes.length);
}

/**
* This finds a common prefix position within the normalization byte string.
*
* Since the normalization string is NULL delimited, start at the returned index and continue until you hit the NULL byte. That is
* then the normalized string.
*
* The prefix search is done according to DoubleArray Trie System (DARTS).
*
* See:
* <a href="https://github.com/google/sentencepiece/blob/bc53923a9147dc8ffa54034c8ed774de78cc4d39/third_party/darts_clone/darts.h#L469">
* DARTS
* </a>
* @param inputBytes utf8 bytes to normalize
* @param offset offset position to start given the input
* @param len the length of bytes to consider
* @return The starting position in the normalization string of the normalized bytes, if found.
*/
OptionalInt commonPrefix(byte[] inputBytes, int offset, int len) {
int pos = 0;
OptionalInt vs = OptionalInt.empty();
int v = offsets[pos];
pos ^= offset(v);
for (int i = offset; i < offset + len; i++) {
// bytes can be negative in java, handle it and require unsigned
int k = inputBytes[i];
if (k < 0) {
k += 256;
}
if (k == 0) {
break;
}
pos ^= k;
v = offsets[pos];
if (label(v) != k) {
return vs;
}
pos ^= offset(v);
if (hasLeaf(v)) {
vs = OptionalInt.of(value(offsets[pos]));
return vs;
}
}
return vs;
}

Optional<BytesRef> normalizePart(byte[] strBytes, int offset, int len) {
OptionalInt index = commonPrefix(strBytes, offset, len);
if (index.isEmpty()) {
return Optional.empty();
}
int firstIndex = index.getAsInt();
int secondIndex = firstIndex;
// Parsed normalized string has normalization sections partitioned by \0 (NULL) byte
while (secondIndex < normalizedStrUtf8Bytes.length && normalizedStrUtf8Bytes[secondIndex] != 0) {
secondIndex++;
}
if (secondIndex == firstIndex) {
return Optional.empty();
}
return Optional.of(new BytesRef(normalizedStrUtf8Bytes, firstIndex, secondIndex - firstIndex));
}

String normalize(String str) {
// We need to iterate actual Unicode graphemes (this includes surrogate pairs, etc.)
// I would much rather translate the entire input string text into utf-8 bytes, and then iterate to the appropriate
// break points from there. But, this seemed the easiest way for now
//
// Keep in mind, these break points aren't necessarily surrogate pairs, but also codepoints that contain a combining mark
BreakIterator b = BreakIterator.getCharacterInstance(Locale.ROOT);
b.setText(str);
int start = b.first();
// If we knew the utf-8 length ahead of time (and iterated over the bytes in the appropriate chunks)
// we could pre-populate the known length here.
BytesRefBuilder strBuilder = new BytesRefBuilder();
for (int end = b.next(); end != BreakIterator.DONE; start = end, end = b.next()) {
// TODO: It would be awesome if we could translate these starts and ends to byte positions, if we could performance would be
// dramatically improved
String unicodeStr = str.substring(start, end);
byte[] unicode = unicodeStr.getBytes(StandardCharsets.UTF_8);
// The trie only go up to a depth of 5 bytes.
// So even looking at it for graphemes (with combining, surrogate, etc.) that are 6+ bytes in length is useless.
if (unicode.length < 6) {
Optional<BytesRef> subStr = normalizePart(unicode, 0, unicode.length);
if (subStr.isPresent()) {
strBuilder.append(subStr.get());
continue;
}
}
int charIndex = 0;
int charByteIndex = 0;
char[] unicodeCharArray = unicodeStr.toCharArray();
for (char c : unicodeCharArray) {
Optional<BytesRef> subStr = normalizePart(unicode, charByteIndex, numUtf8Bytes(c));
if (subStr.isPresent()) {
strBuilder.append(subStr.get());
} else {
int numBytes = UnicodeUtil.UTF16toUTF8(unicodeCharArray, charIndex, 1, reusableCharByteBuffer);
strBuilder.append(reusableCharByteBuffer, 0, numBytes);
}
charByteIndex += numUtf8Bytes(c);
++charIndex;
}
}
return strBuilder.get().utf8ToString();
}

private static int numUtf8Bytes(int c) {
if (c < 128) {
return 1;
}
if (c < 2048) {
return 2;
}
if (c < 65536) {
return 3;
}
return 4;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;

import java.io.IOException;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

record PreCompiledCharMap(String charMapStr) {
static ParseField PRECOMPILED_CHARSMAP = new ParseField("precompiled_charsmap");
static ConstructingObjectParser<PreCompiledCharMap, Void> PARSER = new ConstructingObjectParser<>(
"precompiled_charsmap_config",
true,
a -> new PreCompiledCharMap((String) a[0])
);
static {
PARSER.declareString(constructorArg(), PRECOMPILED_CHARSMAP);
}

static PreCompiledCharMap fromResource(String resourcePath) throws IOException {
try (
XContentParser parser = XContentType.JSON.xContent()
.createParser(XContentParserConfiguration.EMPTY, PreCompiledCharMap.class.getResourceAsStream(resourcePath))
) {
return PreCompiledCharMap.PARSER.apply(parser, null);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.nlp.tokenizers;

import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.OptionalInt;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;

public class PrecompiledCharMapNormalizerTests extends ESTestCase {

public void testCommonPrefix() throws IOException {
PrecompiledCharMapNormalizer parsed = loadTestCharMap();
OptionalInt local = parsed.commonPrefix("\uFB01".getBytes(StandardCharsets.UTF_8));
assertThat(local.isPresent(), is(true));
assertThat(local.getAsInt(), equalTo(2130));
String transformed = parsed.normalize("\uFB01");
assertThat(transformed, equalTo("fi"));
assertThat(parsed.normalize("𝔾"), equalTo("G"));
assertThat(parsed.normalize("\uD835\uDD60"), equalTo("o"));
assertThat(parsed.normalize("\u200D"), equalTo(" "));
assertThat(parsed.normalize("เขาไม่ได้พูดสักคำ"), equalTo("เขาไม\u0E48ได\u0E49\u0E39ดส\u0E31กค\u0E4Dา"));
}

public void testAdverseScenario() throws IOException {
PrecompiledCharMapNormalizer parsed = loadTestCharMap();
assertThat(parsed.normalize("คำ"), equalTo("ค\u0e4dา"));
}

public void testAdverseScenarioHindi() throws IOException {
PrecompiledCharMapNormalizer parsed = loadTestCharMap();
assertThat(parsed.normalize("ड़ी दुख"), equalTo("ड\u093cी द\u0941ख"));
}

public void testTwoCharUnicode() throws IOException {
PrecompiledCharMapNormalizer parsed = loadTestCharMap();
assertThat(parsed.normalize("آ"), equalTo("آ"));
}

private static PrecompiledCharMapNormalizer loadTestCharMap() throws IOException {
PreCompiledCharMap map = PreCompiledCharMap.fromResource(
"/org.elasticsearch.xpack.ml.inference.nlp.tokenizers/precompiled_char_map.json"
);
return PrecompiledCharMapNormalizer.fromBase64Str(map.charMapStr());
}
}

Large diffs are not rendered by default.

0 comments on commit 2d571a0

Please sign in to comment.