Source code for piel.analysis.signals.time.core.threshold
import numpy as np
from scipy.signal import find_peaks
from piel.types import TimeSignalData, MultiTimeSignalData
from typing import Optional, List
import logging
logger = logging.getLogger(__name__)
[docs]
def extract_signal_above_threshold(
signal_data: TimeSignalData,
threshold: float,
min_pulse_width_s: float = 0.0,
noise_floor: float = 0.0,
) -> MultiTimeSignalData:
"""
Extracts all pulses from the input signal that exceed the specified threshold.
Args:
signal_data (TimeSignalData): The original signal data containing time and data arrays.
threshold (float): The data value threshold to identify pulses.
min_pulse_width_s (float, optional): The minimum duration (in seconds) for a pulse to be considered valid.
Pulses shorter than this duration will be ignored. Defaults to 0.0.
noise_floor (float, optional): The value to assign to non-pulse regions in the extracted pulses.
Defaults to 0.0.
Returns:
MultiTimeSignalData: A list of DataTimeSignalData instances, each representing a detected pulse.
"""
# Convert lists to NumPy arrays for efficient processing
time = np.array(signal_data.time_s)
data = np.array(signal_data.data)
if len(time) != len(data):
raise ValueError("Time and data arrays must have the same length.")
# Identify where data exceeds the threshold
above_threshold = data > threshold
# Find rising and falling edges
edges = np.diff(above_threshold.astype(int))
pulse_start_indices = (
np.where(edges == 1)[0] + 1
) # +1 to correct the index after diff
pulse_end_indices = np.where(edges == -1)[0] + 1
# Handle edge cases where the signal starts or ends above the threshold
if above_threshold[0]:
pulse_start_indices = np.insert(pulse_start_indices, 0, 0)
if above_threshold[-1]:
pulse_end_indices = np.append(pulse_end_indices, len(data))
logger.debug(f"Detected {len(pulse_start_indices)} potential pulses.")
# Initialize list to hold extracted pulses
extracted_pulses: MultiTimeSignalData = []
# Iterate over each detected pulse
for idx, (start_idx, end_idx) in enumerate(
zip(pulse_start_indices, pulse_end_indices), start=1
):
pulse_duration = time[end_idx - 1] - time[start_idx]
if pulse_duration < min_pulse_width_s:
logger.debug(
f"Pulse {idx} ignored due to insufficient width: {pulse_duration}s < {min_pulse_width_s}s."
)
continue # Skip pulses that are too short
# Extract the pulse time and data
pulse_time = time[start_idx:end_idx]
pulse_data = data[start_idx:end_idx]
# Optionally, assign noise_floor to non-pulse regions if maintaining original array length
# Here, we create pulses with their own time and data arrays
# Create a TimeSignalData instance for the pulse
pulse_signal = TimeSignalData(
time_s=pulse_time.tolist(),
data=pulse_data.tolist(),
data_name=f"{signal_data.data_name}_pulse_{idx}",
)
extracted_pulses.append(pulse_signal)
logger.debug(
f"Pulse {idx} extracted: Start={time[start_idx]}s, End={time[end_idx - 1]}s, Duration={pulse_duration}s."
)
logger.info(f"Total pulses extracted: {len(extracted_pulses)}.")
return extracted_pulses
[docs]
def extract_pulses_from_signal(
full_data: TimeSignalData,
pre_pulse_time_s: float = 0.01,
post_pulse_time_s: float = 0.01,
noise_std_multiplier: float = 3.0,
min_pulse_height: Optional[float] = None,
min_pulse_distance_s: Optional[float] = None,
data_time_signal_kwargs: Optional[dict] = None,
) -> List[TimeSignalData]:
"""
Detects and extracts pulses from a DataTimeSignalData instance, including segments
before and after each pulse up to the noise floor.
Parameters:
full_data (TimeSignalData): The input signal data containing multiple pulses.
pre_pulse_time_s (float): Time (in seconds) to include before each detected pulse.
post_pulse_time_s (float): Time (in seconds) to include after each detected pulse.
noise_std_multiplier (float): Multiplier for noise standard deviation to set detection threshold.
min_pulse_height (float, optional): Minimum height of a pulse to be detected. If not provided,
it is set to noise_std_multiplier * noise_std.
min_pulse_distance_s (float, optional): Minimum distance (in seconds) between consecutive pulses.
If not provided, it is set based on the pre_pulse_time and post_pulse_time.
data_time_signal_kwargs (dict, optional): Additional keyword arguments for DataTimeSignalData.
Returns:
List[TimeSignalData]: A list of DataTimeSignalData instances, each representing an extracted pulse.
"""
if data_time_signal_kwargs is None:
data_time_signal_kwargs = {}
data = np.array(full_data.data)
time_s = np.array(full_data.time_s)
if len(time_s) != len(data):
raise ValueError("time_s and data must have the same length.")
# Compute baseline and noise statistics
baseline = np.mean(data)
noise_std = np.std(data)
# Set detection threshold
if min_pulse_height is None:
detection_threshold = baseline + noise_std_multiplier * noise_std
else:
detection_threshold = min_pulse_height
# Determine sampling rate
if len(time_s) < 2:
raise ValueError(
"time_s array must contain at least two elements to calculate sampling rate."
)
sampling_intervals = np.diff(time_s)
mean_sampling_interval = np.mean(sampling_intervals)
sampling_rate = 1.0 / mean_sampling_interval
# Set minimum distance between pulses
if min_pulse_distance_s is None:
# Minimum distance in samples based on pre and post pulse time
min_pulse_distance_s = (pre_pulse_time_s + post_pulse_time_s) * sampling_rate
else:
# Convert distance from seconds to samples
min_pulse_distance_s = min_pulse_distance_s * sampling_rate
# Detect peaks
peaks, properties = find_peaks(
data,
height=detection_threshold,
distance=min_pulse_distance_s,
)
if len(peaks) == 0:
raise ValueError("No pulses detected based on the provided criteria.")
extracted_pulses = []
for peak_idx in peaks:
# Define window around the peak
peak_time = time_s[peak_idx]
# Determine pre-pulse start time
pre_start_time = peak_time - pre_pulse_time_s
pre_start_time = max(pre_start_time, time_s[0])
# Determine post-pulse end time
post_end_time = peak_time + post_pulse_time_s
post_end_time = min(post_end_time, time_s[-1])
# Find indices corresponding to pre_start_time and post_end_time
pre_start_idx = np.searchsorted(time_s, pre_start_time, side="left")
post_end_idx = np.searchsorted(time_s, post_end_time, side="right")
# Extract the segment
segment_time = time_s[pre_start_idx:post_end_idx]
segment_data = data[pre_start_idx:post_end_idx]
# Create a new TimeSignalData instance for the pulse
pulse_data_name = f"{full_data.data_name}_pulse_{peak_idx}"
extracted_pulse = TimeSignalData(
time_s=segment_time.tolist(),
data=segment_data.tolist(),
data_name=pulse_data_name,
**data_time_signal_kwargs,
)
extracted_pulses.append(extracted_pulse)
return extracted_pulses
[docs]
def is_pulse_above_threshold(pulse: TimeSignalData, threshold: float) -> bool:
"""
Determines if the pulse's amplitude exceeds the specified threshold.
Parameters:
pulse (TimeSignalData): The pulse data to evaluate.
threshold (float): The amplitude threshold.
Returns:
bool: True if the pulse's maximum amplitude is greater than or equal to the threshold, False otherwise.
"""
if not pulse.data:
return False
max_amplitude = max(pulse.data)
return max_amplitude >= threshold