Human-in-the-Loop Interactive Model Training
A method is described for training a predictive model which increases the interpretability and trustworthiness of the model for end-users. The model is trained from data having multitude of features. Each feature is associated with a real value and a time component. Many predicates (atomic elements for training the model) are defined as binary functions operating on the features, and typically time sequences of the features or logical combinations thereof. The predicates can be limited to those functions which have human understandability or encode expert knowledge relative to a predication task of the model. We iteratively train a boosting model with input from an operator or human-in-the-loop. The human-in-the-loop is provided with tools to inspect the model as it is iteratively built and remove one or more of the predicates in the model, e.g. if it does not have indicia of trustworthiness, is not causally related to a prediction of the model, or is not understandable. We repeat the iterative process several times ultimately generate a final boosting model. The final model is then evaluated, e.g., for accuracy, complexity, trustworthiness and post-hoc explainability.
The present application is a national stage entry of PCT/US2017/054213 filed Sep. 29, 2017 and U.S. Provisional Patent Application 62/552,088 filed Aug. 20, 2017, the contents of which are hereby incorporated by reference.
PRIORITYThis application claims priority benefits of U.S. Provisional Application Ser. No. 62/552,088 filed Aug. 30, 2017.
BACKGROUNDThis disclosure relates to the field of machine learning, and more particularly to a method of training a predictive model from underlying data.
Machine learning models, for example neural network models used in the health sciences to make predictions or establish a predictive test, tend to suffer from a problem that they are difficult to understand by end-users, such as physicians or medical researchers. The lack of understanding of how the models work leads to a lack of trust in the models. In other words, the models are not “interpretable”, and are often thought of as some unknowable “black box.” As machine learning models become more widely adopted to aid experts like judges and doctors to make consequential decisions, there is significant interest to ensure that such systems are more than simply accurate, they must be understandable and instill trust, a collection of traits generally referred to as “interpretable.” Z. Lipton, The Mythos of Model Intepretability, arXiv:1606.03490 [cs.LG] (June 2016).
Interpretability has no universally agreed upon technical definition in the machine learning community, but some have proposed the following properties:
Complexity or model size. A model that can be understood by a human in its entirety, like a sparse linear model. A variant of this is if a human could perform inference in a reasonable amount of time. This has also been called simulatability.
Understandable. A clear relationship between how an input is considered by the model, like a node in a decision tree. This has also been called decomposability.
Training Transparency. The method of the training, like convex optimization, has well understood properties, like those used to train linear models.
After-the-fact end-user interpretability. That is, the model allows for an after the fact explanation of a prediction, like a saliency map, or examples of cases with similar predictions.
This disclosure presents a solution to this problem of generating interpretable models. In this regard, we describe a method of generating a predictive model that is interpretable by end-users. While the disclosure provides an example of a method of training a predictive model in the context of electronic health records, it is offered by way of example and not limitation as the method could be used in other situations where there is a desire to generate more understandable or interpretable predictive models for other types of end-users.
SUMMARYThis disclosure relates to a computer-implemented method of training a predictive model which is interpretable to end-users and inherently more understandable and hence trustworthy than other types of models, such as deep neural networks. There are several aspects which contribute this goal, including representation of “knowledge” in the model in a human-understandable form and the use of input from human operator or expert in the middle of model training. In the illustrated embodiment, knowledge in the model is in the form of human-understandable predicates. The model consists of a set of predicates and weights. The input from the human in the model training allows for the deselection of proposed predicates for the model which are deemed by the human to be not trustworthy or otherwise undesirable in the model. Accordingly, the whole model is understandable and modifiable by a human. The model also has very desirable expressiveness due to a flexible design of the predicate types.
In one embodiment, the model is built up gradually over many iterations, a technique known as boosting. The method makes use of data having a multitude of features (e.g., unstructured data such as words in text notes, medications, lab results, vital signs, previous hospitalizations, etc.). Every instance of each feature is associated with a real value (such as a vital sign or a word in a note) and a time component. The time component could be an index in a time sequence, or a time in the past relative to a current time when a prediction is generated by the model, such as some number of days, months or minutes in the past. In one embodiment, the data is structured in a tuple format of the type {X, xi, ti} where X is the name of feature, xi is a real value of the feature and ti is a time component for the real value xi.
The method includes a step of defining a multitude of “predicates.” The predicates are binary functions operating on sequences of the tuples and return a result of 0 or 1. Predicates could also be binary functions of logical combinations of sequences of tuples, such as Predicate 1 OR Predicate 2; or Predicate 1 OR Predicate 2 where Predicate 2=Predicate 2a AND Predicate 2B). As another example, a predicate could be combination of two Exists predicates for medications vancomycin AND zosyn over some time period. The predicates can be grouped into types, such as “relatively human understandable” predicates such as Exists or Counts type predicates, and relatively less human understandable predicates. An example of an Exists predicate for feature X is “did the token/feature X exist in the electronic health record for a patient at any time?” If so, a 1 is returned and if not a 0 is returned. An example of a Counts predicate is “does the number of counts of feature X over all time in the electronic health record for a patient exceed some value C?” If so a 1 is returned, otherwise a 0 is returned. In a complex data set such as unstructured electronic health records over a large number of patients, the number of possible predicates is extremely large, potentially in the millions. However, the predicates can be designed or structured in a human understandable way. That is, the definition of the predicates can be specified by an expert (e.g., end-user) so that they are conceptually related and relevant to predictions that may be made by the model.
The method includes step of iteratively training a boosting model. The boosting model can be seeded or initialized by a bias term such a 1. The iterative training method includes the following:
1) generating a number of new predicates selected at random (in one possibility these predicates are human understandable predicates only, but this is not essential; additionally it may be possible to automatically exclude predicates that a human would delete as untrustworthy or irrelevant anyway). In one embodiment 5,000 predicates are selected at random.
2) scoring all the new random predicates by weighted information gain with respect to a class label associated with a prediction of the boosting model (e.g., the diagnostic billing code at discharge, inpatient mortality, etc.).
3) selecting a number, e.g., 10, of the new random predicates with the highest weighted information gain and adding them to the boosting model.
4) computing weights for all the predicates in the boosting model; and
5) removing one or more of the selected new predicates with the highest information gain from the boosting model in response to input from an operator or human-in-the-loop (e.g., a human expert views the predicates and removes those that are deemed to be less trustworthy, not understandable, irrelevant, or otherwise).
Steps 1, 2, 3, 4 and 5 are repeated iteratively, for example 10 or 20 times, gradually building up a boosting model. The use of a human-in-the-loop enhances the interpretability and reduces the complexity of the model by removing predicates that are not trustworthy, irrelevant, add unnecessary complexity, etc. This iterative process generates a final iteratively trained boosting model.
In one embodiment, after the final iteratively trained boosting model is generated it is evaluated, e.g., for accuracy or performance, indicia of interpretability, such as trustworthiness, complexity, human understandability, post-hoc explainability, etc.
The disclosure includes several methods for visualizing the model in the evaluation step. These can include, among others, i) displaying the iterative process of generating the boosting model by addition of predicates in each boosting round, (ii) displaying the grouping of the predicates in the final iteratively trained boosting model, e.g., by subject matter or related concepts, (iii) visualizing predicates, to make them more human understandable, as well as (iv) user interface tools for presenting proposed predicates with the highest weighted information gain and providing an expert user to deselect one or more of the proposed new predicates.
In another aspect, a computer-implemented method of training a predictive model from electronic health record data for a multitude of patients is disclosed. The data includes a multitude of features, each feature associated with real values and a time component, wherein the data is in a tuple format of the type {X, xi, ti} where X is the name of feature, xi is a real value of the feature and ti is a time component for the real value xi. The method includes implementing the following instructions or steps in a processor of the computer:
- a) defining a multitude of predicates as binary functions operating on sequences of the tuples or logical operations on the sequences of the tuples;
- b) dividing the multitude of predicates into groups based on understandability, namely a first group of relatively more human understandable predicates and a second group of relatively less human understandable predicates; and
- c) iteratively training a boosting model by performing the following:
1) generating a number of new random predicates from the first group of predicates;
2) scoring all the new random predicates by weighted information gain with respect to a class label associated with a prediction of the boosting model;
3) selecting a number of the new random predicates with the highest weighted information gain and adding them to the boosting model;
4) computing weights for all the predicates in the boosting model;
5) removing one or more of the selected new predicates with the highest information gain from the boosting model in response to input from an operator; and
6) repeating the performance of steps 1, 2, 3, 4 and 5 a plurality of times and thereby generating a final iteratively trained boosting model.
In still another aspect, we have disclosed an improved computing platform, e.g., general purpose computer, implementing a machine learning model. The improvement takes the form of the machine learning model being an iteratively trained boosted model built from predicates defined as binary functions operating on sequences of features having both a real value and time component. The predicates are defined with operator input the selection of predicates for inclusion in the iteratively trained boosted model are subject to review and selection or deselection by an operator during iterative training of the boosting model.
In one embodiment the features are features in electronic health records. Other types of training data sets could be used and the use of electronic health records is offered by way of example and not limitation.
In still another aspect, a workstation is disclosed for providing operator input into iteratively training a boosting model. The workstation includes an interface displaying predicates selected as having a weighted information gain for making a prediction of the boosting model, and the interface providing a tool for selection or deselection of one or more of the predicates in iteratively training the boosting model.
It will be noted that in the broadest sense, the methods of this disclosure can be used for “features” in training data where the term “features” is used in its traditional sense in machine learning as individual atomic elements in the training data which are used to build classifiers, for example individual words in the notes of a medical record, laboratory test results, etc. In the following description we describe features in the form of binary functions (predicates) which offer more complex ways of determining whether particular elements are present in the training data, taking into account time information associated with the elements. More generally, the methodology may make use of a test (or query) in the form of a function applicable to any member of the training data to detect the presence of one or more of the features in that member of the training data.
Accordingly, in one further aspect a computer-implemented method of generating a predictive model from training data is described, the predictive model being for predicting a label based on input data which, for each of a plurality of features X, indicates a value x of the feature at each of a plurality of times, and the training data comprising a plurality of samples, each sample indicating the value of one or more of the features at each of one of more times and a corresponding label. The method comprises implementing the following steps as instructions with a processor:
defining a set of predicates, each predicate being a function which generates an output when applied to time sequences of the features or logical combinations of the time sequences of the features;
generating a boosting model, the boosting model receiving as input the respective outputs of each of the set of predicates when applied to the samples of the training data; and
performing a plurality of times the sequence of steps of:
(i) automatically generating a plurality of additional predicates;
(ii) adding the plurality of additional predicates to predicates already in the boosting model to form an updated set of predicates;
(iii) displaying a plurality of the updated set of predicates; and
(iv) receiving data input rejecting one or more of the updated set of predicates; and
(v) removing the rejected one or more predicates from the updated set of predicates.
This disclosure relates to a computer-implemented method of training a predictive model which is interpretable to end-users and inherently understandable and hence trustworthy. There are several aspects which contribute this goal, including representation of “knowledge” in the model in a human-understandable form and the use of input from human operator input in the middle of model training.
This document will explain how the method works in the context of a particular problem domain, but as noted above the method can be used more generally to other types of problems.
In the following discussion, the input to the model is an electronic health record (EHR) data set which is the set of medical information collected by a health system or hospital about patients, including time-stamped structured information (e.g. all medications and dosages given to patients, laboratory values, diagnoses, vital signs, procedures, etc.) and unstructured data (e.g. clinical notes). Recent rapid adoption of EHRs in the United States makes modeling on this data particular important to improve care delivery.
A patient quickly accumulates hundreds of thousands of data-points, and in clinical practice, this information cannot even be visualized in a single EHR screen. This is particularly the case in the context of high-dimensional inputs with correlated features, as is the case in personalized medicine.
In the present disclosure we describe by way of example the generation of models to make two predictions:
1. Diagnosis: Predict the primary billing diagnosis of a patient. These predictions may save the physician time looking up codes, whose accuracy can promote better secondary use of the data by health systems and researchers.
2. In-Patient Mortality: Predict whether a patient is going to die during their hospital stay; i.e., mortality of a patient. The predictions of the model can be used to guide a doctor to intensify monitoring and checkups or discuss prognosis with patients in case of a (unexpectedly) high predicted risk of mortality.
In both cases, in order to make use of the predictions the doctor needs to understand why a prediction is what it is; in other words the model needs to be interpretable.
We will now construct a toy example of two models that are equivalent when measuring their accuracy, complexity, decomposability, training transparency and end-user interpretability. However, their intuitive interpretability varies significantly.
Example 1: Model A only counts the number of breakfasts the patient had in the hospital documented by a nurse which is part of the EHR. There is a positive correlation between this features and mortality. Model B instead uses on the number of days stayed at the hospital. Both models use only a single (derived) feature, may have the same accuracy, were trained the same way and can be used to explain predictions. But a clinician finds it Model B easier to interpret.
This example motivates the addition of another property of interpretability that we call “feature-trustworthiness.” Like interpretability, it is a notion difficult to measure. We offer the following definition: an input feature is “trustworthy” if it is easy to understand by itself and end-users of the model believe that the feature is directly or causally related to the predicted outcome. A model is trustworthy if the features used for explaining the model's predictions are trustworthy.
Previously, a handful of features where hand-crafted and chosen with trustworthiness in mind and models were built with these features. This method incorporates domain expert's knowledge, but is not data driven. But with the approach of scalable machine-learning better results were achieved with models that operate on all the features and automate the feature selection process. This method is at the opposite end as it is data-driven but no domain knowledge is required and the results are not interpretable. Our method can be considered as a hybrid of data-driven and domain expert guided machine learning that achieves state-of-the-art results.
A dimension of model interpretability that is underexplored in the literature is dealing with data that may not be immediately interpretable. For example, an electronic health record contains time series data of structured and unstructured data that requires domain expertise to nominally understand. The pre-processing, feature engineering, and data-augmentation to transform the raw data into features for an algorithm are necessary for end-users to understand how raw data was entered into the algorithm; the understandability of these steps are what we call “pre-processing interpretability.”
There has been less research about the interaction of these different components of interpretability. In this document we describe a new machine learning model that promotes multiple aspects of interpretability, and report results on classifying diagnoses and predicting in-patient mortality using electronic medical records.
We developed a novel machine learning method which we have called Space-Time Aware Boosting LEarner (STABLE), which is shown in
Our data set for model generation was the MIMIC-III dataset which contains de-identified health record data on critical care patients at Beth Israel Deaconess Medical Center in Boston, Mass. between 2002 and 2012. The data set is described in A. E. Johnson et al., MIMIC-III, a freely accessible critical care database, J. Sci. Data, 2016.
The EHR data looks like a sequence of events with associated time stamps. For example, a medical record might contain historical values including vital measurements such as blood pressure, weight and heart rate. Lab values over time are also present at various time scales from daily to weekly to once every few years. There are also medical notes associated at particular times. Hence the model architecture for such data is not a straightforward choice of the standard feature and label as the features here happen at a particular time.
Referring now to
Methodology
The data in the data set 12 contains a multitude of features, potentially hundreds of thousands or more. In the example of electronic health records, the features could be specific words or phrases in unstructured clinical notes (text) created by a physician or nurse. The features could be specific laboratory values, vital signs, diagnosis, medical encounters, medications prescribed, symptoms, and so on. Each feature is associated with real values and a time component. At step 16, we format the data in a tuple format of the type {X, xi, ti} where X is the name of feature, xi is a real value of the feature (e.g., the word or phrase, the medication, the symptom, etc.) and ti is a time component for the real value xi. The time component could be an index (e.g., an index indicating the place of the real value in a sequence of events over time), or the time elapsed since the real value occurred and the time when the model is generated or makes a prediction. The generation of the tuples at step 16 is performed for every electronic health record for every patient in the data set. Examples of tuples are {“note:sepsis”, 1, 1000 seconds} and {“heart_rate_beats_per_minute”, 120, 1 day}.
At step 18, in order to deal with the time series nature of the data, via software instructions we binarize all features as predicates and so real valued features might be represented by a space-time predicate such as heart_rate>120 beats per minute within the last hour. The term “predicate” in this document is defined as a binary function which operates on a sequence of one or more of the tuples of step 16, or binary function operating on logical combinations of sequences of the tuples. All predicates are functions that return 1 if true, 0 otherwise. As an example, a predicate Exists “heart_rate_beats_per_minute” in [{“heart_rate_beats_per_minute”, 120, 1 week}] returns 1 because there is a tuple having {“heart “heart_rate_beats_per_minute”, 120, 1 day} in the entire sequence of heart_rate_beats_per_minute tuples over the sequence of the last week. Predicates could also be binary functions of logical combinations of on sequences of tuples, such as Exists Predicate 1 OR Predicate 2; or Exists Predicate 1 OR Predicate 2 where Predicate 2=Predicate 2a AND Predicate 2B). As another example, a predicate could be a combination of two Exists predicates for medications vancomycin AND zosyn over some time period.
At step 20, there is the optional step of grouping the predicates into two groups based on human understandability (i.e., understandable to an expert in the field). Examples of predicates in Group 1, which are the maximally human understandable predicates, are:
Exists: X—did the token/feature X exist at any point in a patient's timeline. Here X can be a word in a note, or the name of a lab or a procedure code among other things.
Counts: #X>C. Did the number of existences of the token/feature X over all time exceed C. More generally, a Counts predicate returns a result of 0 or 1 depending on the number of counts of a feature in the electronic health record data for a given patient relative to a numeric parameter C.
Depending on the type of prediction made by the model, other types of human understandable predicates could be selected as belonging to Group 1. Additionally, human understandable predicates could be generated or defined during model training by an operator or expert.
The predicates in Group 2, which are less human-understandable, can be for example:
Any x(i)>V at t(i)<T. Did the value of x(i) exceed V at time less than T in the past (or alternatively X<=V).
Max/Min/Avg_i x(i)>V. Did the maximum or minimum or average of X>V (or alternatively X<=V) over all time.
Hawkes process. Did the sum of exponential time decayed impulses when x(i)>V exceed some activation A over some time window T? Activation=sum_i I(x(i)>V)*exp(−t(i)/T)
Decision List predicates where any two conjunctions of the above predicates are used.
True—always returns 1. This is the first predicate (seed) in the boosting model and acts as the bias term. It is initialized to the log odds ratio of the positive class in the first batch.
Referring again to
In order to overcome the problem of difficulty in understanding or interpreting deep neural networks, we focused on creating a boosting model that could generate parsimonious (less complex) and human-understandable rules to make them interpretable and facilitate a natural human evaluation of them. Boosting algorithms generally combine a series of weak learners that are iteratively added if they increment performance. We use input from a human in the loop during training to selectively remove or deselect predicates which are candidates for inclusion in the boosting model. After multiple iterations of selection of predicates and removal or deselection of some of them, we arrive at a final trained boosting model, which is defined as a set of predicates and associated weights.
At step 24, we then proceed to evaluate the finally trained boosting model. As shown in
Referring now to
At step 202, a large number of new random predicates are generated or selected. For example, 5,000 new random predicates are generated. Since the number of potential predicates can be very large, as they are the cross product of the number of tokens/features, feature values and different times, we do not generate all possible predicates per round. The actual instances of each of the rules, including the selection of variables, value thresholds and time-thresholds were generated as follows. First, pick a random patient (alternating between those with a positive or negative label as for some coding tasks the positive very rare), a random variable X, and a random time T in the patient's timeline. Time is chosen by index since events are not uniformly spaced. V is the corresponding value of X at time T and C is the counts of times X occurs in the patient's timeline. Thus, if for a picked patient, if for feature X they had M tuples, pick j uniformly from [0, M−1] to locate the tuple {X, x(j), t(j)} then T=t(j) and V=x(h).
Then, generate all possible predicate types using these values. Alternatively, we could restrict the model to use only the predicate type of Group 1, to gain interpretability in the final model. Note that here it is possible to design the selection of predicates which are used to generate the model by human input so as to increase the interpretability and trustworthiness of the model.
At step 204, we then score each of the 5,000 random predicates by weighted information gain with respect to a class label associated with a prediction of the boosting model (e.g., inpatient mortality, discharge billing code, etc.). The weights for each sample (patient EHR) come from computing the probability p of the sample given the current boosting model. The importance q is then q=|label−prediction|. This means that samples that the boosting model makes errors on are more important in the current boosting round. Using the importance q and the label of the samples, one can then compute the weighted information gain of the candidate predicates with respect to the label and the current boosting model. Alternatively, one can select predicates randomly and then perform a gradient step with L1 regularization. Another method is to sample groups of predicates and evaluate for information gain, in accordance with the methods described in https://en.wikipedia.org/wiki/Information_gain_in_decision_trees or use techniques described in the paper of Trivedi et al., An Interactive Tool for Natural Language Processing on Clinical Text, arXiv:1707.01890 [cs.HC] (July 2017).
At step 206 we select a number of the new random predicates with the highest weighted information gain on a given prediction task, such as 5, 10 or 20 of them.
At step 208 we then preform a gradient fit to compute weights for all predicates. At step 208 we using gradient descent with log loss and L1 regularization to compute the new weights for all previous and newly added predicates. We use the FOBOS algorithm to perform the fit, see the paper of Duchi and Singer, Efficient Online and Batch Learning Using Forward Backward Splitting, J. Mach. Learn. Res. (2009).
At step 210, we then remove selected new predicates in response to operator input. In particular, an expert such as a physician 212 operating a computer 214 views the randomly selected predicates with the highest information gain and then removes those that are deemed not trustworthy or causally unrelated to the prediction task of the model. For example, if one of the predicates was “number_of_breakfasts” and the prediction task is inpatient mortality, the operator may choose to deselect that predicate because it is not causally connected to whether the patient is at risk of inpatient mortality.
In one embodiment, we show the predicates to a human (212) in an interface on the computer 214 that allows them to delete predicates based on a loose criteria of “trustworthiness,” which we defined as whether the human participant believes that the predicate strongly relates to the task at hand. In this “human-in-the-loop” we prefer to build the model in the method of
Additionally, it is possible to have the user interface of the workstation include a tool, such as box for entry of text, where the operator can define a predicate during building of the boosting model. For example, at steps 206 or 210 the operator could insert a new predicate and it is added to the boosting model.
At step 216 there is a check to see if the training process complete, and normally the process loops back after the first iteration using No branch and loop 218 is taken to repeat steps 202, 204, 206, 208 and 210 multiple times, such as ten or twenty times. Each iteration through the loop 218 results in the gradually buildup of more and more predicates. Each predicate has a high weighted information gain score (from step 204), and with an inspection and possible deselection of some predicates by the human operator in step 210. Accordingly, the methodology gradually builds up an accurate, trustworthy and interpretable model. Moreover, by virtue of the design and selection of human understandable predicates, and the human inspection and possible removal of predicates that lack sufficient trustworthiness, the methodology results in a final generated boosted model that is interpretable to end-users and overcomes the problems with the prior art.
After a sufficient number of boosting rounds (loop 218) have been performed, for example the performance metrics meet expected criteria, the yes branch 220 is taken and the process proceeds to the evaluation step 24 of
As noted previously, the evaluation can take the form of human evaluation of the model for trustworthiness, complexity (did the model have a reasonable number of features), and accuracy. For measurements of accuracy one can investigate how the model performed on a test set relative to other models generated from the data, as well as the use of test metrics such as the area under a receiver operating characteristic curve (AUROC), a known performance metric in machine learning.
In order to analyze the performance of the model built in accordance with
In one embodiment, the evaluation step 24 could consist of the following:
- 1. Accuracy. We used the AUROC for performance of the model on a validation set.
- 2. Complexity. We counted the number of predicates at the end of training.
- 3. Trustworthiness. For each task, we randomly picked X predicates from each of the models (inpatient mortality, diagnosis at discharge). We had a physician evaluate each predicate from a scale of 1 to 3, with 1 indicating a predicate was not related to the task at hand (e.g. an antibiotic not related to heart failure) to 3, indicating a predicate was strongly related to the task. We report the “Trust Score” or trustworthiness of a model by the averaged score of all its predicates.
As noted previously, one of the ways of evaluation of the model generated in accordance with
In
Our interactive visualization allows a user to dynamically explore the learned predicates by choosing from several sorting and coloring options. In
Example Text Interface for Training
The workstation 214 can provide a text interface for the operator/expert to use during model training. This section will provide an example of a text interface for building a model for prediction of congestive heart failure as the diagnosis at discharge.
Each line represents a predicate in the model. The information at the beginning of each line is the meta-info about each predicate: its index, human decision about whether to keep it, a visual tag for human indicating whether it is a new predicate, and the predicate weight. The second part of the each line is the predicate itself. “E” means the existence of a feature, and “#” means the count of a feature with a threshold. “TRUE” simply captures the bias of the label in the data set. In the example below, the human decides to ‘delete’ the predicate at index 2, since the feature count's threshold is not trustworthy. This model is very simple, because this is at the very beginning of model training; later the model will become much larger and more complex. Since the model is composed of a set of predicates, it is still possible for human to inspect the whole model, e.g., by scrolling through the lines or by use of visualization techniques such as show in
- [0, Y, −, 0.0244] E:obsloinc:33762-6 pg/mL (Natriuretic peptide.B prohormone N-Terminal)
- [1, Y, −, 0.0240] E:Composition.section.text.div.tokenized failure
- [2, Y, −, 0.0237] #:Composition.section.text.div.tokenized ventricular>=11
- [3, Y, −, 0.0237] E:Composition.section.text.div.tokenized congestive
- [4, Y, −, 0.0232] #:Composition.section.text.div.tokenized regurgitation>=3
- [5, Y, −, 0.0232] E:Observation.code.Ioinc.display.tokenized
- [6, Y, −, 0.0228] #:Composition.section.text.div.tokenized exertion>=2
- [7, Y, −, 0.0224] E:Composition.section.text.div.tokenized lasix
- [8, Y, −, 0.0220] E:Composition.section.text.div.tokenized la
- [9, Y, −, 0.0216] E:Composition.section.text.div.tokenized regurgitation
- [10, Y, −, 0.0206] Context age_in_years>=60.000000 @ t<=1.000000
- [11, Y, −, −0.0101] E:Context Patient.gender male
- [12, Y, −, −0.0220] Context age_in_years>=40.000000 @ t<=1.000000
- [13, Y, −, −0.0244] Context age_in_years>=18.000000 @ t<=1.000000
- [14, Y, −, −0.0256] E:Context Patient.genderfemale
- [15, Y, −, −3.3718] TRUE
- New Model Test Score: 0.883712, Rules: 16
- BOOST>delete 2
A user interface for interactive model training in accordance with
A header bar 802 which identifies the current model labeling or prediction task (in this case prediction of acute myocardial infarction). The header bar 802 also includes some statistics shown at the right hand edge of the bar about the current session, available at a glance, such as loss and area under the curve of a receiver operator characteristics plot.
A content area 804 which provides the display of tools for modifying learner behavior and working with predicates (i.e., selecting or deselecting predicates), and showing statistics such as weight of predicates, see description of
A control bar 806 which provides for the display of tools for requesting and saving models and a history of user actions in the current session.
A timeline 808 which summarizes the user's session with the learner by showing performance and model size metrics.
The content area 804 is a scrollable region containing “cards” (individual graphical display regions) that drive the bulk of the interaction between the user and the learner. There are two kinds or types of cards, Setting Cards and Predicate Cards. In
The Predicate Cards are shown in
Note in
The number and identification of categories of predicates can of course vary, but in the present context the following categories are recognized: demographics, doctor notes, medications, lab results, nurse observations, previous conditions, admission/discharge and medical procedures. If a predicate does not fit into one of these categories it is placed in a further category called Other.
The Timeline shown at the bottom of
While the interface of
Results
In our work we have developed models using the procedure of
We explored the effects of the use of predicates of type Group 2 (more complex, less human understandable) in the training of the purely machine learning model (“MM”) versus the use of Group 1 (less complex, more human understandable) predicates in the human-in-the-loop model (“HM”). We found that the effect of using Group 2 predicates depends on the nature of the prediction tasks. For the tasks of predicting discharge diagnosis code, the gap between two different MM models, one with both Group 1 and Group 2 predicates (MM1) and one using just Group 1 predicates (existence and counts predicates) MM2 is rather insignificant. For example, in one discharge code task, using the AUROC metric, MM1 achieves 0.910 vs MM2's 0.896 (a gap of 0.4%). In another discharge code task, the comparison is 0.916 vs 0.914 (a gap of 0.2%). In the more complex task of mortality prediction, the gap is somewhat significant, i.e. 0.791 vs 0.814 (a gap of 2.3%). However, since one of the goals of the present inventive method is to improve on model interpretability, machine models which use the simple predicate types are preferred otherwise it is very hard for a human to understand the model. This shows the tradeoff of model quality and interpretability, but we believe it is a good tradeoff to make in the medical domain, since interpretability and trustworthiness are extremely important.
We also explored the effect of putting the human in the loop and comparing the performance of the resulting model (HM1, constructed per
We have two general observations about human behavior in this process: 1) The domain expert makes decision about whether to keep or delete a predicate based mostly on trustworthiness. Under this mindset, the expert is acting on behalf of the end-users who will use this model. 2) We have a mechanism to evaluate the current model on demand, in order to help the human make decisions. However, we observe that the expert almost never relies on that in making decisions. This may explain why the HM1 model got a much higher “trust score”, as shown below. Table 1 shows the quality (AUROC curve), size and trust scores for the three models in a task of classifying congestive heart failure as the diagnosis at discharge.
Similar quality and size results were obtained for a task of classifying Dysrhythmia as a diagnosis at discharge (CCS code 106). From model quality perspective, the human model (HM1) is very comparable with the machine models (MM2 and MM3) in the two coding tasks. In the more challenging task of predicting inpatient mortality, the HM model did worse (˜5%) than MM2, and is comparable with MM3. In this task, the model was not able to suggest very interpretable predicates, and hence they are frequently deleted by human, leading to an overly small model with only 23 predicates.
From model size perspective, the human model is much smaller than the machine model (MM2). Having a smaller model allows others to inspect the model more easily; it is not strictly required but it is highly desirable, especially in the medical domain.
The most striking results is the “Trust Score” of different models. The human expert model (HM1) is rated much higher in the model's trustworthiness, which is a very desirable result. When we prune the machine model's predicate to only include the ones with highest weights (MM3), its “Trust Score” also improved (from 1.70 to 1.97), suggesting that the machine model associates higher weights for the more trustworthy predicates. Nevertheless, given the much higher “Trust Score” of the human model (HM1), its smaller model size, and comparable quality, HM1 demonstrates that our objective of obtaining an interpretable, trustworthy machine learning model has been achieved.
Further Considerations
In order to further assist the user in probing and improving the model during model training, it may be desirable to add additional features to the workstation of
As another example, some more complex predicates may be initially difficult to understand even to an expert, but they may be rendered in graphical form which increases understanding by the expert and may allow them to choose them for inclusion in the model.
Additionally, many predicates may be redundant and it is preferable to select and use for model building a particular one based on its greater ability to be understood by the end-user. In order to reduce the amount of time needed to build the model it is preferable to delete or remove from the training process not only the redundant predicates but also those that the human would delete anyway, for example irrelevant ones or ones that are not human understandable.
Also, it is possible to rank the predicates such that more specific predicates have a higher priority. For example, lab test results could be preferred or ranked higher than a lab test name predicate. This can be done by using some policy rules and adjusting the weighted information scores (or the weights for the model) during the iterations of
Additionally, it may be preferable to use bigrams (two words) over unigrams (one word) in predicates obtained from unstructured medical notes because bigrams provide more context and make the predicate easier to understand. The bigrams could be weighted or scored using policy rules or otherwise. Furthermore, the user interface of the workstation of
Other preferences could be defined, either as predicates defined by the user during the iterations of
Additionally, to aid the user in deciding to select or deselect predicates, or define new predicates for use by the model, it may be useful to provide statistics to assist the user. For example, one can define “coverage” as the number of examples for which a particular predicate is true, “precision” as the number of examples with a true label for which this predicate is true divided by coverage, and “recall” as the number of examples with a true label for which this predicate is true divided by the number of examples with a true label, and a correlation between the predicate and the label.
It is desirable to be able to build the models quickly by distributing the processing task among several servers or computing platforms during model training with a goal of reducing fatigue on the human-in-the-loop. Basically, with reference to
Another enhancement to the method is to reduce the time periods (sequence of tuples in the defined predicates) to human-friendly time periods, such as the last hour, the last day, the last week, the last month, instead of arbitrary time periods.
Claims
1. A computer-implemented method of training a predictive model from data comprising a multitude of features, each feature associated with a real value and a time component, comprising the steps of executing the following instructions in a processor of the computer:
- a) defining a multitude of predicates as binary functions operating on time sequences of the features or logical operations on the time sequences of the features;
- b) iteratively training a boosting model by performing the following:
- 1) generating a number of new random predicates;
- 2) scoring all the new random predicates by weighted information gain with respect to a class label associated with a prediction of the boosting model;
- 3) selecting a number of the new random predicates with the highest weighted information gain and adding them to the boosting model;
- 4) computing weights for all the predicates in the boosting model;
- 5) removing one or more of the selected new predicates with the highest information gain from the boosting model in response to input from an operator; and
- 6) repeating the performance of steps 1, 2, 3, 4 and 5 a plurality of times and thereby generating a final iteratively trained boosting model.
2. The method of claim 1, further comprising the step of c) evaluating the final iteratively trained boosting model.
3. The method of claim 1, wherein the data is in a tuple format of the type {X, xi, ti} where X is the name of feature, xi is a real value of the feature and ti is a time component for the real value xi, and wherein the predicates are defined as binary functions operating on sequences of tuples or logical operations on sequences of the tuples.
4. The method of claim 1, wherein the data comprises electronic health record data for a multitude of patients.
5. The method of claim 1, wherein the method further comprises the step of dividing the predicates into groups based on understandability, namely a first group of relatively more human understandable predicates and a second group of relatively less human understandable predicates and wherein the new random predicates are selected from the first group.
6. The method of claim 1, wherein step b) 5) further comprises the step of graphically representing the predicates currently in the boosting model and providing the operator with the ability to remove one or more of the predicates.
7. The method of claim 1, further comprising the step of graphically representing a set of predicates added to the boosting model after each of the iterations of step b) 6).
8. The method of claim 6, further comprising the step of graphically representing the weights computed for each of the predicates in step b) 4).
9. The method of claim 5, wherein the data comprises electronic health record data for a multitude of patients, and wherein the set of predicates are represented in a manner to show the subject matter or source within the electronic health record data of the predicate.
10. The method of claim 2, wherein the evaluation step (c) comprises evaluating the final iteratively trained boosting model for at least one of accuracy, complexity, or trustworthiness.
11. The method of claim 9, wherein the predicates comprise an existence predicate returning a result of 0 or 1 depending on whether a feature exists in the electronic health record data for a given patient in the multitude of patients; and a counts predicate returning a result of 0 or 1 depending on the number of counts of a feature in the electronic health record data for a given patient in the multitude of patients relative to a numeric parameter C.
12. The method of claim 1, wherein step b) further comprises the step of providing the operator with the ability to define a predicate during model training.
13. The method of claim 1, wherein step b) further comprises the step of removing redundant predicates.
14. The method of claim 3, wherein the sequences of tuples are defined by time periods selected from the group consisting of 1 or more days, 1 or more hours, 1 or more minutes, or 1 or more months.
15. The method of claim 1, further comprising the step of ranking the predicates selected in step b) 3).
16. The method of claim 1, further comprising the step of generating statistics of predicates in the boosting model and presenting them to the operator.
17. A computer-implemented method of training a predictive model from electronic health record data for a multitude of patients, the data comprising a multitude of features, each feature associated with real values and a time component, wherein the data is in a tuple format of the type {X, xi, ti} where X is the name of feature, xi is a real value of the feature and ti is a time component for the real value xi, comprising the steps of implementing the following instructions in a processor of the computer:
- a) defining a multitude of predicates as binary functions operating on sequences of the tuples or logical operations on the sequences of the tuples;
- b) dividing the multitude of predicates into groups based on understandability, namely a first group of relatively more human understandable predicates and a second group of relatively less human understandable predicates;
- c) iteratively training a boosting model by performing the following: 1) generating a number of new random predicates from the first group of predicates; 2) scoring all the new random predicates by weighted information gain with respect to a class label associated with a prediction of the boosting model; 3) selecting a number of the new random predicates with the highest weighted information gain and adding them to the boosting model; 4) computing weights for all the predicates in the boosting model; 5) removing one or more of the selected new predicates with the highest information gain from the boosting model in response to input from an operator; and 6) repeating the performance of steps 1, 2, 3, 4 and 5 a plurality of times and thereby generating a final iteratively trained boosting model.
18. The method of claim 17, further comprising the step d) of evaluating the final iteratively trained boosting model.
19. In a computing platform implementing a machine learning model, the improvement comprising: the machine learning model comprises an iteratively trained boosted model built from predicates defined as binary functions operating on sequences of features having both a real value and time component, wherein the predicates are defined with operator input and wherein the selection of predicates for inclusion in the iteratively trained boosted model are subject to review and selection or deselection by an operator during iterative training of the boosting model.
20. The improvement of claim 19, wherein the features comprise features in electronic health records.
21. A workstation for providing operator input into iteratively training a boosting model, wherein the workstation comprises an interface displaying predicates selected as having a weighted information gain for making a prediction of the boosting model, and the interface providing a tool for selection or deselection of one or more of the predicates in the boosting model.
22. The workstation of claim 21, wherein predicates are defined as binary functions operating on sequences of features having both a real value component and a time component or logical operations on sequences of the features.
23. The workstation of claim 21, wherein the interface further comprises a tool for allowing an operator to define a predicate.
24. A computer-implemented method of generating a predictive model from training data, the predictive model being for predicting a label based on input data which, for each of a plurality of features X, indicates a value x of the feature at each of a plurality of times, and the training data comprising a plurality of samples, each sample indicating the value of one or more of the features at each of one of more times and a corresponding label;
- the method comprising implementing the following steps as instructions with a processor:
- defining a set of predicates, each predicate being a function which generates an output when applied to time sequences of the features or logical combinations of the time sequences of the features;
- generating a boosting model, the boosting model receiving as input the respective outputs of each of the set of predicates when applied to the samples of the training data; and
- performing a plurality of times the sequence of steps of:
- (i) automatically generating a plurality of additional predicates;
- (ii) adding the plurality of additional predicates to predicates already in the boosting model to form an updated set of predicates;
- (iii) displaying a plurality of the updated set of predicates; and
- (iv) receiving data input rejecting one or more of the updated set of predicates; and
- (v) removing the rejected one or more predicates from the updated set of predicates.
25. A method according to claim 24 in which the step (i) of automatically generating the plurality of additional predicates comprises:
- (a) generating candidate predicates by a pseudo-random algorithm;
- (b) scoring the candidate predicates for weighted information gain in the boosting model;
- (c) selecting the additional predicates from the candidate predicates based on the scores.
26. The method according to claim 24, wherein the output of each predicate is a binary value.
27. The method according to claim 24, wherein each sample in the training data is formatted as a plurality of data items having a tuple format of the type {X, x, ti), where xi indicates the value of feature X at a time ti, and i labels the tuple of the sample, each predicate being a function performed on a plurality of data items of the sample.
28. The method according to claim 24 in which the training data comprises electronic health record data for a plurality of patients.
29. The method of claim 24, in which each predicate is a function of a part of the sample relating to a single corresponding one of the features.
30. The method of claim 29, in which the additional predicates comprise at least one of existence predicates which are each indicative of a specific feature taking a value in a specific range at at least one time, and count predicates which are each indicative of a specific feature taking a value in a specific range at more than, less than, or equal to a specific number of times C.
31. The method according to claim 24, in which the features are each associated with a corresponding one of a set of human understandable categories or groups, and step (iv) of displaying a plurality of the set of predicates includes displaying grouped together the predicates which are functions of data relating to features of each category or group.
32. The method of claim 24, in which step (iv) of displaying a plurality of the set of predicates includes displaying a respective weight value of the regenerated boosting model.
33. The method of claim 24, further comprising evaluating the accuracy of the boosting model in predicting the label using a validation sub-set of the training data.
34. The method of claim 1, wherein in step b) 5) the one or more predicates are removed which are not causally related to the prediction of the boosting model.
35. The improvement of claim 18 wherein the predicates deselected by an operator are not causally related to a prediction of the boosting model.
36. The workstation of claim 22, wherein the predicates deselected by an operator are not causally related to a prediction of the boosting model.
37. The method of claim 24, wherein the rejected one or more of the updated set of predicates are not causally related to a prediction of the boosting model.
Type: Application
Filed: Sep 29, 2017
Publication Date: Nov 18, 2021
Patent Grant number: 12191007
Inventors: Kai Chen (San Bruno, CA), Eyal Oren (Los Gatos, CA), Hector Yee (Mountain View, CA), James Wilson (Littleton, MA), Alvin Rajkomar (Mountain View, CA), Michaela Hardt (Mountain View, CA)
Application Number: 16/618,656