Source code for stride.optimisation.pipelines.steps.filter_traces


from stride.utils import filters

from .utils import name_from_op_name
from ....core import Operator


[docs] class FilterTraces(Operator): """ Filter a set of time traces. Parameters ---------- f_min : float, optional Lower value for the frequency filter, defaults to None (no lower filtering). f_max : float, optional Upper value for the frequency filter, defaults to None (no upper filtering). filter_type : str, optional Type of filter to apply, from ``butterworth`` (default for band pass and high pass), ``fir``, or ``cos`` (default for low pass). filter_relaxation : float, optional Relaxation factor for the filter in range (0, 1], defaults to 1 (no dilation). """ def __init__(self, **kwargs): super().__init__(**kwargs) self.f_min = kwargs.pop('f_min', None) self.f_max = kwargs.pop('f_max', None) self.filter_type = kwargs.pop('filter_type', None) self.relaxation = kwargs.pop('filter_relaxation', 1.0) self._num_traces = None
[docs] def forward(self, *traces, **kwargs): self._num_traces = len(traces) filtered = [] for each in traces: filtered.append(self._apply(each, **kwargs)) if len(traces) > 1: return tuple(filtered) else: return filtered[0]
[docs] def adjoint(self, *d_traces, **kwargs): d_traces = d_traces[:self._num_traces] filtered = [] for each in d_traces: filtered.append(self._apply(each, adjoint=True, **kwargs)) self._num_traces = None if len(d_traces) > 1: return tuple(filtered) else: return filtered[0]
def _apply(self, traces, **kwargs): time = traces.time f_min = kwargs.pop('f_min', self.f_min) f_max = kwargs.pop('f_max', self.f_max) relaxation = kwargs.pop('filter_relaxation', self.relaxation) f_min_dim_less = relaxation*f_min*time.step if f_min is not None else 0 f_max_dim_less = 1/relaxation*f_max*time.step if f_max is not None else 0 out_traces = traces.alike(name=name_from_op_name(self, traces)) if f_min is None and f_max is not None: pass_type = 'lowpass' args = (f_max_dim_less,) elif f_min is not None and f_max is None: pass_type = 'highpass' args = (f_min_dim_less,) elif f_min is not None and f_max is not None: pass_type = 'bandpass' args = (f_min_dim_less, f_max_dim_less) else: out_traces.extended_data[:] = traces.extended_data return out_traces default_filter_type = 'cos' if f_min is None else 'butterworth' filter_type = kwargs.pop('filter_type', self.filter_type or default_filter_type) method_name = '%s_filter_%s' % (pass_type, filter_type) method = getattr(filters, method_name) filtered = method(traces.extended_data, *args, zero_phase=False, **kwargs) out_traces.extended_data[:] = filtered return out_traces