Source code for napari_dmc_brainmap.segment.processing.presegmentation_tools

import numpy as np
import pandas as pd
from pathlib import Path
import json
import cv2
from typing import List, Tuple, Union
from concurrent.futures import ProcessPoolExecutor, as_completed
from aicsimageio import AICSImage
from skimage.morphology import remove_small_objects
from skimage.measure import label, regionprops
from aicssegmentation.core.pre_processing_utils import intensity_normalization, image_smoothing_gaussian_slice_by_slice
from aicssegmentation.core.seg_dot import dot_3d
from aicssegmentation.core.utils import hole_filling
from aicsimageio.writers import OmeTiffWriter
from bg_atlasapi import BrainGlobeAtlas
from napari_dmc_brainmap.registration.sharpy_track.sharpy_track.model.calculation import fitGeoTrans, mapPointTransform
from napari_dmc_brainmap.segment.processing.atlas_utils import loadAnnotBool, angleSlice
from napari_dmc_brainmap.utils.path_utils import get_info
from napari_dmc_brainmap.utils.atlas_utils import get_bregma

from napari.utils.notifications import show_info

[docs] class PreSegmenter: """ Base class for pre-segmentation tasks including loading data, preparing directories, and excluding objects based on registration data. """ def __init__(self, input_path: Path, general_params: dict) -> None: """ Initialize the PreSegmenter with input path and general parameters. Parameters: input_path (Path): Path to the input directory. general_params (dict): General configuration parameters. """ self.input_path = input_path self.general_params = general_params # self.registration_data = None
[docs] def load_registration_data(self) -> None: """ Load registration data including transformation matrices and atlas information. """ if self.general_params["regi_bool"]: try: regi_dir = get_info(self.input_path, 'sharpy_track', channel=self.general_params["regi_chan"], only_dir=True) regi_fn = regi_dir.joinpath("registration.json") with open(regi_fn, 'r') as f: regi_data = json.load(f) annot_bool = loadAnnotBool(self.general_params["atlas_id"]) show_info(f"check existence of local version of {self.general_params['atlas_id']} atlas ...") show_info(f"loading reference atlas {self.general_params['atlas_id']} ...") atlas = BrainGlobeAtlas(self.general_params["atlas_id"]) z_idx = atlas.space.axes_description.index(self.general_params["xyz_dict"]['z'][0]) z_res = self.general_params["xyz_dict"]['z'][2] bregma = get_bregma(self.general_params["atlas_id"]) self.registration_data = { "regi_data": regi_data, "annot_bool": annot_bool, "z_idx": z_idx, "z_res": z_res, "bregma": bregma } except FileNotFoundError: show_info('NO REGISTRATION DATA FOUND') self.registration_data = None
[docs] def get_regi_index(self, im: str, seg_im_suffix: str) -> int: """ Get the registration index corresponding to the given image. Parameters: im (str): Image name. seg_im_suffix (str): Suffix of the segmented image. Returns: int: Registration index. """ for k, v in self.registration_data["regi_data"]['imgName'].items(): if v.startswith(im[:-(len(seg_im_suffix) - 1)]): regi_index = k return regi_index
[docs] def exclude_segment_objects(self, im: str, segmented_image: np.ndarray, seg_idx: np.ndarray, seg_im_suffix: str) -> np.ndarray: """ Exclude objects from the segmentation based on atlas registration. Parameters: im (str): Image name. segmented_image (np.ndarray): Binary segmentation mask. seg_idx (np.ndarray): Array of segment indices. seg_im_suffix (str): Suffix of the segmented image. Returns: np.ndarray: Updated segment indices after exclusion. """ dim_image = segmented_image.shape x_res = self.general_params["xyz_dict"]['x'][1] y_res = self.general_params["xyz_dict"]['y'][1] x_im = seg_idx[:, 1] / dim_image[1] * x_res y_im = seg_idx[:, 0] / dim_image[0] * y_res regi_index = self.get_regi_index(im, seg_im_suffix) drop_mask = self.get_drop_mask(regi_index, x_im, y_im) if len(drop_mask) > 0: seg_idx = seg_idx[drop_mask] return seg_idx
[docs] def get_drop_mask(self, regi_index: int, x_im: np.ndarray, y_im: np.ndarray) -> np.ndarray: """ Get a mask for excluding objects based on atlas registration. Parameters: regi_index (int): Registration index. x_im (np.ndarray): X-coordinates of image objects. y_im (np.ndarray): Y-coordinates of image objects. Returns: np.ndarray: Boolean mask for excluding objects. """ try: # get transformation tform = fitGeoTrans(self.registration_data["regi_data"]['sampleDots'][regi_index], self.registration_data["regi_data"]['atlasDots'][regi_index]) # slice annotation volume x_angle, y_angle, z = self.registration_data["regi_data"]['atlasLocation'][regi_index] annot_slice = angleSlice(x_angle, y_angle, z, self.registration_data["annot_bool"], self.registration_data["z_idx"], self.registration_data["z_res"], self.registration_data["bregma"], self.general_params["xyz_dict"]) # mark invalid coordinates drop_mask = [] for x, y in zip(x_im, y_im): x_atlas, y_atlas = mapPointTransform(x, y, tform) x_atlas, y_atlas = int(x_atlas), int(y_atlas) if (x_atlas < 0) | (y_atlas < 0) | (x_atlas >= self.general_params["xyz_dict"]['x'][1]) | ( y_atlas >= self.general_params["xyz_dict"]['y'][1]): drop_mask.append(0) else: if annot_slice[y_atlas, x_atlas] == 0: drop_mask.append(0) else: drop_mask.append(1) drop_mask = np.array(drop_mask, dtype=bool) except Exception: show_info(f"No registration data for {self.registration_data['regi_data']['imgName'][regi_index]}") drop_mask = [] return drop_mask
[docs] def save_to_csv(self, data: pd.DataFrame, file_path: Path) -> None: """ Save data to a CSV file. Parameters: data (pd.DataFrame): Data to save. file_path (Path): Path to save the CSV file. """ data.to_csv(file_path, index=False)
[docs] def load_image_list(self, chan: str, im_class: str) -> Tuple[Path, List[str], str]: """ Load the list of images to be segmented. Parameters: chan (str): Channel identifier. im_class (str): Class of the image (e.g., 'rgb', 'single_channel'). Returns: Tuple[Path, List[str], str]: Directory path, list of images, and image suffix. """ if im_class == 'rgb': seg_im_dir, seg_im_list, seg_im_suffix = get_info(self.input_path, im_class) else: seg_im_dir, seg_im_list, seg_im_suffix = get_info(self.input_path, im_class, channel=chan) if self.general_params["start_end_im"]: if len(self.general_params["start_end_im"]) == 2: seg_im_list = seg_im_list[self.general_params["start_end_im"][0]:self.general_params["start_end_im"][1] + 1] return seg_im_dir, seg_im_list, seg_im_suffix
[docs] def prepare_segmentation_folders(self, seg_folder: str, chan: str) -> Union[Tuple[Path, Path], Path]: """ Prepare directories for segmentation tasks. Parameters: seg_folder (str): Folder name for segmentation masks. chan (str): Channel identifier. Returns: Union[Tuple[Path, Path], Path]: Paths to prepared directories. """ output_dir = get_info(self.input_path, self.general_params["output_folder"], channel=chan, seg_type=self.general_params["seg_type"], create_dir=True, only_dir=True) if self.general_params["seg_type"] == 'projections': return output_dir else: mask_dir = get_info(self.input_path, seg_folder, channel=chan, seg_type=self.general_params["seg_type"], create_dir=True, only_dir=True) return mask_dir, output_dir
#%%
[docs] class CellsSegmenter(PreSegmenter): """ Class for performing cell segmentation tasks including preprocessing, segmentation, and centroid detection. """ def __init__(self, input_path: Path, general_params: dict, cells_params: dict, preseg_params: dict) -> None: """ Initialize the CellsSegmenter. Parameters: input_path (Path): Path to the input directory. general_params (dict): General configuration parameters. cells_params (dict): Parameters specific to cell segmentation. preseg_params (dict): Pre-segmentation parameters. """ super().__init__(input_path, general_params) self.cells_params = cells_params self.preseg_params = preseg_params
[docs] def load_image(self, image_path: Path, chan: str) -> np.ndarray: """ Load an image for segmentation. Parameters: image_path (Path): Path to the image file. chan (str): Channel identifier. Returns: np.ndarray: Loaded image as a NumPy array. """ reader = AICSImage(str(image_path)) img = reader.data.astype(np.float32) if self.cells_params["single_channel"]: img_struct = img[0, 0, 0, :, :].copy() else: # for RGB images chan_dict = { 'cy3': 0, 'green': 1, 'dapi': 2 } img_struct = img[0, 0, 0, :, :, chan_dict[chan]].copy() return np.array([img_struct, img_struct]) # Duplicate layer stack
[docs] def save_mask_image(self, segmentation: np.ndarray, mask_save_fn: Path) -> None: """ Save the segmentation mask as a TIFF image. Parameters: segmentation (np.ndarray): Segmentation mask. mask_save_fn (Path): File path to save the mask. """ writer = OmeTiffWriter() writer.save(segmentation[0], str(mask_save_fn))
[docs] def preprocess_image(self, image: np.ndarray) -> np.ndarray: """ Preprocess the image for segmentation. Parameters: image (np.ndarray): Input image. Returns: np.ndarray: Preprocessed image. """ image = intensity_normalization(image, scaling_param=self.preseg_params["intensity_norm"]) return image_smoothing_gaussian_slice_by_slice(image, sigma=self.preseg_params["gaussian_smoothing_sigma"])
[docs] def segment_image(self, image: np.ndarray) -> np.ndarray: """ Perform segmentation on the preprocessed image. Parameters: image (np.ndarray): Preprocessed image. Returns: np.ndarray: Segmentation mask. """ response = dot_3d(image, log_sigma=self.preseg_params["dot_3d_sigma"]) bw = response > self.preseg_params["dot_3d_cutoff"] bw_filled = hole_filling(bw, self.preseg_params["hole_min_max"][0], self.preseg_params["hole_min_max"][1], True) seg = remove_small_objects(bw_filled, min_size=self.preseg_params["minArea"], connectivity=1) seg = seg > 0 seg = seg.astype(np.uint8) seg[seg > 0] = 255 return seg[0]
[docs] def find_centroids(self, segmented_image: np.ndarray) -> np.ndarray: """ Find centroids in the segmented image. Parameters: segmented_image (np.ndarray): Segmentation mask. Returns: np.ndarray: Centroid coordinates. """ # Label the image to find individual segmented regions label_img = label(segmented_image) # Get properties of segmented regions, particularly the centroids regions = regionprops(label_img) # Extract centroids centroids = np.zeros((len(regions), 2)) for idx, props in enumerate(regions): centroids[idx, 0] = props.centroid[0] # Y-coordinate centroids[idx, 1] = props.centroid[1] # X-coordinate return centroids
[docs] def segment_cells(self, image_path: Path, save_path: Path, mask_save_fn: Path, chan: str, seg_im_suffix: str) -> None: """ Perform segmentation on a single image and save the results. Parameters: image_path (Path): Path to the image file. save_path (Path): Path to save segmentation results. mask_save_fn (Path): Path to save segmentation mask. chan (str): Channel identifier. seg_im_suffix (str): Image suffix. """ image = self.load_image(image_path, chan) preprocessed_image = self.preprocess_image(image) segmented_image = self.segment_image(preprocessed_image) seg_cells = self.find_centroids(segmented_image) if self.general_params['regi_bool']: seg_cells = self.exclude_segment_objects(image_path.stem, segmented_image, seg_cells, seg_im_suffix) # csv_to_save = pd.DataFrame({'Position Y': idx[0], 'Position X': idx[1]}) csv_to_save = pd.DataFrame(seg_cells, columns=["Position Y", "Position X"]) self.save_to_csv(csv_to_save, save_path)
# Create a binary image with only centroid points for further visualization # centroid_binary = np.zeros(segmented_image.shape, dtype='uint8') # centroids_int = np.round(seg_cells).astype(int) # for val in centroids_int: # centroid_binary[val[0], val[1]] = 255 # Save centroid binary image # centroid_image_save_path = mask_image_path.with_name(f"{mask_image_path.stem}_centroids.tif") # print(mask_save_fn) # cv2.imwrite(str(mask_save_fn), centroid_binary)
[docs] def process_images(self, progress_callback: Union[None, callable] = None) -> None: """ Process all images for segmentation. Parameters: progress_callback (Union[None, callable]): Callback function for progress updates. """ self.load_registration_data() if self.cells_params['single_channel']: im_class = 'single_channel' else: im_class = 'rgb' total_images = 0 for chan in self.general_params["channels"]: _, seg_im_list, _ = self.load_image_list(chan, im_class) total_images += len(seg_im_list) processed_images = 0 for chan in self.general_params["channels"]: mask_dir, output_dir = self.prepare_segmentation_folders(self.cells_params["mask_folder"], chan) seg_im_dir, seg_im_list, seg_im_suffix = self.load_image_list(chan, im_class) for im_name in seg_im_list: path_to_im = seg_im_dir.joinpath(im_name) save_path = output_dir.joinpath(f"{im_name.split('.')[0]}_{self.general_params['seg_type']}.csv") mask_save_fn = mask_dir.joinpath(im_name[:-len(seg_im_suffix)] + '_masks.tiff') self.segment_cells(path_to_im, save_path, mask_save_fn, chan, seg_im_suffix) processed_images += 1 if progress_callback: progress = int((processed_images / total_images) * 100) progress_callback(progress)
[docs] def process_images_parallel(self, max_workers: int = 4, progress_callback: Union[None, callable] = None) -> None: """ Process images in parallel. Parameters: max_workers (int): Number of parallel workers. progress_callback (Union[None, callable]): Callback function for progress updates. """ tasks = [] self.load_registration_data() im_class = 'single_channel' if self.cells_params['single_channel'] else 'rgb' with ProcessPoolExecutor(max_workers=max_workers) as executor: # Determine if we're using single-channel or RGB images # Load the image lists for all channels for chan in self.general_params["channels"]: mask_dir, output_dir = self.prepare_segmentation_folders(self.cells_params["mask_folder"], chan) seg_im_dir, seg_im_list, seg_im_suffix = self.load_image_list(chan, im_class) # Submit each image processing task to the executor for im_name in seg_im_list: image_path = seg_im_dir.joinpath(im_name) save_path = output_dir.joinpath(f"{im_name.split('.')[0]}_{self.general_params['seg_type']}.csv") mask_save_fn = mask_dir.joinpath(im_name[:-len(seg_im_suffix)] + '_masks.tiff') tasks.append(executor.submit(self.segment_cells, image_path, save_path, mask_save_fn, chan, seg_im_suffix)) # Monitor the progress as tasks complete for idx, future in enumerate(as_completed(tasks)): try: # Retrieve the result or handle any exception raised during processing future.result() if progress_callback: progress = int((idx + 1) / len(tasks) * 100) progress_callback(progress) except Exception as e: print(f"Error processing image: {e}")
#%%
[docs] class ProjectionSegmenter(PreSegmenter): """ Class for performing projection segmentation tasks. """ def __init__(self, input_path: Path, general_params: dict, projection_params: Union[dict, None] = None) -> None: """ Initialize the ProjectionSegmenter. Parameters: input_path (Path): Path to the input directory. general_params (dict): General configuration parameters. projection_params (Union[dict, None]): Parameters specific to projection segmentation. """ super().__init__(input_path, general_params) self.projection_params = projection_params
[docs] def segment_projection(self, image_path: Path, save_path: Path, binary_suffix: str) -> None: """ Perform segmentation on a single projection image and save the results. Parameters: image_path (Path): Path to the image file. save_path (Path): Path to save segmentation results. binary_suffix (str): Suffix of the binary image. """ # Load the image image = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) # self.load_image(image_path, single_channel=True, structure_channel=0) # Find positions of segmented regions (non-zero pixels) idx = np.where(image == 255) seg_proj = np.zeros((len(idx[0]), 2)) for i, (y, x) in enumerate(zip(idx[0], idx[1])): seg_proj[i, 0] = y # Y-coordinate seg_proj[i, 1] = x # X-coordinate if self.general_params["regi_bool"]: seg_proj = self.exclude_segment_objects(image_path.stem, image, seg_proj, binary_suffix) # Save results csv_to_save = pd.DataFrame(seg_proj, columns=["Position Y", "Position X"]) self.save_to_csv(csv_to_save, save_path)
[docs] def process_images(self, progress_callback: Union[None, callable] = None) -> None: """ Process all projection images for segmentation. Parameters: progress_callback (Union[None, callable]): Callback function for progress updates. """ self.load_registration_data() im_class = self.projection_params['binary_folder'] total_images = 0 for chan in self.general_params["channels"]: _, seg_im_list, _ = self.load_image_list(chan, self.projection_params['binary_folder']) total_images += len(seg_im_list) processed_images = 0 for chan in self.general_params["channels"]: output_dir = self.prepare_segmentation_folders(self.projection_params["binary_folder"], chan) binary_dir, binary_images, binary_suffix = self.load_image_list(chan, im_class) for im_name in binary_images: path_to_im = binary_dir.joinpath(im_name) save_path = output_dir.joinpath(f"{im_name.split('.')[0]}_{self.general_params['seg_type']}.csv") self.segment_projection(path_to_im, save_path, binary_suffix) processed_images += 1 if progress_callback: progress = int((processed_images / total_images) * 100) progress_callback(progress)
[docs] def process_images_parallel(self, max_workers=4, progress_callback: Union[None, callable] =None) -> None: """ Process projection images in parallel using multiple CPU cores. Parameters: max_workers (int): Number of worker processes to run in parallel. progress_callback (Union[None, callable]): Callback function to report progress. """ tasks = [] im_class = self.projection_params['binary_folder'] with ProcessPoolExecutor(max_workers=max_workers) as executor: # Load the image lists for all channels for chan in self.general_params["channels"]: output_dir = self.prepare_segmentation_folders(self.projection_params["binary_folder"], chan) binary_dir, binary_images, binary_suffix = self.load_image_list(chan, im_class) # Submit each image processing task to the executor for im_name in binary_images: self.load_registration_data() image_path = binary_dir.joinpath(im_name) save_path = output_dir.joinpath(f"{im_name.split('.')[0]}_{self.general_params['seg_type']}.csv") tasks.append(executor.submit(self.segment_projection, image_path, save_path, binary_suffix)) # Monitor the progress as tasks complete for idx, future in enumerate(as_completed(tasks)): try: # Retrieve the result or handle any exception raised during processing future.result() if progress_callback: progress = int((idx + 1) / len(tasks) * 100) progress_callback(progress) except Exception as e: print(f"Error processing image: {e}")