Source code for napari_dmc_brainmap.preprocessing.preprocessing
"""
DMC-BrainMap widget for preprocessing of .tif files.
2024 - FJ
"""
import os
import platform
from pathlib import Path
from qtpy.QtCore import Signal
from qtpy.QtWidgets import QPushButton, QWidget, QVBoxLayout, QMessageBox, QProgressBar
from superqt import QCollapsible
from joblib import Parallel, delayed
from napari.qt.threading import thread_worker
from napari.utils.notifications import show_info
from napari_dmc_brainmap.utils.path_utils import get_image_list, get_info
from napari_dmc_brainmap.utils.general_utils import get_animal_id
from napari_dmc_brainmap.utils.params_utils import load_params, clean_params_dict, update_params_dict
from napari_dmc_brainmap.utils.dropdown_utils import get_threshold_dropdown
from magicgui import magicgui, widgets
from magicgui.widgets import FunctionGui
from napari_dmc_brainmap.preprocessing.preprocessing_tools import preprocess_images, create_dirs, get_channels, \
chunk_list
from napari_dmc_brainmap.utils.gui_utils import check_input_path
from typing import List, Dict, Tuple, Union, Optional
[docs]
@thread_worker(progress={"total": 100})
def do_preprocessing(
input_path: Path,
channels: List[str],
img_list: List[str],
preprocessing_params: Dict[str, Union[str, dict]],
resolution: Tuple[int, int],
save_dirs: Dict[str, str]
) -> str:
"""
Perform preprocessing on a list of images in a multithreaded manner.
Parameters:
input_path (Path): Path to the input directory containing images.
channels (List[str]): List of channels to process.
img_list (List[str]): List of image file names to process.
preprocessing_params (Dict[str, Union[str, dict]]): Parameters for preprocessing operations.
resolution (Tuple[int, int]): Tuple containing resolution information for preprocessing.
save_dirs (Dict[str, str]): Dictionary containing paths to save preprocessed images.
Yields:
int: Progress of the preprocessing operation in percentage.
Returns:
str: Animal ID for which preprocessing was performed.
"""
if "operations" in preprocessing_params.keys():
resolution_tuple = tuple(resolution) if 'sharpy_track' in preprocessing_params['operations'] else False
num_cores = os.cpu_count()
# overwrite parallelization to 1 if detects Darwin OS
if platform.system() == 'Darwin':
num_cores = 1
chunk_img_list = chunk_list(img_list, chunk_size=num_cores)
progress_value = 0
progress_step = 100 / len(chunk_img_list)
for chunk in chunk_img_list:
Parallel(n_jobs=num_cores)(
delayed(preprocess_images)(
im, channels, input_path, preprocessing_params, save_dirs, resolution_tuple
) for im in chunk
)
progress_value += progress_step
yield int(progress_value)
preprocessing_params = clean_params_dict(preprocessing_params, "operations")
update_params_dict(input_path, preprocessing_params)
else:
show_info("No preprocessing operations selected. Expand the respective windows and tick the checkbox.")
yield 100
return get_animal_id(input_path)
[docs]
def create_general_widget(
widget_type: str,
channels: List[str],
downsampling_default: int = 3,
contrast_limits: Optional[Dict[str, str]] = None
) -> widgets.Container:
"""
Create a generalized MagicGUI widget for image processing.
Parameters:
widget_type (str): The type of widget being created (e.g., 'RGB', 'Single Channel').
channels (List[str]): List of available channels to select.
downsampling_default (int): Default value for the downsampling factor.
contrast_limits (Optional[Dict[str, str]]): Default contrast limit values for each channel.
Returns:
widgets.Container: The created MagicGUI widget container.
"""
if widget_type != 'Binary':
contrast_limits = contrast_limits or {
'dapi': '50,2000',
'green': '50,1000',
'cy3': '50,2000',
'n3': '50,2000',
'cy5': '50,1000'
}
# Create the base widget
container = widgets.Container(widgets=[
widgets.CheckBox(value=False, label=f'Process {widget_type}', tooltip=f'Tick to process {widget_type} images'),
widgets.Select(choices=['all'] + channels, value='all', label='Select channels', tooltip='Select channels to process'),
widgets.SpinBox(value=downsampling_default, min=1, label='Downsampling Factor', tooltip='Enter scale factor for downsampling'),
widgets.CheckBox(value=True, label=f'Adjust Contrast for {widget_type}', tooltip=f'Option to adjust contrast for {widget_type} images')
],
labels=True
)
if widget_type == 'SHARPy':
container.pop(-2)
# Add contrast widgets for each channel
for channel in channels:
container.append(widgets.LineEdit(value=contrast_limits[channel], label=f'Set contrast limits for {channel}', tooltip=f'Enter contrast limits: min,max for {channel}'))
else:
contrast_limits = contrast_limits or {
'dapi': '4000',
'green': '1000',
'cy3': '2000',
'n3': '2000',
'cy5': '2000'
}
# Create the base widget
container = widgets.Container(widgets=[
widgets.CheckBox(value=False, label=f'Process {widget_type}',
tooltip=f'Tick to process {widget_type} images'),
widgets.Select(choices=['all'] + channels, value='all', label='Select channels',
tooltip='Select channels to process'),
widgets.SpinBox(value=downsampling_default, min=1, label='Downsampling Factor',
tooltip='Enter scale factor for downsampling'),
widgets.ComboBox(choices=get_threshold_dropdown(), label='Thresholding method',
tooltip='select a method to compute the threshold value (from:'
' https://scikit-image.org/docs/stable/api/skimage.filters.html#module-skimage.filters'),
widgets.CheckBox(value=False, label=f'Manually set threshold for {widget_type}',
tooltip=f'Option to manually set threshold for {widget_type} images '
f'(if not ticked, thresholding method will be used)')
],
labels=True
)
# Modify for SHARPy or Binary widget
# if widget_type == 'SHARPy':
# container.pop(-2)
if widget_type == 'Binary':
container.append(
widgets.ComboBox(choices=get_threshold_dropdown(), label='Thresholding Method',
tooltip='Select a thresholding method (see skimage.filters).')
)
# Add contrast widgets for each channel
for channel in channels:
container.append(
widgets.LineEdit(value=contrast_limits[channel], label=f'Set threshold for {channel}',
tooltip=f'Enter threshold for {channel}'))
return container
[docs]
def initialize_header_widget() -> FunctionGui:
"""
Initialize a header widget for selecting the input path and imaged channels.
Returns:
FunctionGui: The initialized header widget.
"""
@magicgui(
input_path=dict(widget_type='FileEdit',
label='Input Path (Animal ID):',
mode='d',
tooltip='Directory containing subfolders with stitched images.'),
chans_imaged=dict(widget_type='Select',
label='Imaged Channels',
choices=['dapi', 'green', 'n3', 'cy3', 'cy5'],
value=['green', 'cy3'],
tooltip='Select all imaged channels. Hold Ctrl/Shift for multiple selections.'),
call_button=False
)
def header_widget(input_path: Path, chans_imaged: List[str]) -> None:
"""
Header widget for selecting input path and imaged channels.
Parameters:
input_path (Path): Path to the input directory.
chans_imaged (List[str]): List of imaged channels.
"""
pass
return header_widget
[docs]
class PreprocessingWidget(QWidget):
"""
QWidget for configuring and performing preprocessing operations.
"""
progress_signal = Signal(int)
"""Signal emitted to update the progress bar with an integer value."""
def __init__(self, parent: Optional[QWidget] = None) -> None:
"""
Initialize the PreprocessingWidget.
Parameters:
parent (Optional[QWidget]): Parent widget.
"""
super().__init__(parent)
self.setLayout(QVBoxLayout())
# Header widget
self.header = initialize_header_widget()
self.header.native.layout().setSizeConstraint(QVBoxLayout.SetFixedSize)
# Add generalized widgets for different operations
self.rgb_widget = create_general_widget('RGB', ['dapi', 'green', 'cy3'])
self.sharpy_widget = create_general_widget('SHARPy', ['dapi', 'green', 'n3', 'cy3', 'cy5'], contrast_limits={
'dapi': '50,1000',
'green': '50,300',
'cy3': '50,2000',
'n3': '50,500',
'cy5': '50,500'
})
self.single_channel_widget = create_general_widget('Single Channel', ['dapi', 'green', 'n3', 'cy3', 'cy5'])
self.stack_widget = create_general_widget('Stack', ['dapi', 'green', 'n3', 'cy3', 'cy5'])
self.binary_widget = create_general_widget('Binary', ['dapi', 'green', 'n3', 'cy3', 'cy5'])
# Add preprocessing button
self.btn = QPushButton("Do the Preprocessing!")
self.btn.clicked.connect(self._do_preprocessing)
# Progress bar
self.progress_bar = QProgressBar(self)
self.progress_bar.setMinimum(0)
self.progress_bar.setMaximum(100)
self.progress_bar.setValue(0)
# Add widgets to layout
self.layout().addWidget(self.header.native)
self._add_gui_section('Create RGB: expand for more', self.rgb_widget)
self._add_gui_section('Create SHARPy-track images: expand for more', self.sharpy_widget)
self._add_gui_section('Process Single Channels: expand for more', self.single_channel_widget)
self._add_gui_section('Create Image Stacks: expand for more', self.stack_widget)
self._add_gui_section('Create Binary Images: expand for more', self.binary_widget)
self.layout().addWidget(self.btn)
self.layout().addWidget(self.progress_bar)
self.progress_signal.connect(self.progress_bar.setValue)
def _add_gui_section(self, name: str, widget: FunctionGui) -> None:
"""
Add a collapsible GUI section to the layout.
Parameters:
name (str): The name of the collapsible section.
widget (FunctionGui): The widget to add within the collapsible section.
"""
collapsible = QCollapsible(name, self)
collapsible.addWidget(widget.native)
self.layout().addWidget(collapsible)
def _get_widget_info(self, widget: FunctionGui, item: str) -> Dict[str, Union[List[int], int, str]]:
"""
Retrieve information from a given widget based on the type of item.
Parameters:
widget (FunctionGui): The widget to retrieve information from.
item (str): Type of operation (e.g., 'rgb', 'sharpy_track').
Returns:
Dict[str, Union[List[int], int, str]]: Information extracted from the widget.
"""
chan_list = ['dapi', 'green', 'cy3'] if item == 'rgb' else ['dapi', 'green', 'n3', 'cy3', 'cy5']
imaged_chan_list = (widget[1].value if 'all' not in widget[1].value
else self.header.chans_imaged.value)
imaged_chan_list = [i for i in imaged_chan_list if i in self.header.chans_imaged.value]
base_info = {"channels": imaged_chan_list, "downsampling": widget[2].value}
if item == 'sharpy_track':
base_info["contrast_adjustment"] = widget[2].value
elif item != 'binary':
base_info["contrast_adjustment"] = widget[3].value
if item == 'binary':
if widget[4].value: # manual thresholds
base_info.update({"manual_threshold": widget[4].value})
base_info.update({channel: [int(i) for i in widget[4 + idx].value.split(',')] for idx, channel in
enumerate(chan_list) if channel in imaged_chan_list})
else:
base_info.update({"manual_threshold": widget[4].value, "thresh_method": widget[3].value.value})
else:
base_info.update({
channel: [int(i) for i in widget[(3 if item == 'sharpy_track' else 4) + idx].value.split(',')]
for idx, channel in enumerate(chan_list) if channel in imaged_chan_list
})
return base_info
def _get_preprocessing_params(self) -> Dict[str, Union[str, Dict[str, Union[str, List[int], int]]]]:
"""
Retrieve preprocessing parameters based on user selections.
Returns:
- Dict[str, Union[str, Dict[str, Union[str, List[int], int]]]]: Dictionary of preprocessing parameters.
"""
op_widg_dict = {
"rgb": self.rgb_widget,
"sharpy_track": self.sharpy_widget,
"single_channel": self.single_channel_widget,
"stack": self.stack_widget,
"binary": self.binary_widget
}
params_dict = {
"general":
{
"animal_id": get_animal_id(self.header.input_path.value),
"chans_imaged": self.header.chans_imaged.value
},
}
k = 0
for op, widget in op_widg_dict.items():
if widget[0].value:
if k < 1:
params_dict["operations"] = {}
k += 1
params_dict["operations"][op] = widget[0].value
params_dict[f"{op}_params"] = self._get_widget_info(widget, op)
return params_dict
def _check_preprocessing_success(self) -> List[str]:
"""
Check if preprocessing was successful for the given animal ID.
Returns:
List[str]: Return list of directories containing missing files.
"""
input_path = self.header.input_path.value
params_dict = load_params(input_path)
missing_files = []
for op, op_bool in params_dict["operations"].items():
if op_bool:
if op == "rgb":
_, op_data_list, _ = get_info(input_path, op)
if not op_data_list:
missing_files.append(f"{op}")
else:
for chan in params_dict[f"{op}_params"]["channels"]:
_, op_data_list, _ = get_info(input_path, op, chan)
if not op_data_list:
missing_files.append(f"{op}_{chan}")
return missing_files
def _show_success_message(self, animal_id: str) -> None:
"""
Display a success message after preprocessing is complete.
Parameters:
animal_id (str): The Animal ID for which preprocessing was performed.
"""
missing_files = self._check_preprocessing_success()
if len(missing_files) == 0:
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Information)
msg_box.setText(f"Preprocessing finished successfully for {animal_id}!")
msg_box.setWindowTitle("Preprocessing Complete")
msg_box.exec_()
else:
msg_box = QMessageBox()
msg_box.setIcon(QMessageBox.Warning)
msg_box.setText(f"Preprocessing finished, but the following files are missing: {', '.join(missing_files)}")
# msg_box.setText(f"Preprocessing failed for {animal_id}:\n".join(missing_files))
msg_box.setWindowTitle("Preprocessing Error")
msg_box.exec_()
self.btn.setText("Do the Preprocessing!") # Reset button text
self.progress_signal.emit(0)
def _update_progress(self, value: int) -> None:
"""
Update the progress bar with the current progress value.
Parameters:
value (int): Progress value to set.
"""
self.progress_signal.emit(value)
def _do_preprocessing(self) -> None:
"""
Execute the preprocessing of images based on user input.
"""
input_path = self.header.input_path.value
# Validate input path
if not check_input_path(input_path):
return
# Retrieve preprocessing parameters
preprocessing_params = self._get_preprocessing_params()
save_dirs = create_dirs(preprocessing_params, input_path)
channels = get_channels(preprocessing_params)
for chan in channels:
img_list = get_image_list(input_path, chan)
params_dict = load_params(input_path)
resolution = params_dict['atlas_info']['resolution']
# Start the preprocessing worker
preprocessing_worker = do_preprocessing(input_path, channels, img_list, preprocessing_params, resolution, save_dirs)
preprocessing_worker.yielded.connect(self._update_progress)
preprocessing_worker.started.connect(lambda: self.btn.setText("Preprocessing Images..."))
preprocessing_worker.returned.connect(self._show_success_message)
preprocessing_worker.start()