"""Widgets that visualise volume images in .nii files."""
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, fixed, IntSlider
import ipywidgets as widgets
import inspect
import scipy.ndimage
from .colormaps import get_cmap_dropdown
# import pathlib & backwards compatibility
try:
# on >3 this ships by default
from pathlib import Path
except ModuleNotFoundError:
# on 2.7 this should work
try:
from pathlib2 import Path
except ModuleNotFoundError:
raise ModuleNotFoundError('On python 2.7, niwidgets requires '
'pathlib2 to be installed.')
[docs]class NiftiWidget:
"""Turn .nii files into interactive plots using ipywidgets.
Args
----
filename : str
The path to your ``.nii`` file. Can be a string, or a
``PosixPath`` from python3's pathlib.
"""
def __init__(self, filename):
"""
Turn .nii files into interactive plots using ipywidgets.
Args
----
filename : str
The path to your ``.nii`` file. Can be a string, or a
``PosixPath`` from python3's pathlib.
"""
if hasattr(filename, 'get_data'):
self.data = filename
else:
filename = Path(filename).resolve()
if not filename.is_file():
raise OSError('File ' + filename.name + ' not found.')
# load data in advance
# this ensures once the widget is created that the file is of a
# format readable by nibabel
self.data = nib.load(str(filename))
# initialise where the image handles will go
self.image_handles = None
[docs] def nifti_plotter(self, plotting_func=None, colormap=None, figsize=(15, 5),
**kwargs):
"""
Plot volumetric data.
Args
----
plotting_func : function
A plotting function for .nii files, most likely
mask_background : bool
Whether the background should be masked (set to NA).
This parameter only works in conjunction with the default
plotting function (`plotting_func=None`). It finds clusters
of values that round to zero and somewhere touch the edges
of the image. These are set to NA. If you think you are
missing data in your image, set this False.
colormap : str | list
The matplotlib colormap that should be applied to the data.
By default, the widget will allow you to pick from all that
are available, but you can pass a string to fix the
colormap or a list of strings to offer the user a few
options.
figsize : tup
The figure height and width for matplotlib, in inches.
If you are providing a custom plot function, any kwargs you provide
to nifti_plotter will be passed to that function.
"""
kwargs['colormap'] = get_cmap_dropdown(colormap)
kwargs['figsize'] = fixed(figsize)
if plotting_func is None:
self._default_plotter(**kwargs)
else:
self._custom_plotter(plotting_func, **kwargs)
def _default_plotter(self, mask_background=False, **kwargs):
"""Plot three orthogonal views.
This is called by nifti_plotter, you shouldn't call it directly.
"""
plt.gcf().clear()
plt.ioff() # disable interactive mode
data_array = self.data.get_data()
if not ((data_array.ndim == 3) or (data_array.ndim == 4)):
raise ValueError('Input image should be 3D or 4D')
# mask the background
if mask_background:
# TODO: add the ability to pass 'mne' to use a default brain mask
# TODO: split this out into a different function
if data_array.ndim == 3:
labels, n_labels = scipy.ndimage.measurements.label(
(np.round(data_array) == 0))
else: # 4D
labels, n_labels = scipy.ndimage.measurements.label(
(np.round(data_array).max(axis=3) == 0)
)
mask_labels = [lab for lab in range(1, n_labels+1)
if (np.any(labels[[0, -1], :, :] == lab) |
np.any(labels[:, [0, -1], :] == lab) |
np.any(labels[:, :, [0, -1]] == lab))]
if data_array.ndim == 3:
data_array = np.ma.masked_where(
np.isin(labels, mask_labels), data_array)
else:
data_array = np.ma.masked_where(
np.broadcast_to(
np.isin(labels, mask_labels)[:, :, :, np.newaxis],
data_array.shape
),
data_array
)
# init sliders for the various dimensions
for dim, label in enumerate(['x', 'y', 'z']):
if label not in kwargs.keys():
kwargs[label] = IntSlider(
value=(data_array.shape[dim] - 1)/2,
min=0, max=data_array.shape[dim] - 1,
continuous_update=False
)
if (data_array.ndim == 3) or (data_array.shape[3] == 1):
kwargs['t'] = fixed(None) # time is fixed
else:
kwargs['t'] = IntSlider(
value=0, min=0, max=data_array.shape[3] - 1,
continuous_update=False
)
widgets.interact(self._plot_slices, data=fixed(data_array), **kwargs)
plt.close() # clear plot
plt.ion() # return to interactive state
def _plot_slices(self, data, x, y, z, t,
colormap='viridis', figsize=(15, 5)):
"""
Plot x,y,z slices.
This function is called by _default_plotter
"""
fresh = self.image_handles is None
if fresh:
self._init_figure(data, colormap, figsize)
coords = [x, y, z]
# add plot titles to the subplots
views = ['Sagittal', 'Coronal', 'Axial']
for i, ax in enumerate(self.fig.axes):
ax.set_title(views[i])
for ii, imh in enumerate(self.image_handles):
slice_obj = 3 * [slice(None)]
if data.ndim == 4:
slice_obj.append(t)
slice_obj[ii] = coords[ii]
# update the image
imh.set_data(
np.flipud(np.rot90(data[slice_obj], k=1))
if views[ii] != 'Sagittal' else
np.fliplr(np.flipud(np.rot90(data[slice_obj], k=1)))
)
# draw guides to show selected coordinates
guide_positions = [val for jj, val in enumerate(coords)
if jj != ii]
imh.axes.lines[0].set_xdata(2*[guide_positions[0]])
imh.axes.lines[1].set_ydata(2*[guide_positions[1]])
imh.set_cmap(colormap)
if not fresh:
return self.fig
def _init_figure(self, data, colormap, figsize):
# init an empty list
self.image_handles = []
# open the figure
self.fig, axes = plt.subplots(1, 3, figsize=figsize)
for ii, ax in enumerate(axes):
ax.set_facecolor('black')
ax.tick_params(
axis='both', which='both', bottom='off', top='off',
labelbottom='off', right='off', left='off', labelleft='off'
)
# fix the axis limits
axis_limits = [limit for jj, limit in enumerate(data.shape[:3])
if jj != ii]
ax.set_xlim(0, axis_limits[0])
ax.set_ylim(0, axis_limits[1])
img = np.zeros(axis_limits[::-1])
# img[1] = data_max
im = ax.imshow(img, cmap=colormap,
vmin=data.min(), vmax=data.max())
# add "cross hair"
ax.axvline(x=0, color='gray', alpha=0.8)
ax.axhline(y=0, color='gray', alpha=0.8)
# append to image handles
self.image_handles.append(im)
# plt.show()
def _custom_plotter(self, plotting_func, **kwargs):
"""Collect data and start interactive widget for custom plot."""
self.plotting_func = plotting_func
plt.gcf().clear()
plt.ioff()
# XYZ Sliders if plot supports it and user didn't provide any:
if ('cut_coords' in inspect.getargspec(self.plotting_func)[0]
and 'cut_coords' not in kwargs.keys()):
for label in ['x', 'y', 'z']:
if label not in kwargs.keys():
# cut_coords should be given in MNI coordinates
kwargs[label] = IntSlider(value=0, min=-90, max=90,
continuous_update=False)
# Create the widget:
interact(self._custom_plot_wrapper, data=fixed(self.data), **kwargs)
plt.ion()
def _custom_plot_wrapper(self, data, **kwargs):
"""Wrap a custom function."""
# start the figure
fig = plt.figure(figsize=kwargs.pop('figsize', None))
# The following should provide a colormap option to most plots:
if 'colormap' in kwargs.keys():
if 'cmap' in inspect.getargspec(self.plotting_func)[0]:
# if cmap is valid argument to plot func, rename colormap
kwargs['cmap'] = kwargs.pop('colormap')
else:
# if cmap is not valid for plot func, try and coerce it
plt.set_cmap(kwargs.pop('colormap'))
# reconstruct manually added x-y-z-sliders:
if ('cut_coords' in inspect.getargspec(self.plotting_func)[0]
and 'x' in kwargs.keys()):
# add the x-y-z as cut_coords
if ('display_mode' not in kwargs.keys()
or not any([label in kwargs['display_mode']
for label in ['x', 'y', 'z']])):
# If no xyz combination of display modes was requested:
kwargs['cut_coords'] = [kwargs[label]
for label in ['x', 'y', 'z']]
else:
kwargs['cut_coords'] = [kwargs[label]
for label in ['x', 'y', 'z']
if label in kwargs['display_mode']]
# remove x-y-z from kwargs
[kwargs.pop(label, None) for label in ['x', 'y', 'z']]
# Actually plot the image
self.plotting_func(data, figure=fig, **kwargs)
plt.show()