I am implementing a trigram character-level language model following Andrej Karpathy's makemore series. I have two implementations and I want to understand if they are mathematically equivalent or fundamentally different models( i gave them in the snippets)
Implementation 1 :Direct 27x27x27 weight tensor:
W = torch.randn((27, 27, 27), requires_grad=True)
for k in range(200):
logits = W[xs1, xs2]
counts = logits.exp()
probs = counts / counts.sum(1, keepdim=True)
loss = -probs[torch.arange(num), ys].log().mean()
W.grad = None
loss.backward()
W.data += -50 * W.grad
'''Here xs1 and xs2 are integer tensors of character indices. W[xs1, xs2] directly indexes into the 3D weight tensor to get logits of shape (N, 27).'''
Implementation 2 :Concatenated one-hot vectors with 54x27 weight matrix:
W= torch.randn((54, 27), requires_grad=True)
for k in range(200):
xenc1 = F.one_hot(xs1, num_classes=27).float()
xenc2 = F.one_hot(xs2, num_classes=27).float()
xenc = torch.cat([xenc1, xenc2], dim=1)
logits = xenc @ W
loss = F.cross_entropy(logits, ys)
W.grad = None
loss.backward()
W.data -= 50 * W.grad.data
My understanding so far:
Implementation 1 has 27x27x27 = 19683 parameters. Every (char1, char2) pair has a completely unique and independent set of 27 weights.
Implementation 2 has 54x27 = 1458 parameters. Because of the concatenation and matrix multiply, the contribution of char1 and char2 are additive char1 selects rows W[0:27] and char2 selects rows W[27:54] and they are summed together.
So my understanding is these are NOT equivalent models implementation 1 is more expressive but needs more data, Implementation 2 makes an additive assumption but generalizes better with less data.
My questions:
Is my understanding correct that these two models are fundamentally different and not mathematically equivalent?
Which one should converge to a lower loss on a small dataset like names.txt used in karapthy's video( makemore1) (~32k words)?
Is Implementation 1 essentially just a lookup table being trained with gradient descent, making it equivalent to the counting model?