Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add absvariable as an option #31

Merged
merged 1 commit into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pinnicle/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def cmap_Rignot():
cmap = ListedColormap(cmap)
return cmap

def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolution=200, **kwargs):
def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolution=200, absvariable=[], **kwargs):
""" plot model predictions

Args:
Expand All @@ -28,6 +28,7 @@ def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolutio
u_ref (dict): Reference solutions, if None, then just plot the predicted solutions
cols (int): Number of columns of subplot
resolution (int): Number of grid points per row/column for plotting
absvariable (list): Names of variables in the predictions that will need to take abs() before comparison
"""
# generate Cartisian grid of X, Y
# currently only work on 2D
Expand All @@ -44,6 +45,9 @@ def plot_solutions(pinn, path="", X_ref=None, sol_ref=None, cols=None, resolutio
sol_pred = pinn.model.predict(X_nn)
plot_data = {k+"_pred":np.reshape(sol_pred[:,i:i+1], X.shape) for i,k in enumerate(pinn.params.nn.output_variables)}
vranges = {k+"_pred":[pinn.params.nn.output_lb[i], pinn.params.nn.output_ub[i]] for i,k in enumerate(pinn.params.nn.output_variables)}
# take abs
for k in absvariable:
plot_data[k+"_pred"] = np.abs( plot_data[k+"_pred"])

# if ref solution is provided
if (sol_ref is not None) and (X_ref is not None):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_plot(tmp_path):
experiment.model_data.data["ISSM"].X_dict['y'].flatten()[:,None]))
assert experiment.plot_predictions(X_ref=X_ref,
sol_ref=experiment.model_data.data["ISSM"].data_dict,
resolution=10) is None
resolution=10, absvariable=['C']) is None
X, Y, im_data, axs = plot_nn(experiment, experiment.model_data.data["ISSM"].data_dict, resolution=10);
assert X.shape == (10,10)
assert Y.shape == (10,10)
Expand Down
Loading