Skip to content

Commit

Permalink
Added pooling functions
Browse files Browse the repository at this point in the history
  • Loading branch information
goldpulpy committed Oct 9, 2024
1 parent 84274a4 commit e440ce6
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions pysentence_similarity/pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Module for pooling token embeddings."""
from typing import List
import numpy as np


def max_pooling(
model_output: np.ndarray,
attention_mask: List[int]
) -> np.ndarray:
"""
Perform max pooling on token embeddings.
:param model_output: Model output (token embeddings).
:type model_output: np.ndarray
:param attention_mask: Attention mask for the tokens.
:type attention_mask: List[int]
:return: Embedding vector for the entire sentence.
:rtype: np.ndarray
"""
token_embeddings = model_output
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
pooled_embedding = np.max(token_embeddings * input_mask_expanded, axis=1)
return pooled_embedding


def mean_pooling(
model_output: np.ndarray,
attention_mask: List[int]
) -> np.ndarray:
"""
Perform mean pooling on token embeddings.
:param model_output: Model output (token embeddings).
:type model_output: np.ndarray
:param attention_mask: Attention mask for the tokens.
:type attention_mask: List[int]
:return: Embedding vector for the entire sentence.
:rtype: np.ndarray
"""
token_embeddings = model_output
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
pooled_embedding = (
np.sum(token_embeddings * input_mask_expanded, axis=1) /
np.clip(np.sum(input_mask_expanded, axis=1), 1e-9, None)
)
return pooled_embedding


def min_pooling(
model_output: np.ndarray,
attention_mask: List[int]
) -> np.ndarray:
"""
Perform min pooling on token embeddings.
:param model_output: Model output (token embeddings).
:type model_output: np.ndarray
:param attention_mask: Attention mask for the tokens.
:type attention_mask: List[int]
:return: Embedding vector for the entire sentence.
:rtype: np.ndarray
"""
token_embeddings = model_output
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
pooled_embedding = np.min(
np.where(input_mask_expanded > 0, token_embeddings, np.inf), axis=1)
return pooled_embedding

0 comments on commit e440ce6

Please sign in to comment.