Machine learning is very good at optimizing predictions to match an observed signal — for instance, given a dataset of input images and labels of the images (e.g. dog, cat, etc.), machine learning is very good at correctly predicting the label of a new image. However, performance can quickly break down as soon as we care about criteria other than predicting observables. There are several cases where we might care about such criteria:
- In scientific investigations, we often care less about predicting a specific observable phenomenon, and more about what that phenomenon implies about an underlying scientific theory.
- In economic analysis, we are most interested in what policies will lead to desirable outcomes. This requires predicting what would counterfactually happen if we were to enact the policy, which we (usually) don’t have any data about.
- In machine learning, we may be interested in learning value functions which match human preferences (this is especially important in complex settings where it is hard to specify a satisfactory value function by hand). However, we are unlikely to observe information about the value function directly, and instead must infer it implicitly. For instance, one might infer a value function for autonomous driving by observing the actions of an expert driver.
In all of the above scenarios, the primary object of interest — the scientific theory, the effects of a policy, and the value function, respectively — is not part of the observed data. Instead, we can think of it as an unobserved (or “latent”) variable in the model we are using to make predictions. While we might hope that a model that makes good predictions will also place correct values on unobserved variables as well, this need not be the case in general, especially if the model is mis-specified.
I am interested in latent variable inference because I think it is a potentially important sub-problem for building AI systems that behave safely and are aligned with human values. The connection is most direct for value learning, where the value function is the latent variable of interest and the fidelity with which it is learned directly impacts the well-behavedness of the system. However, one can imagine other uses as well, such as making sure that the concepts that an AI learns sufficiently match the concepts that the human designer had in mind. It will also turn out that latent variable inference is related to counterfactual reasoning, which has a large number of tie-ins with building safe AI systems that I will elaborate on in forthcoming posts.
The goal of this post is to explain why problems show up if one cares about predicting latent variables rather than observed variables, and to point to a research direction (counterfactual reasoning) that I find promising for addressing these issues. More specifically, in the remainder of this post, I will: (1) give some formal settings where we want to infer unobserved variables and explain why we can run into problems; (2) propose a possible approach to resolving these problems, based on counterfactual reasoning.
1 Identifying Parameters in Regression Problems
Suppose that we have a regression model Image may be NSFW.
Clik here to view., which outputs a probability distribution over Image may be NSFW.
Clik here to view. given a value for Image may be NSFW.
Clik here to view.. Also suppose we are explicitly interested in identifying the “true” value of Image may be NSFW.
Clik here to view. rather than simply making good predictions about Image may be NSFW.
Clik here to view. given Image may be NSFW.
Clik here to view.. For instance, we might be interested in whether smoking causes cancer, and so we care not just about predicting whether a given person will get cancer (Image may be NSFW.
Clik here to view.) given information about that person (Image may be NSFW.
Clik here to view.), but specifically whether the coefficients in Image may be NSFW.
Clik here to view. that correspond to a history of smoking are large and positive.
In a typical setting, we are given data points Image may be NSFW.
Clik here to view. on which to fit a model. Most methods of training machine learning systems optimize predictive performance, i.e. they will output a parameter Image may be NSFW.
Clik here to view. that (approximately) maximizes Image may be NSFW.
Clik here to view.. For instance, for a linear regression problem we have Image may be NSFW.
Clik here to view.. Various more sophisticated methods might employ some form of regularization to reduce overfitting, but they are still fundamentally trying to maximize some measure of predictive accuracy, at least in the limit of infinite data.
Call a model well-specified if there is some parameter Image may be NSFW.
Clik here to view. for which Image may be NSFW.
Clik here to view. matches the true distribution over Image may be NSFW.
Clik here to view., and call a model mis-specified if no such Image may be NSFW.
Clik here to view. exists. One can show that for well-specified models, maximizing predictive accuracy works well (modulo a number of technical conditions). In particular, maximizing Image may be NSFW.
Clik here to view. will (asymptotically, as Image may be NSFW.
Clik here to view.) lead to recovering the parameter Image may be NSFW.
Clik here to view..
However, if a model is mis-specified, then it is not even clear what it means to correctly infer Image may be NSFW.
Clik here to view.. We could declare the Image may be NSFW.
Clik here to view. maximizing predictive accuracy to be the “correct” value of Image may be NSFW.
Clik here to view., but this has issues:
- While Image may be NSFW.
Clik here to view.might do a good job of predicting Image may be NSFW.
Clik here to view.in the settings we’ve seen, it may not predict Image may be NSFW.
Clik here to view.well in very different settings.
- If we care about determining Image may be NSFW.
Clik here to view.for some scientific purpose, then good predictive accuracy may be an unsuitable metric. For instance, even though margarine consumption might correlate well with (and hence be a good predictor of) divorce rate, that doesn’t mean that there is a causal relationship between the two.
The two problems above also suggest a solution: we will say that we have done a good job of inferring a value for Image may be NSFW.
Clik here to view. if Image may be NSFW.
Clik here to view. can be used to make good predictions in a wide variety of situations, and not just the situation we happened to train the model on. (For the latter case of predicting causal relationships, the “wide variety of situations” should include the situation in which the relevant causal intervention is applied.)
Note that both of the problems above are different from the typical statistical problem of overfitting. Clasically, overfitting occurs when a model is too complex relative to the amount of data at hand, but even if we have a large amount of data the problems above could occur. This is illustrated in the following graph:
Image may be NSFW.
Clik here to view.
Here the blue line is the data we have (Image may be NSFW.
Clik here to view.), and the green line is the model we fit (with slope and intercept parametrized by Image may be NSFW.
Clik here to view.). We have more than enough data to fit a line to it. However, because the true relationship is quadratic, the best linear fit depends heavily on the distribution of the training data. If we had fit to a different part of the quadratic, we would have gotten a potentially very different result. Indeed, in this situation, there is no linear relationship that can do a good job of extrapolating to new situations, unless the domain of those new situations is restricted to the part of the quadratic that we’ve already seen.
I will refer to the type of error in the diagram above as mis-specification error. Again, mis-specification error is different from error due to overfitting. Overfitting occurs when there is too little data and noise is driving the estimate of the model; in contrast, mis-specification error can occur even if there is plenty of data, and instead occurs because the best-performing model is different in different scenarios.
2 Structural Equation Models
We will next consider a slightly subtler setting, which in economics is referred to as a structural equation model. In this setting we again have an output Image may be NSFW.
Clik here to view. whose distribution depends on an input Image may be NSFW.
Clik here to view., but now this relationship is mediated by an unobserved variable Image may be NSFW.
Clik here to view.. A common example is a discrete choice model, where consumers make a choice among multiple goods (Image may be NSFW.
Clik here to view.) based on a consumer-specific utility function (Image may be NSFW.
Clik here to view.) that is influenced by demographic and other information about the consumer (Image may be NSFW.
Clik here to view.). Natural language processing provides another source of examples: in semantic parsing, we have an input utterance (Image may be NSFW.
Clik here to view.) and output denotation (Image may be NSFW.
Clik here to view.), mediated by a latent logical form Image may be NSFW.
Clik here to view.; in machine translation, we have input and output sentences (Image may be NSFW.
Clik here to view. and Image may be NSFW.
Clik here to view.) mediated by a latent alignment (Image may be NSFW.
Clik here to view.).
Symbolically, we represent a structural equation model as a parametrized probability distribution Image may be NSFW.
Clik here to view., where we are trying to fit the parameters Image may be NSFW.
Clik here to view.. Of course, we can always turn a structural equation model into a regression model by using the identity Image may be NSFW.
Clik here to view., which allows us to ignore Image may be NSFW.
Clik here to view. altogether. In economics this is called a reduced form model. We use structural equation models if we are specifically interested in the unobserved variable Image may be NSFW.
Clik here to view. (for instance, in the examples above we are interested in the value function for each individual, or in the logical form representing the sentence’s meaning).
In the regression setting where we cared about identifying Image may be NSFW.
Clik here to view., it was obvious that there was no meaningful “true” value of Image may be NSFW.
Clik here to view. when the model was mis-specified. In this structural equation setting, we now care about the latent variable Image may be NSFW.
Clik here to view., which can take on a meaningful true value (e.g. the actual utility function of a given individual) even if the overall model Image may be NSFW.
Clik here to view. is mis-specified. It is therefore tempting to think that if we fit parameters Image may be NSFW.
Clik here to view. and use them to impute Image may be NSFW.
Clik here to view., we will have meaningful information about the actual utility functions of individual consumers. However, this is a notational sleight of hand — just because we call Image may be NSFW.
Clik here to view. “the utility function” does not make it so. The variable Image may be NSFW.
Clik here to view. need not correspond to the actual utility function of the consumer, nor does the consumer’s preferences even need to be representable by a utility function.
We can understand what goes wrong by consider the following procedure, which formalizes the proposal above:
- Find Image may be NSFW.
Clik here to view.to maximize the predictive accuracy on the observed data, Image may be NSFW.
Clik here to view., where Image may be NSFW.
Clik here to view.. Call the result Image may be NSFW.
Clik here to view..
- Using this value Image may be NSFW.
Clik here to view., treat Image may be NSFW.
Clik here to view.as being distributed according to Image may be NSFW.
Clik here to view.. On a new value Image may be NSFW.
Clik here to view.for which Image may be NSFW.
Clik here to view.is not observed, treat Image may be NSFW.
Clik here to view.as being distributed according to Image may be NSFW.
Clik here to view..
As before, if the model is well-specified, one can show that such a procedure asymptotically outputs the correct probability distribution over Image may be NSFW.
Clik here to view.. However, if the model is mis-specified, things can quickly go wrong. For example, suppose that Image may be NSFW.
Clik here to view. represents what choice of drink a consumer buys, and Image may be NSFW.
Clik here to view. represents consumer utility (which might be a function of the price, attributes, and quantity of the drink). Now suppose that individuals have preferences which are influenced by unmodeled covariates: for instance, a preference for cold drinks on warm days, while the input Image may be NSFW.
Clik here to view. does not have information about the outside temperature when the drink was bought. This could cause any of several effects:
- If there is a covariate that happens to correlate with temperature in the data, then we might conclude that that covariate is predictive of preferring cold drinks.
- We might increase our uncertainty about Image may be NSFW.
Clik here to view.to capture the unmodeled variation in Image may be NSFW.
Clik here to view..
- We might implicitly increase uncertainty by moving utilities closer together (allowing noise or other factors to more easily change the consumer’s decision).
In practice we will likely have some mixture of all of these, and this will lead to systematic biases in our conclusions about the consumers’ utility functions.
The same problems as before arise: while we by design place probability mass on values of Image may be NSFW.
Clik here to view. that correctly predict the observation Image may be NSFW.
Clik here to view., under model mis-specification this could be due to spurious correlations or other perversities of the model. Furthermore, even though predictive performance is high on the observed data (and data similar to the observed data), there is no reason for this to continue to be the case in settings very different from the observed data, which is particularly problematic if one is considering the effects of an intervention. For instance, while inferring preferences between hot and cold drinks might seem like a silly example, the design of timber auctions constitutes a much more important example with a roughly similar flavour, where it is important to correctly understand the utility functions of bidders in order to predict their behaviour under alternative auction designs (the model is also more complex, allowing even more opportunities for mis-specification to cause problems).
3 A Possible Solution: Counterfactual Reasoning
In general, under model mis-specification we have the following problems:
- It is often no longer meaningful to talk about the “true” value of a latent variable Image may be NSFW.
Clik here to view.(or at the very least, not one within the specified model family).
- Even when there is a latent variable Image may be NSFW.
Clik here to view.with a well-defined meaning, the imputed distribution over Image may be NSFW.
Clik here to view.need not match reality.
We can make sense of both of these problems by thinking in terms of counterfactual reasoning. Without defining it too formally, counterfactual reasoning is the problem of making good predictions not just in the actual world, but in a wide variety of counterfactual worlds that “could” exist. (I recommend this paper as a good overview for machine learning researchers.)
While typically machine learning models are optimized to predict well on a specific distribution, systems capable of counterfactual reasoning must make good predictions on many distributions (essentially any distribution that can be captured by a reasonable counterfactual). This stronger guarantee allows us to resolve many of the issues discussed above, while still thinking in terms of predictive performance, which historically seems to have been a successful paradigm for machine learning. In particular:
- While we can no longer talk about the “true” value of Image may be NSFW.
Clik here to view., we can say that a value of Image may be NSFW.
Clik here to view.is a “good” value if it makes good predictions on not just a single test distribution, but many different counterfactual test distributions. This allows us to have more confidence in the generalizability of any inferences we draw based on Image may be NSFW.
Clik here to view.(for instance, if Image may be NSFW.
Clik here to view.is the coefficient vector for a regression problem, any variable with positive sign is likely to robustly correlate with the response variable for a wide variety of settings).
- The imputed distribution over a variable Image may be NSFW.
Clik here to view.must also lead to good predictions for a wide variety of distributions. While this does not force Image may be NSFW.
Clik here to view.to match reality, it is a much stronger condition and does at least mean that any aspect of Image may be NSFW.
Clik here to view.that can be measured in some counterfactual world must correspond to reality. (For instance, any aspect of a utility function that could at least counterfactually result in a specific action would need to match reality.)
- We will successfully predict the effects of an intervention, as long as that intervention leads to one of the counterfactual distributions considered.
(Note that it is less clear how to actually train models to optimize counterfactual performance, since we typically won’t observe the counterfactuals! But it does at least define an end goal with good properties.)
Many people have a strong association between the concepts of “counterfactual reasoning” and “causal reasoning”. It is important to note that these are distinct ideas; causal reasoning is a type of counterfactual reasoning (where the counterfactuals are often thought of as centered around interventions), but I think of counterfactual reasoning as any type of reasoning that involves making robustly correct statistical inferences across a wide variety of distributions. On the other hand, some people take robust statistical correlation to be the definition of a causal relationship, and thus do consider causal and counterfactual reasoning to be the same thing.
I think that building machine learning systems that can do a good job of counterfactual reasoning is likely to be an important challenge, especially in cases where reliability and safety are important, and necessitates changes in how we evaluate machine learning models. In my mind, while the Turing test has many flaws, one thing it gets very right is the ability to evaluate the accuracy of counterfactual predictions (since dialogue provides the opportunity to set up counterfactual worlds via shared hypotheticals). In contrast, most existing tasks focus on repeatedly making the same type of prediction with respect to a fixed test distribution. This latter type of benchmarking is of course easier and more clear-cut, but fails to probe important aspects of our models. I think it would be very exciting to design good benchmarks that require systems to do counterfactual reasoning, and I would even be happy to incentivize such work monetarily.
Acknowledgements
Thanks to Michael Webb, Sindy Li, and Holden Karnofsky for providing feedback on drafts of this post. If any readers have additional feedback, please feel free to send it my way.