Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label get param #306

Open
wants to merge 115 commits into
base: development
Choose a base branch
from
Open

Label get param #306

wants to merge 115 commits into from

Conversation

BalzaniEdoardo
Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo commented Feb 10, 2025

This PR addresses the following

  1. Add a __getitem__ to Basis. Labels can be used to get the desired basis in the tree.
  2. Improved label logic:
    • User provided labels in composite basis must be unique.
    • Default labels in composite bases (class name) is disambiguated by appending _<number>.
    • Labeling a basis using another basis type class name results in an error.
  3. get_params and set_params override default behavior: instead of reflecting the full attribute structure, basis1__basis2...__parameter_name, the parameter key is set to <label>_parameter_name.

For a concrete example, refer to the new how-to note: handling_composite_bases.md

Key Implementation: Setter of Basis

For the PR review, focus on the behavior of @basis1.setter and @basis2.setter. This property determines how the set_params is finalized. In particular, in case one uses set_params to set a basis funciton, some extra attribute must be modfied:

  • If one sets a basis with set_params, the new gets the _parent attribute from the current.
  • Same thing for input shape
  • And try to also grab the current label, if possible.

Copy link
Member

@billbrod billbrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played around with this a bit locally, and I think it behaves as expected.

In addition to the comments:

  • We need some linting to happen on the notebooks: we're getting a lot of changes that I think are just coming from differences in how editor vs. jupyter lab is saving them, which is distracting
  • I want to draw particular attention to my comment in basis/__init__.py on __all__, something is weird there
  • I don't understand what's going on with sklearn_get_params / get_params and with _map_parameters / remap_parameters -- can you explain? Basically, I think I want a description that's a bit more detailed than the high-level one in the PR now.
  • We only support __getitem__ for labels, not __getattr__, which is fine, but do we want to raise a more specific error message if someone tries to grab an attribute using something we know is a label?
  • Relatedly, if they use attributes, they need to do basis1.basis2.etc but getitem allows for labels. Probably worth explaining that in the new doc page?

:::


Note that if you don't provide a label, basis class name is used to construct the keys.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From here to the end belongs in the second section, I think. it's not about sklearn

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to move the content, let me know


Parameters
----------
deep
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing description

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now it is described

@BalzaniEdoardo
Copy link
Collaborator Author

From last version of the PR, many things got moved around after talking with @billbrod.

Composite basis tree-structure traversing

Many operation on composite basis require updating child bases. So far, most of the machinery relied on the child basis methods and properties. This is included the label setter property implemented in this PR.

With this revised PR, I moved all the label editing logic, as well as other functions that operate over the whole basis tree, in a dedicated module. This functions are now independent from the child bases internal machinery. This is independence is not tested here but will be the object of a dedicated PR in which I will add a mock basis class with a single method, the compute_features, and make sure that I can successfully add it to nemos basis and that the overall machinery keeps working.

Get Params Logic

The overwritten get_params set_params logic has been moved to the CompositeBasis, which is the only basis that needs special machinery. Atomic basis rely on the standard sklearn machinery.

Implementation details

  1. __sklearn_get_params__: implements the standard sklearn parameter extraction recursion which relies on class signature. Needed for Base.set_params to work correctly, otherwise signature names and key structure of the parameter dict generated would mismatch.
  2. _get_params_and_key_map recursion similar to __sklearn_get_params__ but that re-labels the key of the parameter dictionary using basis.label, and additionally stores a key_map dictionary mapping the new keys to the standard structure based ones.
  3. _map_params: this one calls _get_params_and_key_map and then drops the label from the key of the parameters of self (which is the top level bases), returns two dictionaries: the one storing the parameter and the one with the key mapping.
  4. get_params: calls _map_params and return only the dictionaries with the parameters.
  5. remap_params: decorator that wraps composite basis set_params. pre-processing It receives a parameter dictionary with either the keys based on attributes (if set_params is called internally by sklearn or Base) or the re-mapped dictionary that uses labels, and maps it back to the standard sklearn attribute structure-based dictionary. processing: calls the nemos.Base set_params on the re-formatted dictionary. post-processing:
    re-assign labels.

BalzaniEdoardo and others added 5 commits February 21, 2025 16:25
Co-authored-by: William F. Broderick <billbrod@gmail.com>
Co-authored-by: William F. Broderick <billbrod@gmail.com>
Copy link
Collaborator Author

@BalzaniEdoardo BalzaniEdoardo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worked on the revisions! ready for a second round

:::


Note that if you don't provide a label, basis class name is used to construct the keys.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to move the content, let me know

@@ -90,6 +90,8 @@ def _shift_time_axis_and_convolve(array: NDArray, eval_basis: NDArray, axis: int
-----
This function supports arrays of any dimensionality greater or equal than 1.
"""
# convert axis
axis = axis if axis >= 0 else array.ndim + axis
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought I already merged this, unsure why it shows. I would keep it since it is a bugfix

@@ -156,6 +156,8 @@ def _pad_dimension(
"acausal": ((pad_size) // 2, pad_size - (pad_size) // 2),
"anti-causal": (0, pad_size),
}
# convert negative axis in jax jit compilable way
axis = axis * (axis >= 0) + (array.ndim + axis) * (axis < 0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing, I thought i merged this in main


all_params = obj.get_params(deep=False)
# use special method for basis
get_params = getattr(obj, "__sklearn_get_params__", obj.get_params)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I explained above, but since it is tricky I add another explenation. the idea is that we need to be able to both retrieve the actual attribute structure and the re-mapping using labels.

Why is that? Because scikit-learn (and our base class) use recursively the signature of classes implementing get/set parameters to determine which parameter can be edited. To work correctly, the recursion must be able to retrieve a dictionary with the actual structure ("basis1__basis2__..."), while the user facing get_prams should use the labels.


all_params = obj.get_params(deep=False)
# use special method for basis
get_params = getattr(obj, "__sklearn_get_params__", obj.get_params)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw here we don't need to use the private one, so i changed that, but we still need the method


Parameters
----------
deep
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now it is described

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants