diff --git a/components/segmenter/src/complex/lstm/matrix.rs b/components/segmenter/src/complex/lstm/matrix.rs index 8f97a59a7c7..a22a1f43054 100644 --- a/components/segmenter/src/complex/lstm/matrix.rs +++ b/components/segmenter/src/complex/lstm/matrix.rs @@ -204,6 +204,7 @@ impl<'a, const D: usize> MatrixBorrowedMut<'a, D> { Some(()) } + #[allow(dead_code)] // maybe needed for more complicated bies calculations /// Mutates this matrix by applying a softmax transformation. pub(super) fn softmax_transform(&mut self) { let sm = self.data.iter().map(|v| v.exp()).sum::(); diff --git a/components/segmenter/src/complex/lstm/mod.rs b/components/segmenter/src/complex/lstm/mod.rs index 6cb249e5457..458e6e21dbf 100644 --- a/components/segmenter/src/complex/lstm/mod.rs +++ b/components/segmenter/src/complex/lstm/mod.rs @@ -4,8 +4,6 @@ use crate::grapheme::GraphemeClusterSegmenter; use crate::provider::*; -use alloc::boxed::Box; -use alloc::string::String; use alloc::vec::Vec; use core::char::{decode_utf16, REPLACEMENT_CHARACTER}; use icu_provider::DataPayload; @@ -18,9 +16,8 @@ use matrix::*; struct LstmSegmenterIterator<'s> { input: &'s str, - bies_str: Box<[Bies]>, - pos: usize, pos_utf8: usize, + bies: BiesIterator<'s>, } impl Iterator for LstmSegmenterIterator<'_> { @@ -29,29 +26,27 @@ impl Iterator for LstmSegmenterIterator<'_> { fn next(&mut self) -> Option { #[allow(clippy::indexing_slicing)] // pos_utf8 in range loop { - let bies = *self.bies_str.get(self.pos)?; + let is_e = self.bies.next()?; self.pos_utf8 += self.input[self.pos_utf8..].chars().next()?.len_utf8(); - self.pos += 1; - if bies == Bies::E || self.pos == self.bies_str.len() { + if is_e || self.bies.len() == 0 { return Some(self.pos_utf8); } } } } -struct LstmSegmenterIteratorUtf16 { - bies_str: Box<[Bies]>, +struct LstmSegmenterIteratorUtf16<'s> { + bies: BiesIterator<'s>, pos: usize, } -impl Iterator for LstmSegmenterIteratorUtf16 { +impl Iterator for LstmSegmenterIteratorUtf16<'_> { type Item = usize; fn next(&mut self) -> Option { loop { - let bies = *self.bies_str.get(self.pos)?; self.pos += 1; - if bies == Bies::E || self.pos == self.bies_str.len() { + if self.bies.next()? || self.bies.len() == 0 { return Some(self.pos); } } @@ -67,7 +62,8 @@ pub(super) struct LstmSegmenter<'l> { bw_w: MatrixZero<'l, 3>, bw_u: MatrixZero<'l, 3>, bw_b: MatrixZero<'l, 2>, - time_w: MatrixZero<'l, 3>, + timew_fw: MatrixZero<'l, 2>, + timew_bw: MatrixZero<'l, 2>, time_b: MatrixZero<'l, 1>, grapheme: Option<&'l RuleBreakDataV1<'l>>, } @@ -79,6 +75,11 @@ impl<'l> LstmSegmenter<'l> { grapheme: &'l DataPayload, ) -> Self { let LstmDataV1::Float32(lstm) = lstm.get(); + let time_w = MatrixZero::from(&lstm.time_w); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_fw = time_w.submatrix(0).unwrap(); + #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) + let timew_bw = time_w.submatrix(1).unwrap(); Self { dic: lstm.dic.as_borrowed(), embedding: MatrixZero::from(&lstm.embedding), @@ -88,42 +89,21 @@ impl<'l> LstmSegmenter<'l> { bw_w: MatrixZero::from(&lstm.bw_w), bw_u: MatrixZero::from(&lstm.bw_u), bw_b: MatrixZero::from(&lstm.bw_b), - time_w: MatrixZero::from(&lstm.time_w), + timew_fw, + timew_bw, time_b: MatrixZero::from(&lstm.time_b), grapheme: (lstm.model == ModelType::GraphemeClusters).then(|| grapheme.get()), } } /// Create an LSTM based break iterator for an `str` (a UTF-8 string). - pub(super) fn segment_str<'s>(&self, input: &'s str) -> impl Iterator + 's { - let lstm_output = self.produce_bies(input); - LstmSegmenterIterator { - input, - bies_str: lstm_output, - pos: 0, - pos_utf8: 0, - } + pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator + 'l { + self.segment_str_p(input) } - /// Create an LSTM based break iterator for a UTF-16 string. - pub(super) fn segment_utf16(&self, input: &[u16]) -> impl Iterator { - let input: String = decode_utf16(input.iter().copied()) - .map(|r| r.unwrap_or(REPLACEMENT_CHARACTER)) - .collect(); - let lstm_output = self.produce_bies(&input); - LstmSegmenterIteratorUtf16 { - bies_str: lstm_output, - pos: 0, - } - } - - /// `produce_bies` is a function that gets a "clean" unsegmented string as its input and returns a BIES (B: Beginning, I: Inside, E: End, - /// S: Single) sequence for grapheme clusters. The boundaries of words can be found easily using this BIES sequence. - fn produce_bies(&self, input: &str) -> Box<[Bies]> { - // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later - // in the embedding layer of the model. - // Already checked that the name of the model is either "codepoints" or "graphclsut" - let input_seq: Vec = if let Some(grapheme) = self.grapheme { + // For unit testing as we cannot inspect the opaque type's bies + fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> { + let input_seq = if let Some(grapheme) = self.grapheme { GraphemeClusterSegmenter::new_and_segment_str(input, grapheme) .collect::>() .windows(2) @@ -133,8 +113,14 @@ impl<'l> LstmSegmenter<'l> { } else { unreachable!() }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; + self.dic - .get_copied(UnvalidatedStr::from_str(input.get(range).unwrap_or(input))) + .get_copied(UnvalidatedStr::from_str(grapheme_cluster)) .unwrap_or_else(|| self.dic.len() as u16) }) .collect() @@ -148,162 +134,192 @@ impl<'l> LstmSegmenter<'l> { }) .collect() }; + LstmSegmenterIterator { + input, + pos_utf8: 0, + bies: BiesIterator::new(self, input_seq), + } + } - /// `compute_hc1` implemens the evaluation of one LSTM layer. - fn compute_hc<'a>( - x_t: MatrixZero<'a, 1>, - mut h_tm1: MatrixBorrowedMut<'a, 1>, - mut c_tm1: MatrixBorrowedMut<'a, 1>, - w: MatrixZero<'a, 3>, - u: MatrixZero<'a, 3>, - b: MatrixZero<'a, 2>, - ) { - #[cfg(debug_assertions)] - { - let hunits = h_tm1.dim(); - let embedd_dim = x_t.dim(); - c_tm1.as_borrowed().debug_assert_dims([hunits]); - w.debug_assert_dims([4, hunits, embedd_dim]); - u.debug_assert_dims([4, hunits, hunits]); - b.debug_assert_dims([4, hunits]); - } + /// Create an LSTM based break iterator for a UTF-16 string. + pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator + 'l { + let input_seq = if let Some(grapheme) = self.grapheme { + GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme) + .collect::>() + .windows(2) + .map(|chunk| { + let range = if let [first, second, ..] = chunk { + *first..*second + } else { + unreachable!() + }; + let grapheme_cluster = if let Some(grapheme_cluster) = input.get(range) { + grapheme_cluster + } else { + return self.dic.len() as u16; + }; - let mut s_t = b.to_owned(); - - s_t.as_mut().add_dot_3d_2(x_t, w); - s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); - - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); - #[allow(clippy::unwrap_used)] // first dimension is 4 - s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); - - #[allow(clippy::unwrap_used)] // first dimension is 4 - c_tm1.convolve( - s_t.as_borrowed().submatrix(0).unwrap(), - s_t.as_borrowed().submatrix(2).unwrap(), - s_t.as_borrowed().submatrix(1).unwrap(), - ); + // The maximum UTF-8 size of a grapheme cluster seems to be 41 bytes + let mut i = 0; + let mut buf = [0; 41]; - #[allow(clippy::unwrap_used)] // first dimension is 4 - h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); + decode_utf16(grapheme_cluster.iter().copied()).for_each(|c| { + debug_assert!(i < 37); + i += c + .unwrap_or(REPLACEMENT_CHARACTER) + .encode_utf8(&mut buf[i..]) + .len() + }); + + self.dic + .get_copied(UnvalidatedStr::from_bytes(&buf[..i])) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + } else { + decode_utf16(input.iter().copied()) + .map(|c| c.unwrap_or(REPLACEMENT_CHARACTER)) + .map(|c| { + self.dic + .get_copied(UnvalidatedStr::from_str(c.encode_utf8(&mut [0; 4]))) + .unwrap_or_else(|| self.dic.len() as u16) + }) + .collect() + }; + LstmSegmenterIteratorUtf16 { + bies: BiesIterator::new(self, input_seq), + pos: 0, } + } +} - let hunits = self.fw_u.dim().1; +struct BiesIterator<'l> { + segmenter: &'l LstmSegmenter<'l>, + input_seq: core::iter::Enumerate>, + h_bw: MatrixOwned<2>, + curr_fw: MatrixOwned<1>, + c_fw: MatrixOwned<1>, +} - // Forward LSTM - let mut c_fw = MatrixOwned::<1>::new_zero([hunits]); - let mut all_h_fw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); - for (i, &g_id) in input_seq.iter().enumerate() { - #[allow(clippy::unwrap_used)] - // embedding has shape (dict.len() + 1, hunit), g_id is at most dict.len() - let x_t = self.embedding.submatrix::<1>(g_id as usize).unwrap(); - if i > 0 { - all_h_fw.as_mut().copy_submatrix::<1>(i - 1, i); - } - #[allow(clippy::unwrap_used)] - compute_hc( - x_t, - all_h_fw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) - c_fw.as_mut(), - self.fw_w, - self.fw_u, - self.fw_b, - ); - } +impl<'l> BiesIterator<'l> { + // input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later + // in the embedding layer of the model. + fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec) -> Self { + let hunits = segmenter.fw_u.dim().1; // Backward LSTM let mut c_bw = MatrixOwned::<1>::new_zero([hunits]); - let mut all_h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); + let mut h_bw = MatrixOwned::<2>::new_zero([input_seq.len(), hunits]); for (i, &g_id) in input_seq.iter().enumerate().rev() { - #[allow(clippy::unwrap_used)] - // embedding has shape (dict.len() + 1, hunit), g_id is at most dict.len() - let x_t = self.embedding.submatrix::<1>(g_id as usize).unwrap(); if i + 1 < input_seq.len() { - all_h_bw.as_mut().copy_submatrix::<1>(i + 1, i); + h_bw.as_mut().copy_submatrix::<1>(i + 1, i); } #[allow(clippy::unwrap_used)] compute_hc( - x_t, - all_h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) + segmenter.embedding.submatrix::<1>(g_id as usize).unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + h_bw.submatrix_mut(i).unwrap(), // shape (input_seq.len(), hunits) c_bw.as_mut(), - self.bw_w, - self.bw_u, - self.bw_b, + segmenter.bw_w, + segmenter.bw_u, + segmenter.bw_b, ); } - #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) - let timew_fw = self.time_w.submatrix(0).unwrap(); - #[allow(clippy::unwrap_used)] // shape (2, 4, hunits) - let timew_bw = self.time_w.submatrix(1).unwrap(); - - // Combining forward and backward LSTMs using the dense time-distributed layer - (0..input_seq.len()) - .map(|i| { - #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) - let curr_fw = all_h_fw.submatrix::<1>(i).unwrap(); - #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) - let curr_bw = all_h_bw.submatrix::<1>(i).unwrap(); - let mut weights = [0.0; 4]; - let mut curr_est = MatrixBorrowedMut { - data: &mut weights, - dims: [4], - }; - curr_est.add_dot_2d(curr_fw, timew_fw); - curr_est.add_dot_2d(curr_bw, timew_bw); - #[allow(clippy::unwrap_used)] // both shape (4) - curr_est.add(self.time_b).unwrap(); - curr_est.softmax_transform(); - Bies::from_probabilities(weights) - }) - .collect() + Self { + input_seq: input_seq.into_iter().enumerate(), + h_bw, + c_fw: MatrixOwned::<1>::new_zero([hunits]), + curr_fw: MatrixOwned::<1>::new_zero([hunits]), + segmenter, + } } } -// TODO(#421): Use common BIES normalizer code -#[derive(Debug, PartialEq, Copy, Clone)] -enum Bies { - B, - I, - E, - S, +impl ExactSizeIterator for BiesIterator<'_> { + fn len(&self) -> usize { + self.input_seq.len() + } } -impl Bies { - /// Returns the value the largest probability - fn from_probabilities(arr: [f32; 4]) -> Bies { - let [b, i, e, s] = arr; - let mut result = Bies::B; - let mut max = b; - if i > max { - result = Bies::I; - max = i; - } - if e > max { - result = Bies::E; - max = e; - } - if s > max { - result = Bies::S; - // max = s; - } - result +impl Iterator for BiesIterator<'_> { + type Item = bool; + + fn next(&mut self) -> Option { + let (i, g_id) = self.input_seq.next()?; + + #[allow(clippy::unwrap_used)] + compute_hc( + self.segmenter + .embedding + .submatrix::<1>(g_id as usize) + .unwrap(), // shape (dict.len() + 1, hunit), g_id is at most dict.len() + self.curr_fw.as_mut(), + self.c_fw.as_mut(), + self.segmenter.fw_w, + self.segmenter.fw_u, + self.segmenter.fw_b, + ); + + #[allow(clippy::unwrap_used)] // shape (input_seq.len(), hunits) + let curr_bw = self.h_bw.submatrix::<1>(i).unwrap(); + let mut weights = [0.0; 4]; + let mut curr_est = MatrixBorrowedMut { + data: &mut weights, + dims: [4], + }; + curr_est.add_dot_2d(self.curr_fw.as_borrowed(), self.segmenter.timew_fw); + curr_est.add_dot_2d(curr_bw, self.segmenter.timew_bw); + #[allow(clippy::unwrap_used)] // both shape (4) + curr_est.add(self.segmenter.time_b).unwrap(); + // For correct BIES weight calculation we'd now have to apply softmax, however + // we're only doing a naive argmax, so a monotonic function doesn't make a difference. + + Some(weights[2] > weights[0] && weights[2] > weights[1] && weights[2] > weights[3]) } +} - #[cfg(test)] - fn as_char(&self) -> char { - match self { - Bies::B => 'b', - Bies::I => 'i', - Bies::E => 'e', - Bies::S => 's', - } +/// `compute_hc1` implemens the evaluation of one LSTM layer. +fn compute_hc<'a>( + x_t: MatrixZero<'a, 1>, + mut h_tm1: MatrixBorrowedMut<'a, 1>, + mut c_tm1: MatrixBorrowedMut<'a, 1>, + w: MatrixZero<'a, 3>, + u: MatrixZero<'a, 3>, + b: MatrixZero<'a, 2>, +) { + #[cfg(debug_assertions)] + { + let hunits = h_tm1.dim(); + let embedd_dim = x_t.dim(); + c_tm1.as_borrowed().debug_assert_dims([hunits]); + w.debug_assert_dims([4, hunits, embedd_dim]); + u.debug_assert_dims([4, hunits, hunits]); + b.debug_assert_dims([4, hunits]); } + + let mut s_t = b.to_owned(); + + s_t.as_mut().add_dot_3d_2(x_t, w); + s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), u); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(0).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(1).unwrap().sigmoid_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(2).unwrap().tanh_transform(); + #[allow(clippy::unwrap_used)] // first dimension is 4 + s_t.submatrix_mut::<1>(3).unwrap().sigmoid_transform(); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + c_tm1.convolve( + s_t.as_borrowed().submatrix(0).unwrap(), + s_t.as_borrowed().submatrix(2).unwrap(), + s_t.as_borrowed().submatrix(1).unwrap(), + ); + + #[allow(clippy::unwrap_used)] // first dimension is 4 + h_tm1.mul_tanh(s_t.as_borrowed().submatrix(3).unwrap(), c_tm1.as_borrowed()); } #[cfg(test)] @@ -379,17 +395,18 @@ mod tests { }; // Testing - for test_case in test_text.data.testcases { - let lstm_output = lstm.produce_bies(&test_case.unseg); + for test_case in &test_text.data.testcases { + let lstm_output = lstm + .segment_str_p(&test_case.unseg) + .bies + .map(|is_e| if is_e { 'e' } else { '?' }) + .collect::(); println!("Test case : {}", test_case.unseg); println!("Expected bies : {}", test_case.expected_bies); - println!("Estimated bies : {lstm_output:?}"); + println!("Estimated bies : {lstm_output}"); println!("True bies : {}", test_case.true_bies); println!("****************************************************"); - assert_eq!( - test_case.expected_bies, - lstm_output.iter().map(Bies::as_char).collect::() - ); + assert_eq!(test_case.expected_bies.replace(['b','i','s'], "?"), lstm_output); } } }