Source code for niwidgets.streamlines

from __future__ import print_function
import nibabel as nib
import numpy as np
from ipywidgets import interact, fixed, widgets
import os
import ipyvolume as ipv


def length(x):
    """Returns the sum of euclidean distances between neighboring points"""
    return np.sum(np.sqrt(np.sum((x[:-1, :] - x[1:, :]) *
                                 (x[:-1, :] - x[1:, :]),
                                 axis=1)))


def color(x):
    """Returns an approximation for color of line based on its endpoints"""
    dirvec = (x[0, :] - x[-1, :])
    return (dirvec/(np.sqrt(np.sum(dirvec*dirvec, axis=-1)))).dot(np.eye(3))


[docs]class StreamlineWidget: """ Turns nibabel track files into interactive plots using ipyvolume. Color for each line is rendered as a function of the endpoints of the streamline. Currently, this resolves to red for left-right, green for anterior-posterior, and blue for inferior-superior. Args ---- filename : str, pathlib.Path The path to your ``.trk`` file. Can be a string, or a ``PosixPath`` from python3's pathlib. streamlines : a nibabel streamline object An streamlines attribute of an object loaded by nibabel.streamlines.load """ def __init__(self, filename=None, streamlines=None): if filename: if not os.path.isfile(str(filename)): # for Python3 should have FileNotFoundError here raise IOError('file {} not found'.format(filename)) if not nib.streamlines.is_supported(str(filename)): raise ValueError(('File {0} is not a streamline file supported' ' by nibabel').format(filename)) # load data in advance self.streamlines = nib.streamlines.load(str(filename)).streamlines elif streamlines: self.streamlines = streamlines else: raise ValueError('One of filename or streamlines must be specified' ) self.lines2use_ = None
[docs] def plot(self, display_fraction=0.1, **kwargs): """ This is the main method for this widget. Args ---- display_fraction : float The fraction of streamlines to show percentile : int The initial number of streamlines to show using a percentile of length distribution width : int The width of the figure height : int The height of the figure """ if display_fraction is not None and (display_fraction > 1 or display_fraction <= 0): raise ValueError('proportion_to_display is a float between 0 and 1' ' (0 excluded) or None') N = len(self.streamlines) num_streamlines = int(display_fraction * N) indices = np.random.permutation(N)[:num_streamlines] self.lines2use_ = self.streamlines[indices] self._default_plotter(**kwargs)
def _create_mesh(self, indices2use=None): if indices2use is None: lines2use = self.lines2use_ local_colors = self.colors else: lines2use = self.lines2use_[indices2use] local_colors = self.colors[indices2use] x, y, z = np.concatenate(lines2use).T # will contain indices to the verties, [0, 1, 1, 2, 2, 3, 3, 4, 4, 5..] indices = np.zeros(np.sum((len(line) - 1) * 2 for line in lines2use), dtype=np.uint32) colors = np.zeros((len(x), 3), dtype=np.float32) vertex_offset = 0 line_offset = 0 line_pointers = [] # if we have a line of 4 vertices, we need to add the indices: # offset + [0, 1, 1, 2, 2, 3] # so we have approx 2x the number of indices compared to vertices for idx, line in enumerate(lines2use): line_length = len(line) # repeat all but the start and end vertex line_indices = np.repeat(np.arange(vertex_offset, vertex_offset + line_length, dtype=indices.dtype), 2)[1:-1] indices[line_offset: line_offset + line_length * 2 - 2] = line_indices line_pointers.append([line_offset, line_length, line_indices]) colors[vertex_offset: vertex_offset + line_length] = local_colors[idx] line_offset += line_length * 2 - 2 vertex_offset += line_length return x, y, z, indices, colors, line_pointers def _default_plotter(self, **kwargs): """ Basic plot function to be used if no custom function is specified. This is called by plot, you shouldn't call it directly. """ self.lengths = np.array([length(x) for x in self.lines2use_]) if not('grayscale' in kwargs and kwargs['grayscale']): self.colors = np.array([color(x) for x in self.lines2use_]) else: self.colors = np.zeros((len(self.lines2use_), 3), dtype=np.float16) self.colors[:] = [0.5, 0.5, 0.5] self.state = {'threshold': 0, 'indices': []} width = 600 height = 600 perc = 80 if 'width' in kwargs: width = kwargs['width'] if 'height' in kwargs: height = kwargs['height'] if 'percentile' in kwargs: perc = kwargs['percentile'] ipv.clear() fig = ipv.figure(width=width, height=height) self.state['fig'] = fig with fig.hold_sync(): x, y, z, indices, colors, self.line_pointers = self._create_mesh() limits = np.array([min([x.min(), y.min(), z.min()]), max([x.max(), y.max(), z.max()])]) mesh = ipv.Mesh(x=x, y=y, z=z, lines=indices, color=colors) fig.meshes = [mesh] if 'style' not in kwargs: fig.style = {'axes': {'color': 'black', 'label': {'color': 'black'}, 'ticklabel': {'color': 'black'}, 'visible': False}, 'background-color': 'white', 'box': {'visible': False}} else: fig.style = kwargs['style'] ipv.pylab._grow_limits(limits, limits, limits) fig.camera_fov = 1 ipv.show() interact(self._plot_lines, state=fixed(self.state), threshold=widgets.FloatSlider( value=np.percentile(self.lengths, perc), min=self.lengths.min() - 1, max=self.lengths.max() - 1, continuous_update=False)) def _plot_lines(self, state, threshold): """ Plots streamlines This function is called by _default_plotter """ if threshold < state['threshold']: # when threshold is reduced, increase the number of lines state['indices'] = np.where(self.lengths > threshold)[0] with state['fig'].hold_sync(): mesh = state['fig'].meshes[0] copy = mesh.lines.copy() for idx in state['indices']: (line_offset, line_length, line_indices) = self.line_pointers[idx] copy[line_offset:line_offset + line_length * 2 - 2] = \ line_indices mesh.lines = copy mesh.send_state('lines') else: # when threshold is increased, decrease the number of lines indices = np.where(self.lengths <= threshold)[0] with state['fig'].hold_sync(): mesh = state['fig'].meshes[0] copy = mesh.lines.copy() for idx in indices: (line_offset, line_length, line_indices) = self.line_pointers[idx] copy[line_offset:line_offset + line_length * 2 - 2] = 0 mesh.lines = copy mesh.send_state('lines') state['indices'] = np.where(self.lengths > threshold)[0] state['threshold'] = threshold