-
Notifications
You must be signed in to change notification settings - Fork 9.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
word-level timestamps in transcribe()
#869
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
8f9357f
word-level timestamps in `transcribe()`
jongwook 46ea501
moving to `timing.py`
jongwook cfd2b81
Merge branch 'main' into word-level-timestamps
jongwook 742d2f4
numba implementation for dtw, replacing dtw-python
jongwook fb12414
Merge branch 'main' into word-level-timestamps
jongwook 80331c0
triton implementation for dtw
jongwook 1d2ed66
add test for dtw implementations
jongwook b61e8f4
triton implementation of median_filter
jongwook 54f2901
a simple word-level timestamps test
jongwook 8ce6277
add scipy as dev dependency
jongwook 812f446
Merge branch 'main' into word-level-timestamps
jongwook cd5191f
installs an older version of Triton if CUDA < 11.4
jongwook f64d8bc
Merge branch 'main' into word-level-timestamps
jongwook 89133bd
Merge branch 'main' into word-level-timestamps
jongwook d4f9399
fix broken merge
jongwook 040aa04
Merge branch 'main' into word-level-timestamps
jongwook 8e2756b
loosen nvcc version match regex
jongwook 6c431c4
find_alignment() function
jongwook ff6cbfd
Merge branch 'main' into word-level-timestamps
jongwook 5fa4356
miscellaneous improvements
jongwook 48537aa
skip median filtering when the input is too small
jongwook 8eb29c3
Expose punctuation options in cli and transcribe() (#973)
ryanheise 6ed4c11
Merge branch 'main' into word-level-timestamps
jongwook 31cd418
fix merge error
jongwook 145f325
fix merge error 2
jongwook 2b079c4
annotating that word_timestamps is experimental
jongwook File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
numba | ||
numpy | ||
torch | ||
tqdm | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import random as rand | ||
|
||
import numpy | ||
import pytest | ||
|
||
|
||
def pytest_configure(config): | ||
config.addinivalue_line("markers", "requires_cuda") | ||
|
||
|
||
@pytest.fixture | ||
def random(): | ||
rand.seed(42) | ||
numpy.random.seed(42) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pytest | ||
import numpy as np | ||
import scipy.ndimage | ||
import torch | ||
|
||
from whisper.timing import dtw_cpu, dtw_cuda, median_filter | ||
|
||
|
||
sizes = [ | ||
(10, 20), (32, 16), (123, 1500), (234, 189), | ||
] | ||
shapes = [ | ||
(10,), (1, 15), (4, 5, 345), (6, 12, 240, 512), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("N, M", sizes) | ||
def test_dtw(N: int, M: int): | ||
steps = np.concatenate([np.zeros(N - 1), np.ones(M - 1)]) | ||
np.random.shuffle(steps) | ||
x = np.random.random((N, M)).astype(np.float32) | ||
|
||
i, j, k = 0, 0, 0 | ||
trace = [] | ||
while True: | ||
x[i, j] -= 1 | ||
trace.append((i, j)) | ||
|
||
if k == len(steps): | ||
break | ||
|
||
if k + 1 < len(steps) and steps[k] != steps[k + 1]: | ||
i += 1 | ||
j += 1 | ||
k += 2 | ||
continue | ||
|
||
if steps[k] == 0: | ||
i += 1 | ||
if steps[k] == 1: | ||
j += 1 | ||
k += 1 | ||
|
||
trace = np.array(trace).T | ||
dtw_trace = dtw_cpu(x) | ||
|
||
assert np.allclose(trace, dtw_trace) | ||
|
||
|
||
@pytest.mark.requires_cuda | ||
@pytest.mark.parametrize("N, M", sizes) | ||
def test_dtw_cuda_equivalence(N: int, M: int): | ||
x_numpy = np.random.randn(N, M).astype(np.float32) | ||
x_cuda = torch.from_numpy(x_numpy).cuda() | ||
|
||
trace_cpu = dtw_cpu(x_numpy) | ||
trace_cuda = dtw_cuda(x_cuda) | ||
|
||
assert np.allclose(trace_cpu, trace_cuda) | ||
|
||
|
||
@pytest.mark.parametrize("shape", shapes) | ||
def test_median_filter(shape): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: Is there a licensing issue using |
||
x = torch.randn(*shape) | ||
|
||
for filter_width in [3, 5, 7, 13]: | ||
filtered = median_filter(x, filter_width) | ||
|
||
# using np.pad to reflect-pad, because Scipy's behavior is different near the edges. | ||
pad_width = filter_width // 2 | ||
padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect") | ||
scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width]) | ||
scipy_filtered = scipy_filtered[..., pad_width:-pad_width] | ||
|
||
assert np.allclose(filtered, scipy_filtered) | ||
|
||
|
||
@pytest.mark.requires_cuda | ||
@pytest.mark.parametrize("shape", shapes) | ||
def test_median_filter_equivalence(shape): | ||
x = torch.randn(*shape) | ||
|
||
for filter_width in [3, 5, 7, 13]: | ||
filtered_cpu = median_filter(x, filter_width) | ||
filtered_gpu = median_filter(x.cuda(), filter_width).cpu() | ||
|
||
assert np.allclose(filtered_cpu, filtered_gpu) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was your reason to not use the dtw library licensing concerns or just speedup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtw-python is GPL, as mentioned here -
#869 (comment)