Source code for piel.experimental.analysis.dataset

from piel.types import ExperimentData


[docs] def compose_xarray_dataset_from_experiment_data( experiment_data: ExperimentData, ): """ Composes an xarray.Dataset from an ExperimentData instance, using all experiment parameters as coordinates. Args: experiment_data (ExperimentData): The experiment data containing parameters and measurements. Returns: xr.Dataset: An xarray Dataset containing the measurements indexed by all parameters and metric name. Raises: ValueError: If the number of parameters does not match the number of measurement data entries. AttributeError: If any measurement data entry lacks the `measurements` attribute. """ import xarray as xr import pandas as pd # Extract parameters DataFrame parameters_df = experiment_data.experiment.parameters measurements_collection = ( experiment_data.data ) # Expected to be PropagationDelayMeasurementDataCollection # Validate that the number of parameters matches the number of measurement data entries if len(parameters_df) != len(measurements_collection.collection): raise ValueError( f"Number of parameter entries ({len(parameters_df)}) does not match " f"number of measurement data entries ({len(measurements_collection.collection)})." ) # Get list of parameter columns parameter_columns = parameters_df.columns.tolist() # Initialize a list to hold all records data_records = [] # Iterate over the parameters and corresponding measurement data for i, (param_index, param_row) in enumerate(parameters_df.iterrows()): # Extract parameter values as a dict param_values = param_row.to_dict() # Retrieve the corresponding measurement data measurement_data_i = measurements_collection.collection[i] # Ensure that measurements are available if measurement_data_i.measurements is None: raise AttributeError( f"This function can only compose a dataset when there is a `measurements` " f"attribute in the data collection. Measurement data at index {i} has no measurements." ) measurement_table = measurement_data_i.measurements.table # Iterate over each row in the measurements table for j, row in measurement_table.iterrows(): # Create a record combining parameter values and measurement data record = {**param_values} # Unpack all parameter columns record.update( { "metric_name": row.name, # Assuming 'Name' is set as index "value": row["Value"], "mean": row["Mean"], "min": row["Min"], "max": row["Max"], "standard_deviation": row["Standard Deviation"], "count": row["Count"], "unit": row["Unit"], } ) data_records.append(record) # Convert the list of records to a DataFrame combined_df = pd.DataFrame(data_records) # Add a measurement index to handle duplicate metric names per parameter set # This ensures each measurement is uniquely identifiable # combined_df['measurement_index'] = combined_df.groupby(parameter_columns + ['metric_name']).cumcount() # Set multi-index with all parameter columns, metric_name, and measurement_index index_columns = parameter_columns + ["metric_name"] combined_df.set_index(index_columns, inplace=True) # Convert the DataFrame to an xarray.Dataset dataset = xr.Dataset.from_dataframe(combined_df) return dataset