diff --git a/mindnlp/injection.py b/mindnlp/injection.py index b19b25cba..2ddafe1d2 100644 --- a/mindnlp/injection.py +++ b/mindnlp/injection.py @@ -628,17 +628,13 @@ def __init__(self, raise ValueError(f"The argument 'group' should be divisible by 'in_channels' " \ f"and 'out_channels', but got group:{group}, in_channels:{in_channels}, " \ f"out_channels:{out_channels}.") - kernel_size = (kernel_size,) - if mindspore.__version__ == '2.0.0': - stride = (1, stride,) - else: - stride = (stride,) + stride = (1, stride) + dilation = (1, dilation) - dilation = (dilation,) super().__init__( in_channels, out_channels, - kernel_size, + (kernel_size,), stride, pad_mode, padding, @@ -647,12 +643,26 @@ def __init__(self, has_bias, None, None) - self.padding = padding + self.padding = (0, 0, padding, padding) + self.padding = (0, 0, padding, padding) + Validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.cls_name) + self.conv2d = ops.Conv2D(out_channel=self.out_channels, + kernel_size=(1, kernel_size), + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) def construct(self, x): - return ops.conv1d(x, self.weight, self.bias, stride=self.stride, pad_mode=self.pad_mode, - padding=self.padding, dilation=self.dilation, groups=self.group) + x = x.expand_dims(2) + output = self.conv2d(x, self.weight.expand_dims(2)) + if self.has_bias: + output = ops.bias_add(output, self.bias) + output = output.squeeze(2) + return output class Conv1dTranspose(_Conv): """patched Conv1dTranspose""" diff --git a/mindnlp/transformers/generation/logits_process.py b/mindnlp/transformers/generation/logits_process.py index 279b9fbf3..9a51961d5 100644 --- a/mindnlp/transformers/generation/logits_process.py +++ b/mindnlp/transformers/generation/logits_process.py @@ -1302,7 +1302,7 @@ def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> min for k in range(input_ids.shape[0]): timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(axis=-1) max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() - if timestamp_logprob > max_text_token_logprob: + if not ops.isnan(timestamp_logprob) and timestamp_logprob > max_text_token_logprob: scores[k, : self.timestamp_begin] = -float("inf") return scores