From a77518e8f5b47b4290faacaae43c3a3126f77d5e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 4 Jul 2023 10:11:21 +0200 Subject: [PATCH] plot --- tests/unittests/utilities/test_plot.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index cfbe5d4ba60..9cd0320a2a2 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -632,6 +632,7 @@ def test_plot_methods(metric_class: object, preds: Callable, target: Callable, n assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + plt.close(fig) @pytest.mark.parametrize( @@ -699,17 +700,17 @@ def test_plot_methods_special_image_metrics(metric_class, preds, target, index_0 assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + plt.close(fig) -@pytest.mark.skipif(not hasattr(torch, "inference_mode"), reason="`inference_mode` is not supported") def test_plot_methods_special_text_metrics(): """Test the plot method for text metrics that does not fit the default testing format.""" metric = BERTScore() - with torch.inference_mode(): - metric.update(_text_input_1(), _text_input_2()) - fig, ax = metric.plot() + metric.update(_text_input_1(), _text_input_2()) + fig, ax = metric.plot() assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + plt.close(fig) @pytest.mark.parametrize( @@ -782,6 +783,7 @@ def test_plot_methods_retrieval(metric_class, preds, target, indexes, num_vals): assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + plt.close(fig) @pytest.mark.parametrize( @@ -821,6 +823,7 @@ def test_confusion_matrix_plotter(metric_class, preds, target, labels, use_label cond1 = isinstance(axs, matplotlib.axes.Axes) cond2 = isinstance(axs, np.ndarray) and all(isinstance(a, matplotlib.axes.Axes) for a in axs) assert cond1 or cond2 + plt.close(fig) @pytest.mark.parametrize("together", [True, False]) @@ -859,6 +862,7 @@ def test_plot_method_collection(together, num_vals): fig, ax = plt.subplots(nrows=len(m_collection) + 1, ncols=1) with pytest.raises(ValueError, match="Expected argument `ax` to be a sequence of matplotlib axis objects with.*"): m_collection.plot(ax=ax.tolist()) + plt.close(fig) @pytest.mark.parametrize( @@ -915,6 +919,7 @@ def test_plot_method_curve_metrics(metric_class, preds, target, thresholds, scor fig, ax = metric.plot(score=score) assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + plt.close(fig) def test_tracker_plotter(): @@ -927,3 +932,4 @@ def test_tracker_plotter(): fig, ax = tracker.plot() # plot all epochs assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + plt.close(fig)