Implementation considerations for deep learning with diffusion MRI streamline tractography

One area of medical imaging that has recently experienced innovative deep learning advances is diffusion MRI (dMRI) streamline tractography with recurrent neural networks (RNNs). Unlike traditional imaging studies which utilize voxel-based learning, these studies model dMRI features at points in continuous space off the voxel grid in order to propagate streamlines, or virtual estimates of axons. However, implementing such models is non-trivial, and an open-source implementation is not yet widely available. Here, we describe a series of considerations for implementing tractography with RNNs and demonstrate they allow one to approximate a deterministic streamline propagator with comparable performance to existing algorithms. We release this trained model and the associated implementations leveraging popular deep learning libraries. We hope the availability of these resources will lower the barrier of entry into this field, spurring further innovation.


Introduction
Deep learning has transformed diffusion MRI (dMRI) processing, with many recent studies focusing on streamline tractography with recurrent neural networks (RNNs) (Poulin et al., 2019). Instead of stepping through temporal features to propagate a signal in time, these studies step through voxel-based dMRI features to propagate a streamline, or a sequence of points approximating a white matter (WM) tract in the brain, in space. However, implementing RNNs to predict sequences of spatial points of arbitrary lengths that may not lie on the voxel-grid with batch-wise backpropagation is non-trivial. Further, an open-source implementation using commonly supported deep learning libraries is not yet widely available. To fill this gap, we detail considerations needed for implementing such a model, assess how one trained with these implementations performs against traditional tractography algorithms, and release the model and associated code implemented in PyTorch (v1.12).

Methods
Defining and computing ground truth labels and losses. We define a batch of K streamlines, S = s 1 , ..., s K , as a list of streamlines of non-uniform length. Specifically, we define streamline s k of length n k as a list of points, s k = x k 1 , ..., x k n k , where x k i is a point in continuous 3-dimensional voxel space. We define labels for x k i as the Cartesian unit vector We remove the last point from each streamline so that inputs and labels have the same length, setting n k = n k − 1. However, as unit vectors have two degrees of freedom, we do not have the RNN directly predict the labels in Cartesian space. Rather, we predict the labels in spherical coordinates as ∆x k i = (ϕ k i , θ k i ) and convert to Cartesian as ∆x k i = (sin ϕ k i cos θ k i , sin ϕ k i sin θ k i , cos ϕ k i ) prior to loss computation. We utilize a cosine similarity loss for each point of Streamlines can be propagated from the ith point to the next asx k i+1 = x k i + γ∆x k i where γ is the step size. Differentiably sampling dMRI features off the voxel grid. x k i , defined as a 3dimensional coordinate in voxel space, provides little utility for efficiently querying dMRI information at its location off the voxel grid. Thus, we instead convert each x k i to c k i , an 11-dimensional vector. Considering x k i as an off-grid point contained within a lattice of 8 on-grid points, the first 3 elements of c k i are the distance of x k i from the lowest lattice point along all 3 spatial axes in voxel space, x k i − ⌊x k i ⌋. The remaining 8 elements are the linear indices of the 8 on-grid points in the image volume. With these 11 values, the lattice values can be queried and interpolated trilinearly to obtain off-grid features for each point in s k as q k i = dM RI(c k i ) (Kang, 2006). As trilinear interpolation is differentiable, this allows for end-to-end training between input voxel grids and output losses at points off the grid.
Organizing data during training. As an example, we assume each q k i is a 45dimensional feature vector, as is commonly the case if the dMRI grid is a grid of fiber orientation distribution (FOD) spherical harmonic (SH) coefficients. Thus, S can be represented as a list of length K where each s k is a matrix of size n k ×45. However, the variability of n k across S is inefficient for the tensor-based parallelization frameworks utilized by deep learning libraries. Thus, we convert S into a "padded packed" tensor for training.
When aligned by the first element of each s k , S can be "padded" with zeros to a tensor of size M × K × 45, where M = max(n 1 , ..., n K ) is the length of the longest streamline in the batch. This padded tensor can then be "packed" to a tensor of size N × 45, where N = K k=1 n k . The packed formulation allows for batch-wise steps in recurrent neural networks for input sequences of different lengths, and the padded formulation allows for easier querying of specific points in their corresponding streamlines for loss aggregation. Both these operations and their inverses are natively supported in PyTorch.
The network predictions are also packed tensors of size N × 3 after conversion from spherical to Cartesian coordinates. To compute the batch-wise loss, we convert the packed predictions to padded representations of size M × K × 3, use a mask to ignore the padding, and average the loss across all the streamline points as 1 . For efficiency, we compute masks and save the labels in padded form before training.
Parallelizing inference. Unlike traditional tractography algorithms which parallelize tracking on the streamline level, RNNs must parallelize on the point level. In other words, each step of the RNN must advance all streamlines in a batch, as outlined in algorithm 1.
2 . CC-BY 4.0 International license available under a was not certified by peer review) is the author/funder, who has granted bioRxiv a license to display the preprint in perpetuity. It is made The copyright holder for this preprint (which this version posted April 6, 2023. ; https://doi.org/10.1101/2023.04.03.535465 doi: bioRxiv preprint . CC-BY 4.0 International license available under a was not certified by peer review) is the author/funder, who has granted bioRxiv a license to display the preprint in perpetuity. It is made The copyright holder for this preprint (which this version posted April 6, 2023. ; https://doi.org/10.1101/2023.04.03.535465 doi: bioRxiv preprint