import os
import numpy as np
__all__ = ['plot_trace', 'plot_gather']
[docs]
def plot_trace(*args, axis=None, colour='black', line_style='solid', title=None, **kwargs):
"""
Utility function to plot individual traces using matplotlib.
Parameters
----------
args : arrays
Optional time grid and signal to be plotted.
axis : matplotlib figure, optional
Figure in which to make the plotting, defaults to new empty figure.
colour : str, optional
Colour to apply to the points, defaults to red.
line_style : str, optional
Line style to be used.
title : str, optional
Figure title, defaults to empty title.
Returns
-------
matplotlib figure
Generated matplotlib figure
"""
try:
if not os.environ.get('DISPLAY', None):
raise ModuleNotFoundError
import matplotlib.pyplot as plt
except ModuleNotFoundError:
return None
if axis is None:
figure, axis = plt.subplots(1, 1)
default_kwargs = dict(c=colour, linestyle=line_style)
default_kwargs.update(kwargs)
im = axis.plot(*args, **default_kwargs)
if title is not None:
axis.set_title(title)
return axis
[docs]
def plot_gather(*args, skip=1, time_range=None, norm=True, norm_trace=True,
colour='black', line_style='solid', title=None, axis=None, **kwargs):
"""
Utility function to plot gather using matplotlib.
Parameters
----------
args : arrays
Optional trace ID grid, optional time grid and signal to be plotted.
skip : int, optional
Traces to skip, defaults to 1.
time_range : tuple, optional
Range of time to plot, defaults to all time.
norm : bool, optional
Whether or not to normalise the gather, defaults to True.
norm_trace : bool, optional
Whether or not to normalise trace by trace, defaults to True.
axis : matplotlib figure, optional
Figure in which to make the plotting, defaults to new empty figure.
colour : str, optional
Colour to apply to the points, defaults to red.
line_style : str, optional
Line style to be used.
title : str, optional
Figure title, defaults to empty title.
Returns
-------
matplotlib figure
Generated matplotlib figure
"""
try:
if not os.environ.get('DISPLAY', None):
raise ModuleNotFoundError
import matplotlib.pyplot as plt
except ModuleNotFoundError:
return None, None
if len(args) > 2:
trace_axis = args[0]
time_axis = args[1]
signal = args[2]
elif len(args) > 1:
trace_axis = None
time_axis = args[0]
signal = args[1]
else:
trace_axis = None
time_axis = None
signal = args[0]
if axis is None:
figure, axis = plt.subplots(1, 1)
if time_range is None:
time_range = (0, signal.shape[-1])
if norm is True:
signal = signal / (np.max(np.abs(signal))+1e-30)
num_traces = signal.shape[0]
if norm_trace is True:
signal = signal / (np.max(np.abs(signal), axis=-1).reshape((num_traces, 1))+1e-30)
signal_under = signal[0:num_traces:skip, time_range[0]:time_range[1]]
num_under_traces = signal_under.shape[0]
shift = np.arange(0, num_under_traces) * 2.00
shift = np.reshape(shift, (shift.shape[0], 1))
signal_shifted = np.transpose(signal_under + shift)
if time_axis is None:
time_axis = np.linspace(0, time_range[1]-time_range[0]-1, time_range[1]-time_range[0], endpoint=False)
time_axis = np.broadcast_to(np.reshape(time_axis, (time_axis.shape[0], 1)), signal_shifted.shape)
default_kwargs = dict(c=colour, linestyle=line_style)
default_kwargs.update(kwargs)
axis.plot(signal_shifted, time_axis, **default_kwargs)
axis.set_ylim(time_axis[-1, 0], time_axis[0, 0])
axis.set_xlabel('trace')
axis.set_ylabel('time')
if trace_axis is None:
trace_axis = np.linspace(0, num_traces-1, num_under_traces, endpoint=True)
else:
trace_axis = trace_axis[::skip]
trace_axis = [str(each) for each in trace_axis]
axis.set_xticks(shift.flatten()[::2])
axis.set_xticklabels(trace_axis[::2])
if title is not None:
axis.set_title(title)
return shift, axis