From e440ce6e71adc241b9204735ae892ee190781d45 Mon Sep 17 00:00:00 2001 From: goldpulpy Date: Wed, 9 Oct 2024 04:05:25 +0300 Subject: [PATCH] Added pooling functions --- pysentence_similarity/pooling.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 pysentence_similarity/pooling.py diff --git a/pysentence_similarity/pooling.py b/pysentence_similarity/pooling.py new file mode 100644 index 0000000..8cfd276 --- /dev/null +++ b/pysentence_similarity/pooling.py @@ -0,0 +1,67 @@ +"""Module for pooling token embeddings.""" +from typing import List +import numpy as np + + +def max_pooling( + model_output: np.ndarray, + attention_mask: List[int] +) -> np.ndarray: + """ + Perform max pooling on token embeddings. + + :param model_output: Model output (token embeddings). + :type model_output: np.ndarray + :param attention_mask: Attention mask for the tokens. + :type attention_mask: List[int] + :return: Embedding vector for the entire sentence. + :rtype: np.ndarray + """ + token_embeddings = model_output + input_mask_expanded = np.expand_dims(attention_mask, axis=-1) + pooled_embedding = np.max(token_embeddings * input_mask_expanded, axis=1) + return pooled_embedding + + +def mean_pooling( + model_output: np.ndarray, + attention_mask: List[int] +) -> np.ndarray: + """ + Perform mean pooling on token embeddings. + + :param model_output: Model output (token embeddings). + :type model_output: np.ndarray + :param attention_mask: Attention mask for the tokens. + :type attention_mask: List[int] + :return: Embedding vector for the entire sentence. + :rtype: np.ndarray + """ + token_embeddings = model_output + input_mask_expanded = np.expand_dims(attention_mask, axis=-1) + pooled_embedding = ( + np.sum(token_embeddings * input_mask_expanded, axis=1) / + np.clip(np.sum(input_mask_expanded, axis=1), 1e-9, None) + ) + return pooled_embedding + + +def min_pooling( + model_output: np.ndarray, + attention_mask: List[int] +) -> np.ndarray: + """ + Perform min pooling on token embeddings. + + :param model_output: Model output (token embeddings). + :type model_output: np.ndarray + :param attention_mask: Attention mask for the tokens. + :type attention_mask: List[int] + :return: Embedding vector for the entire sentence. + :rtype: np.ndarray + """ + token_embeddings = model_output + input_mask_expanded = np.expand_dims(attention_mask, axis=-1) + pooled_embedding = np.min( + np.where(input_mask_expanded > 0, token_embeddings, np.inf), axis=1) + return pooled_embedding