Skip to content

Commit

Permalink
filters: adding Wiener and median filter (#7)
Browse files Browse the repository at this point in the history
Co-authored-by: Marius Isken <marius.isken@gfz-potsdam.de>
  • Loading branch information
miili and Marius Isken authored Oct 29, 2024
1 parent de21a19 commit e508494
Showing 1 changed file with 37 additions and 30 deletions.
67 changes: 37 additions & 30 deletions lightguide/blast.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,25 @@
from __future__ import annotations

import logging
import re
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Literal,
TypeVar,
cast,
)
from typing import Any, Callable, Iterable, Iterator, Literal, TypeVar, cast

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors, dates
from matplotlib.colors import Colormap
from pyrocko import io
from pyrocko.trace import Trace
from scipy import signal
from scipy import ndimage, signal

from lightguide.utils import PathStr

from .filters import afk_filter
from .signal import decimation_coefficients

if TYPE_CHECKING:
from matplotlib import image


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -297,6 +285,23 @@ def bandpass(
else:
self.data = signal.sosfilt(sos, self.data, axis=1)

def wiener_filter(self, size: int = 5):
"""Apply Wiener filter in-place.
Args:
size (int, optional): Size of the footprint. Defaults to 5.
"""
self.data = signal.wiener(self.data, mysize=size)

def median_filter(self, size: int = 3) -> None:
"""Apply median filter in-place.
Args:
size (int, optional): Footprint of the filter.
Defaults to 3.
"""
self.data = ndimage.median_filter(self.data, size=size)

def afk_filter(
self,
exponent: float = 0.8,
Expand Down Expand Up @@ -497,7 +502,7 @@ def to_strain(self, detrend: bool = True) -> Blast:
TypeError: Raised when the input Blast is not in strain rate.
Returns:
Blast: In strain strain.
Blast: Copy in strain.
"""
if self.unit == "strain":
return self.copy()
Expand Down Expand Up @@ -541,7 +546,7 @@ def to_relative_displacement(self, detrend: bool = True) -> Blast:
TypeError: Raised when the input Blast is not in strain rate.
Returns:
Blast: As strain strain.
Blast: Copy in strain.
"""
blast = self.to_strain()
if detrend:
Expand All @@ -556,8 +561,9 @@ def plot(
normalize_traces: bool = True,
cmap: str | Colormap = "seismic",
show_date: bool = False,
interpolation: str = "nearest",
show_channel: bool = False,
) -> image.AxesImage:
) -> plt.Axes:
"""Plot data of the blast.
Args:
Expand All @@ -567,6 +573,8 @@ def plot(
Defaults to True.
cmap (str | Colormap, optional): Matplotlib colormap. Defaults to "seismic".
show_date (bool, optional): Shot absolute dates in UTC. Defaults to False.
interpolation (str, optional): Interpolation for plt.imshow.
Defaults to "nearest".
show_channel (bool, optional): Show channels instead of meters.
Defaults to False.
Expand Down Expand Up @@ -594,10 +602,10 @@ def plot(
if normalize_traces:
data /= np.abs(data.max(axis=1, keepdims=True))

img = ax.imshow(
ax.imshow(
data.T,
aspect="auto",
# interpolation="nearest",
interpolation=interpolation,
cmap=cmap,
extent=extent,
norm=colors.CenteredNorm(),
Expand All @@ -614,7 +622,7 @@ def plot(

if axes is None:
plt.show()
return img
return ax

def copy(self) -> Blast:
"""Return a copy of the blast.
Expand Down Expand Up @@ -671,7 +679,7 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas
if not traces:
raise ValueError("Empty list of traces")

traces = sorted(traces, key=lambda tr: int(tr.station))
traces = sorted(traces, key=lambda tr: tr.station)
ntraces = len(traces)

tmin = set()
Expand Down Expand Up @@ -702,7 +710,9 @@ def from_pyrocko(cls, traces: list[Trace], channel_spacing: float = 4.0) -> Blas
data=data,
start_time=datetime.fromtimestamp(tmin.pop(), tz=timezone.utc),
sampling_rate=int(1.0 / delta_t.pop()),
start_channel=min(int(tr.station) for tr in traces),
start_channel=min(
int(re.sub(r"[a-zA-Z]", "", tr.station)) for tr in traces
),
channel_spacing=channel_spacing,
)

Expand All @@ -716,7 +726,7 @@ def from_miniseed(cls, file: PathStr, channel_spacing: float = 4.0) -> Blast:
Defaults to 4.0.
Returns:
Blast: Produced Blast.
Blast: Assembled Blast.
"""
from pyrocko import io

Expand Down Expand Up @@ -783,19 +793,16 @@ def __len__(self) -> int:
lowpass = shared_function(Blast.lowpass)
highpass = shared_function(Blast.highpass)
bandpass = shared_function(Blast.bandpass)
afk_filter = shared_function(Blast.afk_filter)
decimate = shared_function(Blast.decimate)
afk_filter = shared_function(Blast.afk_filter)
wiener_filter = shared_function(Blast.wiener_filter)
median_filter = shared_function(Blast.median_filter)

trim_time = shared_function(Blast.trim_time)
trim_channels = shared_function(Blast.trim_channels)

mute_median = shared_function(Blast.mute_median)
one_bit_normalization = shared_function(Blast.one_bit_normalization)
afk_filter = shared_function(Blast.afk_filter)
decimate = shared_function(Blast.decimate)

trim_time = shared_function(Blast.trim_time)
trim_channels = shared_function(Blast.trim_channels)

to_strain = shared_function(Blast.to_strain)
to_relative_velocity = shared_function(Blast.to_relative_velocity)
Expand Down

0 comments on commit e508494

Please sign in to comment.