diff --git a/lightguide/blast.py b/lightguide/blast.py index 14fa55a..fe9cb3d 100644 --- a/lightguide/blast.py +++ b/lightguide/blast.py @@ -1,19 +1,11 @@ from __future__ import annotations import logging +import re from copy import deepcopy from datetime import datetime, timedelta, timezone from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - Iterator, - Literal, - TypeVar, - cast, -) +from typing import Any, Callable, Iterable, Iterator, Literal, TypeVar, cast import matplotlib.pyplot as plt import numpy as np @@ -21,17 +13,13 @@ from matplotlib.colors import Colormap from pyrocko import io from pyrocko.trace import Trace -from scipy import signal +from scipy import ndimage, signal from lightguide.utils import PathStr from .filters import afk_filter from .signal import decimation_coefficients -if TYPE_CHECKING: - from matplotlib import image - - logger = logging.getLogger(__name__) @@ -297,6 +285,23 @@ def bandpass( else: self.data = signal.sosfilt(sos, self.data, axis=1) + def wiener_filter(self, size: int = 5): + """Apply Wiener filter in-place. + + Args: + size (int, optional): Size of the footprint. Defaults to 5. + """ + self.data = signal.wiener(self.data, mysize=size) + + def median_filter(self, size: int = 3) -> None: + """Apply median filter in-place. + + Args: + size (int, optional): Footprint of the filter. + Defaults to 3. + """ + self.data = ndimage.median_filter(self.data, size=size) + def afk_filter( self, exponent: float = 0.8, @@ -497,7 +502,7 @@ def to_strain(self, detrend: bool = True) -> Blast: TypeError: Raised when the input Blast is not in strain rate. Returns: - Blast: In strain strain. + Blast: Copy in strain. """ if self.unit == "strain": return self.copy() @@ -541,7 +546,7 @@ def to_relative_displacement(self, detrend: bool = True) -> Blast: TypeError: Raised when the input Blast is not in strain rate. Returns: - Blast: As strain strain. + Blast: Copy in strain. """ blast = self.to_strain() if detrend: @@ -556,8 +561,9 @@ def plot( normalize_traces: bool = True, cmap: str | Colormap = "seismic", show_date: bool = False, + interpolation: str = "nearest", show_channel: bool = False, - ) -> image.AxesImage: + ) -> plt.Axes: """Plot data of the blast. Args: @@ -567,6 +573,8 @@ def plot( Defaults to True. cmap (str | Colormap, optional): Matplotlib colormap. Defaults to "seismic". show_date (bool, optional): Shot absolute dates in UTC. Defaults to False. + interpolation (str, optional): Interpolation for plt.imshow. + Defaults to "nearest". show_channel (bool, optional): Show channels instead of meters. Defaults to False. @@ -594,10 +602,10 @@ def plot( if normalize_traces: data /= np.abs(data.max(axis=1, keepdims=True)) - img = ax.imshow( + ax.imshow( data.T, aspect="auto", - # interpolation="nearest", + interpolation=interpolation, cmap=cmap, extent=extent, norm=colors.CenteredNorm(), @@ -614,7 +622,7 @@ def plot( if axes is None: plt.show() - return img + return ax def copy(self) -> Blast: """Return a copy of the blast. @@ -671,7 +679,7 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas if not traces: raise ValueError("Empty list of traces") - traces = sorted(traces, key=lambda tr: int(tr.station)) + traces = sorted(traces, key=lambda tr: tr.station) ntraces = len(traces) tmin = set() @@ -702,7 +710,9 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas data=data, start_time=datetime.fromtimestamp(tmin.pop(), tz=timezone.utc), sampling_rate=int(1.0 / delta_t.pop()), - start_channel=min(int(tr.station) for tr in traces), + start_channel=min( + int(re.sub(r"[a-zA-Z]", "", tr.station)) for tr in traces + ), channel_spacing=channel_spacing, ) @@ -716,7 +726,7 @@ def from_miniseed(cls, file: PathStr, channel_spacing: float = 4.0) -> Blast: Defaults to 4.0. Returns: - Blast: Produced Blast. + Blast: Assembled Blast. """ from pyrocko import io @@ -783,19 +793,16 @@ def __len__(self) -> int: lowpass = shared_function(Blast.lowpass) highpass = shared_function(Blast.highpass) bandpass = shared_function(Blast.bandpass) - afk_filter = shared_function(Blast.afk_filter) decimate = shared_function(Blast.decimate) + afk_filter = shared_function(Blast.afk_filter) + wiener_filter = shared_function(Blast.wiener_filter) + median_filter = shared_function(Blast.median_filter) trim_time = shared_function(Blast.trim_time) trim_channels = shared_function(Blast.trim_channels) mute_median = shared_function(Blast.mute_median) one_bit_normalization = shared_function(Blast.one_bit_normalization) - afk_filter = shared_function(Blast.afk_filter) - decimate = shared_function(Blast.decimate) - - trim_time = shared_function(Blast.trim_time) - trim_channels = shared_function(Blast.trim_channels) to_strain = shared_function(Blast.to_strain) to_relative_velocity = shared_function(Blast.to_relative_velocity)