Abstract
The proliferation of healthcare data has brought the opportunities of applying data-driven approaches, such as machine learning methods, to assist diagnosis. Recently, many deep learning methods have been shown with impressive successes in predicting disease status with raw input data. However, the “black-box” nature of deep learning and the high-reliability requirement of biomedical applications have created new challenges regarding the existence of confounding factors. In this paper, with a brief argument that inappropriate handling of confounding factors will lead to models’ sub-optimal performance in real-world applications, we present an efficient method that can remove the influences of confounding factors such as age or gender to improve the across-cohort prediction accuracy of neural networks. One distinct advantage of our method is that it only requires minimal changes of the baseline model’s architecture so that it can be plugged into most of the existing neu-ral networks. We conduct experiments across CT-scan, MRA, and EEG brain wave with convolutional neural networks and LSTM to verify the efficiency of our method.
1. Introduction
The increasing amount of data has led healthcare to a new era where the diagnosis can be made directly from raw data such as CT-scan or MRI with data-driven approaches. Machine learning methods, especially deep learning methods, have achieved significant successes in biomedical and healthcare applications, such as classifying lung nodule,1 breast lesions,2 or brain lesions3 from CT-scans, segmentation of brain regions with MRI,4,5 or emotion classification with EEG data.6,7
However, different from how deep learning has revolutionized many other applications, the “black-box” nature of deep learning and the high-reliability requirement of healthcare industry have created new challenges.8 One of these challenges is about removing the false signals extracted by deep learning methods due to the existence of confounding factors. Acknowledging the recognition mistakes made by neural networks9–11 and empirical evidence that deep neural networks can learn signals from confounding factors,12 it is likely that a well-trained deep learning model will exhibit limited predictive performance on external data sets despite its high predictive power on lab collected data sets. The hazard of inappropriate control of confounding factors in healthcare-related science has been discussed extensively,13–15 but these discussions are mainly in the scope of causal analyses or association studies.
In addition to a very recent result showing that confounding factors can adversely affect the predictive performance of neural network models,16 we offer a straightforward example as another motivation: a neural network predictive model for Hodgkin lymphoma diagnosis is trained on a data set collected from young volunteers with high predictive performance, but when the model is applied to the entire society, it may report more false positives than expected. One of the reason could be that the gender ratio reverses toward adolescence in Hodgkin lymphoma,17 and a model trained over data collected from young volunteers is very likely to learn a different gender bias than what is expected in a data collected different age groups. In fact, even if the gender ratio does not change along the aging process, it is still inappropriate for a model to predict based on features related to gender because these features are not directly associated with disease status. As another example, skin cancer18 and colorectal cancer19 are also observed with gender bias, and it is already observed that there is a higher false negative rate in colorectal cancer diagnosis for women19 with traditional methods. Confounding factors do not just exist in the forms of gender. Also, it is observed that other factors, such as age,20 or demographic information,21 will affect the model’s performance if not handled appropriately. Considering that the generalization theory of neural networks is still an open research topic and people are unsure of how neural networks predict, it is particularly important to design methods to handle the influence of these confounding factors explicitly.
In this paper, inspired by previous de-confounding techniques applied to deep learning models,12 we propose a Confounder Filtering (CF) method. A distinct advantage of our method is that CF directly builds upon the original confounded neural network with a minimal change that replaces the original top layer with a layer that predicts the confounding factors. Further, we apply our methods to a broad spectrum of related tasks, such as:
improved lung adenocarcinoma prediction with convolutional neural networks (CNN) by removing contrast material as confounding factors.
improved heart right ventricle segmentation with U-net by removing subject identifications as confounding factors.
improved students’ confusion status prediction with Bidirectional LSTM by removing students’ demographic information as confounding factors.
improved brain tumor prediction with CNN by removing gender associated information as confounding factors.
We have observed consistent improvements in predictive performance by removing the confounding factors. These four empirical contributions have been conveniently summarized in Figure 1, which illustrates the experiments we perform in this paper, including the predictive task, the model we use, the data, and the confounding factors.
The remainder of this paper is organized as follows. In Section 2, we first briefly discuss the related work of this paper, mainly in the methodological perspective. In Section 3, we formally introduce our method, namely Confounder Filtering. Then in Section 4, we apply our method to a wide spectrum of experiments to show the effectiveness of our method and report relevant analysis. Finally, we conclude this paper with discussion of limitations and future directions in Section 5.
2. Related Work
The recent boom of deep learning techniques has allowed a large number of neural network methods developed for healthcare applications rapidly. Readers can refer to comprehensive reviews on how the deep learning can be applied to healthcare and biomedical areas.8,22–24 In this section, we will mainly discuss the related work of our paper in the methodological perspective.
To the best of our knowledge, there are not many deep learning works that control the effects of confounding factors explicitly. Wang et al presented a two-phase algorithm named Select-Additive Learning.12 In the first phase, the model uses information of confounding factors to select which components of the representation learned by neural networks are associated with confounding factors, and then in the addition phase, the algorithm forces the neural networks to discard these components by adding noises. Zhong et al also discussed how confounding factors affect the predictive performance of neural networks. They presented an augment training framework that requires little additional computational costs.25 The idea is to add another neural classifier that predicts confounding factors while predicting original labels, and gradient descent optimizes both of these classifiers. The general additional structure is very similar to the Confounding Filtering method that we are going to present, but our method trains the network in differently so that we can differentiate the weights associated with confounding factors and filter them out explicitly.
In a broader view, correcting confounding factors is related to reducing the representations learned by neural networks through some components of the raw data that are not related to the predictive task. In this perspective, there is a significant amount of neural network methods that can be considered as related work, covering the fields such as domain adaptation,26 transfer learning,27,28 and domain generalization.29 Readers can refer to the survey papers cited and the references therein if interested. Within the scope of this paper, we do not discuss with these methods for two reasons: 1) these methods are not designed for correcting confounding factors explicitly, therefore they may or may not be applicable in this specific situation, 2) even if our CF method behave similar to, or slightly shy of the performance of these methods, there is still a distinct advantage: CF is simple enough to be plugged into any neural networks with almost no changes of the architecture.
3. Confounder Filtering (CF) Method
In this section, we will formally introduce the Confounder Filtering (CF) method. CF method’s goal is to reduce the effects of confounders, therefore improves the generalizability of deep neural networks. We first offer an intuitive overview of the main idea of CF, then we formalize our method, which is followed by a discussion of the availability of the implementation.
3.1. Overview
CF method is aimed to remove the effects of confounding factors by removing the weights that are associated with them. Therefore, the core step is to identify such weights. We first train a model, namely G, conventionally for the predictive task. Then we replace the top model layer with another classifier that predicts the labels of confounding factors, and we continue to train the model. During this training phase, we keep track of the updates of weights. Finally, we filter out all the weights that are frequently updated during this training phase out of G by replacing these weights with zeros, leading to a new confounder-free model. This process is illustrated in Fig. 2.
3.2. Method
We continue to formalize our method. For the convenience of discussion, we split a deep neural network architecture into two components: representation learner component and classification component, denoted by g(·; θ) and f(·;ϕ) respectively, where θ and ϕ stand for the corresponding parameters. Therefore, the complete neural network classifier is denoted as f(g(·;θ);ϕ). Given data < y, X >, the classical training process of the neural network is achieved via solving the following equation: where c(·, ·) stands for the cost function, with famous examples such as mean-squared-error loss or cross-entropy loss.
Ideally, to effectively remove the effects of confounding factors, a method needs the labels of the confounding factors. In other words, we need data in the form of < X, y, s >, where s stands for the label of the confounding factors (e.g. age, gender, physical factors of medical devices etc.). This is also required by similar previous work.12,25 However, our method does not require full correspondence between X, y, and s. For example, later in our experiment, we will show that with two independently collected data sets < X1, y1 > and < X2, s2 >(i.e. we only have correspondence between X1 and y1, and between X2 and s2, but not between y1 and s2), we are able to correct the confounding factors between X1 and y1 with help of X2 and s2. For simplicity, we still present our method with < X, y, s >.
After we train the neural network following the conventional manner as showed in Equation 1 with < X, y > and get and , we continue to identify the weights associating with confounding factors through tuning the classification component via < X, s >. Formally, we solve the following problem:
During the optimization, our method inspects how the gradient of the cost function with respect to < X, s > updates the previous trained weights (i.e. ) with < X, y >. For the ith value of ϕ (denoted as ϕi), we calculate the frequency of updating it during the entire training process (denoted as πi). Formally, we have: where n is the number of total steps, t stands for the index of step.
Further, we construct a masking matrix/tensor M of the same shape as ϕ, and Mi is constructed according to πi. For example, common choices could be either through a Bernoulli sampling or a straightforward thresholding procedure:
In the following experiment, we choose to use the thresholding procedure with τ, whose value lies between top 20% and top 25% of πi’s values.
Finally, we have , where ⊗ stands for element-wise product, and the final trained neural network after confounding factor associated weights filtered out is as following: which is ready for confounder-free prediction.
3.3. Availability
The implementation of our method in TensorFlow is available onlinea with a simple example that trains a CNN for CifarlO dataset, onto which we add some image patterns as confounding factors. Users can follow the online instruction to apply CF to their own customized neural networks.
4. Experiments
In this section, we will verify the performance of our CF method on four different tasks by adding CF towards the current baseline models. For each task, we will first introduce the data set, and then introduce the methods we compare and the results. After discussions of these four tasks, we will introduce some analyses of the model behaviors to further validate the performance of our method.
4.1. lung adenocarcinoma prediction
4.1.1. Data
We construct a data set to test the model performance in classifying adenocarcinomas and healthy lungs from CT-scans. Our experimental data set is a composition of three data sets:
Data Set 1:The CT-images from healthy people are collected from ELCAP Public Lung Image Databaseb. The CT scans have obtained in a single breath hold with a 1. 25 mm slice thickness that consists of 1310 DICOM images from 25 persons.
Data Set 2:The CT-scans of diseased lungs are collected from 69 different patients by Grove et al.30 These scans are diagnostic contrast-enhanced CT scans, being done at diagnosis and prior to surgery and slice thickness at variable from 3 to 6 mm.
Data Set 3: Since these two data sets are collected differently, and one of them is a collection of contrast-enhanced CT scans. The contrast material will likely serve as the confounding factor in prediction. To correct the confounding factor. We noticed a processed versionc of Data Set 2, which consists of explicit labels of contrast information. The data set contains 475 series from 69 different patients selected 50% with contrast and 50% without contrast.
Therefore, we use the 1290 healthy images from 20 persons in Data Set 1 and 1214 diseased lung images from 61 patients in Data Set 2 as the training set, and the rest from these two data sets as the testing set. We use the images from Data Set 3 with corresponding contrast labels to correct confounding factors.
4.1.2. Results
We experiment with the most popular architectures of CNNs, including AlexNet,31 CifarNet,32 LeNet,33 VGG16,34 and VGG19.34 We first sufficiently train these baseline models with appropriate learning rate until the training accuracy converges, and then use our CF method to correct the confounding factors. We test the prediction accuracy of both vanilla CNNs and CF-improved CNNs. Fig. 3 shows the results. We can see that CF can consistently improve the predictive results over a variety of different CNNs.
4.2. Segmentation on right ventricle(RV) of Heart
4.2.1. Data
The data set35 contains 243 physician-segmented CT images (216×256 pixels) from 16 patients. Data augmentation techniques, such as random rotations, translations, zooms, shears and elastic deformations (locally stretch and compress the image), are used to increase the number of samples. More information regarding the data set, including how the training/testing data sets are split, can be found onlined.
4.2.2. Results
The main baseline in this experiment is U-net, which is a convolutional network architecture for fast and precise segmentation of images. Previous experiments show that U-net can behave well even with a small dataset.36 We first test U-net following previous setting35 and interestingly, we achieve a higher accuracy that what was reported. Vanilla U-net achieves an accuracy of 0. 9477. Then, we use CF method to remove the subject identities as confounding factors and improve the accuracy from 0.9477 to 0.9565.
4.3. Students’ confusion status prediction
4.3.1. Data
The data set37 contains EEG brainwave data from 10 college students while they watch MOOC video clipse. The EEG data is collected rom MindSet equipment wore by college students when watch ten video clips, five out of which are confusing ones. The students’ identities are considered as confounding factors in this experiment.
Following previous work,38 we normalize the training data in a feature-wise fashion (i.e., each feature representation is normalized to have a mean of 0 and standard deviation of 1 across each batch of samples). The batch size is set to 20.
4.3.2. Results
We use the state-of-the-art method applied to this data set,38 namely a Bidirectional LSTM, as the baseline method to compare with. The model is configured as following: the LSTM layer has 50 units, with tanh as activation function. The output is connected to a fully connected layer with a sigmoid activation. We compare five-fold-cross-validated results from CF-improved Bidirectional LSTM with results reported previously.38 The results are shown in Table 1. As we can see, CF method helps improve the predictive performance once plugged in.
4.4. Brain tumor prediction
4.4.1. Data
We construct another data set for the last experiment of this paper. We test our method in predicting brain tumors with MRA scans of healthy brainf and CT-scans with tumor brain.39 The healthy data set consists of images of the brain from 100 healthy subjects, in which 20 patients were scanned per decade and each group are equally divided by sex. The tumor data set is collected with 120 patients. The gender information is regarded as confounding factors in this experiment.
4.4.2. Results
Similar to the lung adenocarcinoma prediction experiment, we compare with the set of popular CNNs. The results are shown in Fig. 4. As we can see that, CF helps improve the prediction performance in most cases, except that in the VGG19 cases, when the model’s performance deteriorates after CF is plugged in.
4.5. Analyses of the method behaviors
To further understand the process of CF in identifying the weights that are associated with the confounding factors. We inspect how the weights are updated during the training process and visualize which part of the input data is related to confounding factors.
Fig. 5(a) visualizes the weights during each epoch. The figure splits into two panels, and the left panel is for lung adenocarcinoma prediction experiment, and the right panel is for brain tumor prediction experiment. The figure only shows eight weights of the top layer (in a 4 × 2 rectangle), and visualizes how the weights in the layer change as the training epoch increases. This figure visualizes 96 epochs for lung adenocarcinoma prediction and brain tumor prediction each. The blue dots visualize the weights when the model is trained during the first phase, and the green dots visualize the weights when the model is trained in the second phase for prediction confounding factors. The darker each dot is, the more frequent it gets updated in that epoch. As we can see, for the same 4 × 2 layer, the frequencies of the weights get updated are different between the training during the first phase and training during the second phase. This differences of updating frequencies verify the primary assumption of our method, that the weights associated with the task and the weights associated with the confounding factors are different. Therefore, we can remove the effects of confounding factors by removing the weights associated with them.
Further, we try to investigate which parts of the input data are corresponding to the confounding factors. With the help of Deep Feature Selection40 method, we select the pixels of the image that are associated with the confounding factors. Fig 5 visualizes these pixels with yellow dots. From left to right, these four images are examples for healthy lung, diseased lung, healthy brain, tumorous brain respectively. Interestingly, we do not see clear patterns on the images that are related to the confounding factors. This observation further verify the importance of our CF method because these results indicate that it is barely possible to firstexclude the information from raw images by conventional methods since these yellow dots do not form into any clear pattern.
5. Conclusion
In this paper, we proposed a straightforward method, named Confounder Filtering, which aims to reduce the effects of confounders and improve the generalizability of deep neural networks, to achieve a confounding-factor-free predictive model for healthcare applications. One distinct advantage of our method is that we only require minimal changes to the existing network model to adopt our method. There are still limitations of our method: despite our method only requires a minimal changes of the network architecture, it needs a repeated training process (the second phase training with confounding factors). Another limitation is that our method still requires the switching of the top classification layer from a label predictor to a confounder predictor, which may lose the one-to-one correspondence of weights at the top layer. In the future, in the methodological perspective, we look forward to further improving the training process of our method. On the practical side, as we have released our code, we hope to help the community to increase the performance of other predictive models for healthcare application by removing the confounding factors.
6. Acknowledgement
The authors would like to thank Mingze Cao and Yin Chen for discussions and creation of Fig 1 and Fig 5 This work is funded and supported by the Department of Defense under Contract No. FA8721-05-C-0003 with Carnegie Mellon University for the operation of the Software Engineering Institute, a federally funded research and development center. This work is also supported by the National Institutes of Health grants R01-GM093156 and P30-DA035778. The MR brain images from healthy volunteers used in this paper were collected and made available by the CASILab at The University of North Carolina at Chapel Hill and were distributed by the MIDAS Data Server at Kitware, Inc
Footnotes
E-mail: haohanw{at}cs.cmu.edu
To appear at Pacific Symposium on Biocomputing (PSB) 2019
↵d https://blog.insightdatascience.com/heart-disease-diagnosis-with-deep-learning-c2d92c27e730