Skip to content

Commit

Permalink
Fix docstring parsing issue in Python 3.13 (#1082)
Browse files Browse the repository at this point in the history
Python 3.13 auto dedents docstrings, breaking the parsing.

With this fix, parsing works the same as before.
  • Loading branch information
raphaelrubrice authored Jan 18, 2025
1 parent 547d3ff commit bb1bac4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
23 changes: 14 additions & 9 deletions skorch/classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""NeuralNet subclasses for classification tasks."""

import re
import textwrap

import numpy as np
from sklearn.base import ClassifierMixin
Expand Down Expand Up @@ -35,21 +36,23 @@
skorch behavior should be restored, i.e. raising an
``AttributeError``, pass an empty list."""

neural_net_clf_additional_attribute = """classes_ : array, shape (n_classes, )
neural_net_clf_additional_attribute = """ classes_ : array, shape (n_classes, )
A list of class labels known to the classifier.
"""


def get_neural_net_clf_doc(doc):
doc = neural_net_clf_doc_start + " " + doc.split("\n ", 4)[-1]
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+){1,99}')
indentation = " "
# dedent/indent roundtrip required for consistent indention in both
# Python <3.13 and Python >=3.13
# Because <3.13 => not automatic dedent, but it is the case in >=3.13
doc = neural_net_clf_doc_start + " " + textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
start, end = pattern.search(doc).span()
doc = doc[:start] + neural_net_clf_additional_text + doc[end:]
doc = doc + neural_net_clf_additional_attribute
return doc


# pylint: disable=missing-docstring
class NeuralNetClassifier(ClassifierMixin, NeuralNet):
__doc__ = get_neural_net_clf_doc(NeuralNet.__doc__)
Expand Down Expand Up @@ -249,15 +252,17 @@ def predict(self, X):
Probabilities above this threshold is classified as 1. ``threshold``
is used by ``predict`` and ``predict_proba`` for classification."""


def get_neural_net_binary_clf_doc(doc):
doc = neural_net_binary_clf_doc_start + " " + doc.split("\n ", 4)[-1]
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+){1,99}')
indentation = " "
# dedent/indent roundtrip required for consistent indention in both
# Python <3.13 and Python >=3.13
# Because <3.13 => not automatic dedent, but it is the case in >=3.13
doc = neural_net_binary_clf_doc_start + " " + textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
start, end = pattern.search(doc).span()
doc = doc[:start] + neural_net_binary_clf_criterion_text + doc[end:]
return doc


class NeuralNetBinaryClassifier(ClassifierMixin, NeuralNet):
# pylint: disable=missing-docstring
__doc__ = get_neural_net_binary_clf_doc(NeuralNet.__doc__)
Expand Down
11 changes: 7 additions & 4 deletions skorch/regressor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""NeuralNet subclasses for regression tasks."""

import re
import textwrap

from sklearn.base import RegressorMixin
import torch
Expand All @@ -23,15 +24,17 @@
criterion : torch criterion (class, default=torch.nn.MSELoss)
Mean squared error loss."""


def get_neural_net_reg_doc(doc):
doc = neural_net_reg_doc_start + " " + doc.split("\n ", 4)[-1]
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+){1,99}')
indentation = " "
# dedent/indent roundtrip required for consistent indention in both
# Python <3.13 and Python >=3.13
# Because <3.13 => not automatic dedent, but it is the case in >=3.13
doc = neural_net_reg_doc_start + " " + textwrap.indent(textwrap.dedent(doc.split("\n", 5)[-1]), indentation)
pattern = re.compile(r'(\n\s+)(criterion .*\n)(\s.+|.){1,99}')
start, end = pattern.search(doc).span()
doc = doc[:start] + neural_net_reg_criterion_text + doc[end:]
return doc


# pylint: disable=missing-docstring
class NeuralNetRegressor(RegressorMixin, NeuralNet):
__doc__ = get_neural_net_reg_doc(NeuralNet.__doc__)
Expand Down

0 comments on commit bb1bac4

Please sign in to comment.