Skip to content

Commit

Permalink
Fixes to Hubert issues that cause problems later
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Dec 5, 2023
1 parent 410f995 commit 05fe555
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/transformers/models/hubert/modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,6 @@ def _normalize_kernel(self):

def build(self, input_shape):
if not self.built:
input_shape = input_shape.as_list()
# If a specific input shape is passed in, we need to modify it to account for padding
# Not necessary if those portions of the shape are None
if input_shape[-2] is not None:
input_shape[-2] += self.explicit_padding * 2
super().build(input_shape)

self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
Expand Down Expand Up @@ -531,13 +526,17 @@ def __init__(self, config: HubertConfig, **kwargs: Any) -> None:
)
self.padding = TFHubertSamePadLayer(config.num_conv_pos_embeddings)
self.activation = get_tf_activation(config.feat_extract_activation)
self.config = config

def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.conv(hidden_states)
hidden_states = self.padding(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states

def build(self, input_shape=None):
self.conv.build(self.config.hidden_size)


# Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2SamePadLayer with Wav2Vec2->Hubert
class TFHubertSamePadLayer(tf.keras.layers.Layer):
Expand Down

0 comments on commit 05fe555

Please sign in to comment.