Source code for napari_dmc_brainmap.utils.data_loader

from pathlib import Path
import json
from typing import List, Union
import pandas as pd
from natsort import natsorted
from napari_dmc_brainmap.utils.path_utils import get_info
from napari_dmc_brainmap.utils.atlas_utils import get_bregma, coord_mm_transform
from bg_atlasapi import BrainGlobeAtlas
from napari.utils.notifications import show_info

[docs] class DataLoader: def __init__(self, input_path: Path, atlas: BrainGlobeAtlas, animal_list: List[str], channels: List[str], data_type: str = 'cells', hemisphere: str = 'both'): """ Initialize the DataLoader. Parameters: input_path (Path): Path to the input data directory. atlas: The atlas object containing structure and hierarchy information. animal_list (List[str]): List of animal identifiers. channels (List[str]): List of channel identifiers. data_type (str): Type of data to load (e.g., 'cells', 'optic_fiber'). Default is 'cells'. hemisphere (str): Hemisphere to analyze ('both', 'ipsi', 'contra'). Default is 'both'. """ self.input_path = input_path self.atlas = atlas self.animal_list = animal_list self.channels = channels self.data_type = data_type self.hemisphere = hemisphere self.bregma = get_bregma(atlas.atlas_name)
[docs] def load_data(self) -> pd.DataFrame: """ Load and merge data for all animals. Returns: pd.DataFrame: Merged DataFrame containing data for all animals and channels. """ results_data_merged = pd.DataFrame() for animal_id in self.animal_list: if self.data_type in ["optic_fiber", "neuropixels_probe"]: seg_super_dir = get_info(self.input_path.joinpath(animal_id), 'results', seg_type=self.data_type, only_dir=True) self.channels = natsorted([f.parts[-1] for f in seg_super_dir.iterdir() if f.is_dir()]) params = self._load_params(self.input_path.joinpath(animal_id, 'params.json')) for channel in self.channels: if self.data_type == 'swc': results_data = self._load_swc_data(animal_id, channel) else: results_data = self._load_channel_data(animal_id, channel) if results_data is None: continue if self.data_type in ["optic_fiber", "neuropixels_probe"]: results_data = results_data[results_data['inside_brain']] results_data['ml_mm'] *= -1 # Convert negative values for the left hemisphere results_data['animal_id'] = animal_id results_data['channel'] = f"{animal_id}_{channel}" if ( self.data_type in ["optic_fiber", "neuropixels_probe"] and len(self.animal_list) > 1) else channel results_data['injection_site'] = params.get('injection_site', 'right') results_data['genotype'] = params.get('genotype', 0) results_data['group'] = params.get('group', 0) results_data = self._get_ipsi_contra(results_data) results_data_merged = pd.concat([results_data_merged, results_data], ignore_index=True) show_info(f"Loaded data from {animal_id}") if self.atlas.metadata['name'] == 'allen_mouse': results_data_merged = self._clean_results_df(results_data_merged) if self.hemisphere in ['ipsi', 'contra']: results_data_merged = results_data_merged[results_data_merged['ipsi_contra'] == self.hemisphere] return results_data_merged.reset_index(drop=True)
[docs] def get_tgt_data_only(self, df: pd.DataFrame, tgt_list: List[str], negative: bool = False, use_na: bool = False) -> pd.DataFrame: """ Filter data to include only target regions. Parameters: df (pd.DataFrame): Input DataFrame containing structure data. tgt_list (List[str]): List of target region names. negative (bool): Whether to exclude target regions. Default is False. use_na (bool): Whether to include NA regions. Default is False. Returns: pd.DataFrame: Filtered DataFrame containing only target regions. """ ids = [] for reg in tgt_list: if not reg == 'NA': try: ids.append(self.atlas.structures[reg]['id']) except KeyError: show_info(f'No region called >> {reg} <<, skipping that region.') pass ids_child = [child_id for tgt_id in ids for child_id in self.atlas.hierarchy.is_branch(tgt_id)] ids.extend(ids_child) if 'NA' in tgt_list and use_na: ids.append(-42) df.loc[~df.structure_id.isin(ids), 'structure_id'] = -42 condition = df['structure_id'].isin(ids) tgt_data = df[~condition] if negative else df[condition] tgt_data.loc[:, 'tgt_name'] = tgt_data['structure_id'].map(lambda s: self._get_tgt_name(s, tgt_list)) return tgt_data.reset_index(drop=True)
def _get_tgt_name(self, structure_id: int, tgt_list: List[str]) -> str: """ Map a structure ID to its name or closest ancestor in the target list. Parameters: structure_id (int): The ID of the structure. tgt_list (List[str]): List of target region names. Returns: str: The name of the target region. """ try: acronym = self.atlas.structures[structure_id]['acronym'] if acronym in tgt_list: return acronym ancestors = list(set(self.atlas.get_structure_ancestors(structure_id)) & set(tgt_list)) return ancestors[0] if ancestors else 'NA' except KeyError: return 'NA' def _load_params(self, params_file: Path) -> dict: """ Load parameters from a JSON file. Parameters: params_file (Path): Path to the params.json file. Returns: dict: Loaded parameters. """ try: with open(params_file) as fn: return json.load(fn).get('general', {}) except FileNotFoundError: show_info(f"WARNING: Params file not found for {params_file}, using defaults.") return {} def _load_channel_data(self, animal_id: str, channel: str) -> Union[pd.DataFrame, None]: """ Load data for a specific channel of an animal. Parameters: animal_id (str): Identifier of the animal. channel (str): Identifier of the channel. Returns: pd.DataFrame or None: Loaded channel data or None if file is not found. """ results_dir = get_info(self.input_path.joinpath(animal_id), 'results', seg_type=self.data_type, channel=channel, only_dir=True) results_file = results_dir.joinpath(f"{animal_id}_{self.data_type}.csv") if not results_file.exists(): show_info(f"WARNING: Data file {results_file} does not exist. Skipping...") return None return pd.read_csv(results_file) def _load_swc_data(self, animal_id: str, channel: str) -> Union[pd.DataFrame, None]: """ Load swc data for a specific channel of an animal. Parameters: animal_id (str): Identifier of the animal. channel (str): Identifier of the channel. Returns: pd.DataFrame or None: Loaded channel data or None if file is not found. """ swc_merge = pd.DataFrame() results_dir = get_info(self.input_path.joinpath(animal_id), 'results', seg_type=self.data_type, channel=channel, only_dir=True) results_file = results_dir.joinpath(f"{animal_id}_{self.data_type}.csv") if not results_file.exists(): show_info(f"WARNING: Data file {results_file} does not exist. Skipping...") return None results_descriptor = pd.read_csv(results_file) for index, row in results_descriptor.iterrows(): swc_fn = results_dir.joinpath(row['swc_file']) cols = ["id", "type", "ml_coords", "ap_coords", "dv_coords", "radius", "parent"] swc_df = pd.read_csv(swc_fn, comment="#", delim_whitespace=True, names=cols) swc_df['neuron_id'] = row['swc_file'].split('.')[0] swc_df['group_id'] = row['group'] swc_df['animal_id'] = animal_id swc_merge = pd.concat((swc_merge, swc_df), ignore_index=True) coords = swc_merge[["ap_coords", "dv_coords", "ml_coords"]].to_numpy() mm_coords = [coord_mm_transform(row, self.bregma, self.atlas.space.resolution) for row in coords] swc_merge[["ap_mm", "dv_mm", "ml_mm"]] = pd.DataFrame(mm_coords, index=swc_merge.index) return swc_merge def _get_ipsi_contra(self, df: pd.DataFrame) -> pd.DataFrame: """ Determine if cells are ipsilateral ('ipsi') or contralateral ('contra') to the injection site. Parameters: df (pd.DataFrame): Input DataFrame containing 'ml_mm' and 'injection_site' columns. Returns: pd.DataFrame: DataFrame with an additional 'ipsi_contra' column. """ if 'injection_site' not in df.columns or df.empty: raise ValueError("The dataframe must contain a non-empty 'injection_site' column.") injection_site = df['injection_site'].iloc[0] df['ipsi_contra'] = 'ipsi' if injection_site == 'left': df.loc[df['ml_mm'] < 0, 'ipsi_contra'] = 'contra' elif injection_site == 'right': df.loc[df['ml_mm'] > 0, 'ipsi_contra'] = 'contra' else: raise ValueError(f"Unexpected value in 'injection_site': {injection_site}. Expected 'left' or 'right'.") return df def _clean_results_df(self, df: pd.DataFrame) -> pd.DataFrame: """ Clean the results DataFrame by removing unwanted structures. Parameters: df (pd.DataFrame): DataFrame containing atlas data. Returns: pd.DataFrame: Cleaned DataFrame with unwanted structures removed. """ try: list_delete = ['root'] for item in ['fiber tracts', 'VS']: list_delete.extend(self.atlas.get_structure_descendants(item)) df = df.drop(df[df['acronym'].isin(list_delete)].index) df = df.reset_index(drop=True) except KeyError: pass return df