The Unreasonable Effectiveness of Overparametrized Models*
As Bartlett and his coauthors argue, deep neural networks display at least three properties that might be surprising to most practitioners with a background in statistics or even in classical machine learning (ML):
They find solutions to highly non-convex optimization problems (with a plethora of local minima) using a local, gradient-based optimization algorithm (gradient descent or any of its most powerful descendants, like Adam).
The solution is almost perfect in the training data.
They display remarkable prediction power on the test data, without the need to control for model complexity.
As a matter of fact, as Yann LeCun has repeated many times, the third point (overparameterization) appears to be critical for the first one.1
In this post I’m going to discuss the relevance of this and other related facts about overparameterization.
What is overparameterization
In it’s simplest incarnation a model is overparameterized when the number of parameters is larger than the sample size. Let’s start with a linear regression of the form:
Here we have K parameters, and we collect a training sample of size N. We train the model using ordinary least squares (OLS):
So for the estimator to exist it’s both necessary and sufficient that the NxK matrix X is of full rank. If your model is overparameterized (N<K), the feature matrix will be at most of rank N and thus you won’t be able to estimate the model.2
You can use the following code snippet to play with different scenarios of under or overparameterization in linear regression. Note that in Python, np.linalg.inv(A)
always returns an answer, even when the matrix is non-invertible! However, you can check that the results are not correct and that you end up computing non-sensical results with overidentified models.
# simulate a linear model y = xbeta + epsilon
np.random.seed(1626)
# set sample size and number of parameters
N = 10
k = 1.1 # if k < (>) 1 you have an under(over)parametrized model
K = int(k*N)
x = np.random.randn(N,K)
xtx = np.matmul(x.T,x)
beta = np.random.randint(low = -20, high = 20, size=(K,1))
y = np.dot(x,beta) + np.sqrt(0.1)*np.random.randn(N,1)
rank_x = np.linalg.matrix_rank(x)
rank_xx = np.linalg.matrix_rank(xtx)
print(f'Size of X = {x.shape}, Rank of X = {rank_x}')
print(f"Size of X'X = {xtx.shape}, Rank of X'X = {rank_xx}")
# find OLS: note that it always exists! Even for overparameterized models :(
inv_xtx = np.linalg.inv(xtx)
beta_ols = np.dot(inv_xtx,np.dot(x.T,y))
# assemble in a Dataframe to help visualization and comparison with true values
df = pd.DataFrame(beta_ols, columns = ['OLS'])
df['true'] = pd.DataFrame(beta)
print(df)
# let's check if the inverse is an inverse! A^{-1}A = I = AA^{-1}
check_inv_l = np.matmul(inv_xtx, xtx)
check_inv_r = np.matmul(xtx,inv_xtx)
assert np.all(np.abs(check_inv_l - np.eye(K)) < 1e-10)
assert np.all(np.abs(check_inv_r - np.eye(K)) < 1e-10)
Intuitively speaking, the information provided by the N examples is not enough to estimate a larger number of parameters. As I argue in Chapter 10 of Data Science: The Hard Parts, OLS provides a useful benchmark that helps ground the intuition for many other ML models.
This demonstrates that overparameterization actually impedes learning in linear models! With non-linear or non-parametric models, we have a bit more leeway, but the more general principle of the bias-variance trade-off still applies. The idea is that increasing model complexity reduces bias but increases variance. In practice, your training error decreases as model complexity increases, but this comes at the cost of a diminished ability to generalize to other contexts (e.g., test samples or real-life scenarios when your model is in production).
Consider the case of tree-based models, where the depth of the trees serves as a measure of the complexity of the algorithm. If you reach the maximum depth (d_max = log_2(N)
), you will have perfect predictions in the training sample, but a higher prediction error out-of-sample.
Typical solutions to this problem of overfitting include some type of regularization, where we effectively penalize the loss function for excess complexity.
What happens with deep neural networks?
Neural networks fall in the class of parametric algorithms, so we can go back to thinking about model complexity in terms of parameter counting. As an example, the parameters in transformer-based models correspond to the weight matrices for the attention and the feed-forward layers.3
As Yan LeCun’s tweet suggests, deep neural nets are generally considered overparameterized, either by the above definition, or using the less strict one of having a model complexity above some threshold, for instance, when the training error is close to zero.
Several image recognition models are overparameterized, and the most famous example is probably AlexNet, that had ~61M parameters and was trained on the ImageNet dataset with ~1.2M images. The less strict overparameterization would apply to some of the best known large language models (LLMs). OpenAI’s GPT-3 had 175B parameters and was trained on a dataset consisting of 300B tokens, Meta’s Llama-2 was trained on 2T tokens and the largest model has 70B parameters, and the larger BERT had 340M parameters but was trained on 3.3B words.
Overparameterization in deep learning gives rise to some other interesting phenomena. As already mentioned, it’s been hypothesized that it helps finding good enough solutions in highly non-convex problems using gradient-based procedures.4 This can be most clearly seen with the phenomenon known as “double dip” or “double descent”.
Double descent
The next figure, taken from Figure 1 in Nakkiran, et.al. (2019), shows the typical case of double descent, where the test error first falls, then starts increasing (suggesting overfitting as in classical machine learning) and finally falls again. Notice that, by sufficiently increasing the complexity of the model, we’re able to find superior solutions, in terms of having a lower test error.
In case you’re wondering if double descent is good or bad, the answer is nuanced. On the bright side, there’s no doubt that larger models are superior in terms of performance, giving rise to models that exhibit emergent behavior for which they were not pretrained or fine-tuned. On the less bright side, training these larger models becomes extremely costly, both financially and environmentally. This leaves only a handful of companies with the ability to train them.
Scaling Laws
Several recent papers have somewhat challenged this notion that “larger is better”. In “Scaling Laws for Neural Language Models,” researchers from OpenAI empirically quantify the intuitive notion that model complexity and sample size should grow in tandem. Otherwise, one becomes a bottleneck for the other.
Among several scaling laws, they find that “dataset size may grow sub-linearly in model size while avoiding overfitting”. This implies that if the model complexity increases by 8x, the sample size should increase by 5x to avoid becoming a performance bottleneck. The sublinear requirement is great news, especially since some believe that we can soon run out of data. These unequal growth requirements support the trend in overparameterization.
However, in the later Chinchilla paper, researchers from DeepMind find that for an increase in compute budget, model and sample size should grow one-to-one. They want to answer the following question: for a given compute budget, what are the optimal model and dataset size? Alternatively, how should you optimally trade-off model complexity and training dataset?
The next table is taken from their Table 3, where, for a given model size (parameters) they find the optimal compute and dataset size. As the compute budget increases, the model and dataset size increase at roughly the same rate, keeping always a ratio of ~21:1 tokens to parameters. In other words, overparameterization itself should not change, at least under compute-optimal considerations.
What’s in a dataset
As observed, the concept of overparameterization revolves around the relative complexity of a model in relation to the size of the dataset. But what exactly constitutes a dataset?
With tabular data the dataset size is given by the number of rows. Image recognition tasks count the number of images in the training sample, and in language models, the dataset is comprised of all the words in the corpus. For instance, BERT was trained on the Toronto BookCorpus and English Wikipedia, with a total of ~3.3B words.5
Is there a better way to measure dataset size? It's possible that our models aren't as overparametrized as we believe, potentially resolving this apparent paradox. What if we consider the amount of information present in the data? Deep learning excels in representation learning, where algorithms strive to comprehend the underlying patterns that facilitate the learning process.
Good datasets should provide more information about these underlying patterns. For instance, consider two alternative datasets, each consisting of only two sentences:6
Dataset 1: There’s trash on my bed. My small dog ran and jumped over my bed.
Dataset 2: There’s trash on my bed. The boat lies on the bed of the river
The sample size for both datasets is the same, but there’s some sense in which the second one provides more information, at least with respect to the meaning of the word bed. Under this interpretation, I can imagine one can train a larger model on the latter, but it will appear to be more overparametrized.
Creating such datasets from the ground up is a formidable endeavour, but maybe current LLMs help. Below is an example using ChatGPT. I’m not sure if I was successful, but this shows that it can in principle be done. Nonetheless, I suspect that we need to have better ways to quantify the amount of information that text and image data have.
Naturally, researchers have already attempted doing so7, and we should expect more of this in the future. Unfortunately, we shouldn’t expect LLMs to be our new source of data, as it has already been shown that this might lead to model collapse.
Last words
The fact that overparameterization is actually beneficial in deep learning is remarkable. There are still many open questions but we can be confident that our understanding of how, why and when these models work will only improve in the future.
* You may recognize that I’m borrowing the title from Andrej Karpathy’s 2015 superb blog post on recurrent neural nets
Another “classical” overview of this remarkable facts can be found in the paper by Jianqing Fan, et.al. (2019). In the same tweet, LeCun argues that this has amazed statisticians for decades. See also his interview with Lex Friedman (min. 8:19).
You can find more about the overparameterized linear regression case in Ch.18 of the superb Elements of Statistical Learning by Hastie and coauthors.
Non-embedding (positional or token) parameters are generally not included. See for example Kaplan, et.al (2020). For further parameter counting in transformers, see also here..
There are some theoretical results that support this. See for instance the paper by Bartlett at the beginning of this post and this paper for the case of generative adversarial networks (GANs).
Language models are trained on tokenized words, so one must actually convert this to tokens to get the sample size.
I’m adapting this example from one in Figure 23.11 in Jurafsky and Martin (2023).
The TinyStories paper is one example of using LLM-generated synthetic data.