STAMP: Simultaneous Training and Model Pruning for Low Data Regimes in Medical Image Segmentation

Acquisition of high quality manual annotations is vital for the development of segmentation algorithms. However, to create them we require a substantial amount of expert time and knowledge. Large numbers of labels are required to train convolutional neural networks due to the vast number of parameters that must be learned in the optimisation process. Here, we develop the STAMP algorithm to allow the simultaneous training and pruning of a UNet architecture for medical image segmentation with targeted channelwise dropout to make the network robust to the pruning. We demonstrate the technique across segmentation tasks and imaging modalities. It is then shown that, through online pruning, we are able to train networks to have much higher performance than the equivalent standard UNet models while reducing their size by more than 85% in terms of parameters. This has the potential to allow networks to be directly trained on datasets where very low numbers of labels are available.

Deep learning-based methods have become state-of-the-art for medical image segmentasegmentation. A network is trained to predict the target segmentations y n from the input   Our pruning strategy removes whole convolutional kernels as shown in Fig. 2, meaning that we do not end up with a sparse representation and are able to use standard libraries and hardware. When pruning whole convolutional filters, we remove the filter from the kernel matrix w i , the corresponding bias, and the filter from the weight kernel matrix w i+1 , as in (Li et al., 2017). To decide which filters to remove, we consider the feature activation maps, denoted by z (k) l ∈ R W l ×H l ×D l ×C l where l ∈ {1, ..., L}, is the current layer, and C l is the number of channels at layer depth l. The activation maps are related to the kernel weights by: z where ReLU is the activation function used throughout our network, other than for the final Figure 2: Whole convolutional filters are pruned, based on the magnitude of the corresponding feature maps z l,i (the convolutional kernel to be removed is shown in yellow). This requires the weights to be pruned from the kernel matrix w l and w l+1 (shown in yellow in the kernel matrices). The corresponding biases are also pruned but are not shown.
biases b (k) l , corresponding to the smallest magnitude filter activations, z (k) l , are pruned from the network. To assess the overall magnitude of the filters, the L2 norm (Han et al., 2015), averaged across all training data points (calculated from an additional forward pass), is considered: The L2 norm is used as it is computationally simple and provides stable performance, but 147 other metrics could be used to evaluate the magnitude of the filters. This was explored in 148 the supplementary material, and it was found that STAMP was robust to this choice.

149
These values must then be normalised across the layer depth, because the values are at different scales at different depths (Molchanov et al., 2016). Therefore, a simple L2 normalisation across the values at each layer is employed.
where j iterates over all the filter kernels in layer l. 150 To prune the filter, it is removed entirely from the model architecture, rather than just set  To make the model more robust to the pruning, adaptive channelwise targeted dropout, Algorithm 1: Adaptive Targeted Dropout Algorithm. The index function returns the index (or position) of the Θ for the k th filter at layer l within the full sorted list of Θ values across the network. From this, the average index at each layer depth is calculated and used for determining the dropout probability (p) for each layer. The output p is a vector of probabilities. [] represents a list of elements, such that they remain together during the sort descending process.
Input: Normalised filter magnitudes for every filter at every depth, output from pruning algorithm: Θ = [Θ L 2 (z Multiply by base dropout value create a dropout scheme where the convolutional kernels which are most likely to be removed 159 by the pruning procedure are the most likely to be dropped out during training. 160 Our approach builds upon two previous works. In (  The final layer (l = 18) contains the same number of filters as output classes C and so 216 clearly cannot be pruned (in all cases the background class is considered as an additional 217 class). Therefore, only filters in layers 1 − 17 will be considered as candidates to be pruned 218 Figure 3: UNet model architecture: it follows the standard pattern of halving resolution and doubling filters at each depth. l corresponds to the layer depth, C l is the number of channels in that layer and C is the number of classes in the output segmentation. f is the initial number of filters, and is varied across experiments, but is 4 unless otherwise stated. were considered, which will be referred to as follows:

238
• STAMP -The proposed method, training and pruning the network simultaneously, without the green blocks shown in Fig. 1, the procedure therefore consisting solely of  The PruneFinetune baseline was chosen to be a fair comparison to standard pruning 253 methods. Given that none of the existing methods were developed for UNet-style archi-254 tectures and segmentation tasks, nor for low data regimes, implementing the methods as 255 presented in prior work was not appropriate. Thus, we retained the pruning metric and 256 pruning quantity from STAMP+, and thus these were combined with the standard proce-257 dure of taking the pretrained model, pruning the filter and then fine-tuning to convergence.

258
In this way we were best able to compare any advantage of simultaneously training and 259 pruning, and the results were unlikely to be due to other design decisions.

260
The same data splits were used for evaluating all methods. were then reported for the held-out testing data, using the selected model.   segmentation of the hippocampus using manual labels, such that C = 2 (hippocampus and

299
IXI 1 -HH: 3T T1 MRI dataset, preprocessed using FSL Anat 2 , with 3D images of 166 300 subjects for training and 19 subjects for testing. The task was segmentation of the brain, 301 with the labels being automatically generated by the FSL Anat pipeline by non-linearly 302 registering an atlas to the target image, such that C = 2 (foreground and background).

310
The following section will explore first the low data regime results from across the datasets 311 (Section 3.1), including an ablation study (Section 3.2). Then we present a methods com-   Figure 5: STAMP+ segmentation results on the HarP data for both 50 data points for training and 200, with dice score plotted against the number of parameters remaining in the network architecture. As the network is pruned, the number of parameters reduces, and so the x-axis is inverted. The mean dice score is shown, with interquartile range. The mean dice score from the Standard UNet is shown for comparison. Note that STAMP+ begins from random initialisation and so the performance is initially poor. It is evident that the performance improvement is greatest with a low number of data points but the performance is more stable between iterations with more training data. where it can be seen, through comparing the dice score between pruning iterations, that the 341 training of STAMP+ was more stable with more training data. We can also see that more training. 346 We also considered five further datasets with the results shown in Fig. 6. It can clearly 347 be seen that across the segmentation tasks and imaging modalities, STAMP+ outperformed     It can also be seen that while the improvement provided by the targeted dropout was 389 relatively small, it was significant (p = 0.0007) and, in addition, increased the stability of 390 the STAMP+ procedure. There was also a significant improvement compared to using a It can be seen that even the simplest pruned model, STAMP, outperforms all of the models trained without pruning, which were unable to segment all three labels, as seen in Fig. 8. STAMP+ slightly but significantly outperformed the other approaches for the Dice score averaged over the three regions.

STAMP+.
Network performance was further hindered at higher dropout values.

397
The poor performance of the Standard UNet on this task was almost certainly due to 398 the shortage of training data available for the task. Access to more data for training was 399 not an option in this instance, and so how much more data would be required to improve 400 the performance could not be explored. However, reducing the amount of training data to 401 explore the performance of the STAMP+ method could be explored. Figure   We then compared STAMP+ to alternative methods for producing smaller models. the HarP dataset, but the results were seen to be consistent across the datasets considered.

418
(a) Figure 10: Dice scores averaged across the three subcortical regions for increasing numbers of available training images for the IXI data. 250 represents the full training set. It can clearly be seen that more pruning iterations were required to reach the same dice performance as the amount of data was decreased, until there was insufficient data available.

419
STAMP+ was therefore first compared to PruneFinetune for increasing numbers of train-420 ing subjects. In Fig. 11 the comparison can be seen for 25, 50 and 100 training subjects. 421 We found that when the number of training subjects was very low (for instance, consider             as good or increased performance on the task of interest, but also that, through pruning, 544 better performing models can be produced in low data regimes.

545
Using the HarP data, the method has been validated on data with manual segmentation Whilst care has been taken to test the method across modalities and segmentation tasks, 571 the labels used for the IXI dataset were generated using automated tools. The two tools 572 used -FSL FIRST and ANAT -employ differing model assumptions so the results should 573 be valid for comparison between methods; however, it does potentially limit the maximum 574 achievable performance due to the imperfections in the labels.

575
The method has also only been explored for the standard UNet architecture. This deci-576 sion was made as the UNet is the most popular architecture for medical image segmentation, 577 and the majority of methods either use the UNet architecture or derivatives thereof. It is 578 expected that the results would generalise to other similar networks, but this has not been 579 explored explicitly within this work, and so future work should focus on exploring the ap-580 proach for other network architectures commonly used in medical imaging.

581
Finally, a potential limitation of the method is that the performance between pruning 582 iterations is unstable, especially compared to the PruneFinetune method. Across this work, 583 the best model was evaluated using the validation data, and this has corresponded well 584 with good performance on the testing data -although it does not always correspond to the 585 highest performing iteration on the testing data. This has been true across the datasets 586 explored here but would not necessarily be true if the method were to be applied to other 587 datasets. The use of the targeted dropout helps to reduce this by increasing the stability 588 between iterations, but if this were not seen to be the case for a given dataset, it may be 589 necessary to increase the number of recovery epochs to increase the stability of the model 590 training. It is probable that a lot of the instability is due to working in the low data regime, 591 as the stability visibly increases as the number of training subjects is increased. In practice, 592 a large number of recovery epochs could be used, and gradually reduced to smaller numbers 593 if the training were stable, and could be simply automated.  It can be seen that the three metrics performed similarly, with no significant difference between L2 and Taylor metrics. All three metrics were better than random pruning. Note that the results are plotted against the remaining parameters rather than the pruning iteration. sation ensured that the last filter at any depth was not pruned. With Random pruning, the 747 condition that the last kernel at a given depth cannot be pruned was explicitly coded. Figure   748 18 shows the Dice scores on the test dataset at each pruning iteration for models trained 749 using STAMP+ and with each metric in turn. For each metric, the mean value across the 750 test set is shown as the solid line, and the shaded region indicates the interquartile range.

751
The results are plotted against the remaining parameters rather than the pruning iteration.  It can first be seen that all three metrics performed better than randomly pruning channels, as would clearly be expected. As the models started from random initialisation, the 757 performance initially for all four metrics was poor, then improved rapidly for all metrics 758 except random as the model training continued. It can then be seen that all three of the 759 metrics performed comparably on this task, with no significant difference between the per-760 formance of the L2 and Taylor metrics (L2 vs Taylor: p = 0.07, L2 vs L1: p = 0.003, L2 761 vs Random: p = 5.7 × 10 −11 ). As the L2 norm is computationally more efficient than the 762 Taylor metric, it was used throughout the experiments.  First, it can be seen that the model could be pruned and trained simultaneously, such that 773 the same network performance was reached as the model trained to convergence; therefore, 774 the pruned network was sufficiently powerful to be able to represent the variation in the 775 data. It is also evident from Fig. 19 that it was possible to prune the models to a fraction 776 of their original size without reducing the network performance. We also found that the   if we began with a larger model and pruned it to be the same size as the smaller model, 797 rather than originally training the smaller model to convergence. The pruning was, however, 798 less stable, showing that the model training was less robust to filters being removed than 799 when the model was deeper.  shows that the targeted dropout successfully made the model more robust to being pruned, 820 even with this relatively simple task.   Comparing STAMP to PruneFinetune, we can see that pruning the already converged 822 model leads to more stable training than STAMP (Fig. 24). However, the removal of a for the pruning process. It can clearly be seen that STAMP+ requires a fraction of the epochs 838 to train and that the validation performance is much more stable throughout the process.     as indicated by the model architecture (Fig. 3). It can be seen that until the model becomes  also be seen across the network that the first layer within a pair of layers, at a given depth,

871
is pruned more quickly than the second, for both the encoder and the decoder. We repeated the experiment presented in Section 3.4.1 on 25 and 100 data points respec-893 tively. It can be seen that the pattern is the same as that presented for 50. The more data 894 points available for training, the lower the impact of the number of recovery epochs became.

895
(a) (b) (c) Figure 28: a) Shows the distribution of the filters with network depth as the model is pruned, meaning that the darker the shade of the filter block, the longer the filters at that depth were maintained in the models. b) Shows the magnitudes of the activations, averaged across the training data, maintained in the model as it was pruned, where filter depth is the count of filters from the input. It can be seen that the lower magnitude activations were pruned first and the average value of the activation increased as the model was pruned. The first vertical dashed line at 182 filters corresponds to the distribution of filters which gave the best performance on the testing data. The second dashed line corresponds to the distribution of the filters for the smallest model that was able to complete the segmentation successfully, with no substantial difference in performance from the best performing model. c) Shows the average filter magnitude, and the lower and upper quartile bounds with pruning iteration. It can be seen that the average value increased consistently with pruning iteration.