I am implementing a trigram character-level language model following Andrej Karpathy's makemore series. I have two implementations and I want to understand if they are mathematically equivalent or fundamentally different models( i gave them in the snippets)

Implementation 1 :Direct 27x27x27 weight tensor:


W = torch.randn((27, 27, 27), requires_grad=True)

for k in range(200):
    logits = W[xs1, xs2]
    counts = logits.exp()
    probs  = counts / counts.sum(1, keepdim=True)
    loss   = -probs[torch.arange(num), ys].log().mean()
    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

'''Here xs1 and xs2 are integer tensors of character indices. W[xs1, xs2] directly indexes into the 3D weight tensor to get logits of shape (N, 27).'''


Implementation 2 :Concatenated one-hot vectors with 54x27 weight matrix:




W= torch.randn((54, 27), requires_grad=True)

for k in range(200):
    xenc1  = F.one_hot(xs1, num_classes=27).float()
    xenc2  = F.one_hot(xs2, num_classes=27).float()
    xenc   = torch.cat([xenc1, xenc2], dim=1)
    logits = xenc @ W
    loss   = F.cross_entropy(logits, ys)
    W.grad = None
    loss.backward()
    W.data -= 50 * W.grad.data

My understanding so far:

Implementation 1 has 27x27x27 = 19683 parameters. Every (char1, char2) pair has a completely unique and independent set of 27 weights.

Implementation 2 has 54x27 = 1458 parameters. Because of the concatenation and matrix multiply, the contribution of char1 and char2 are additive char1 selects rows W[0:27] and char2 selects rows W[27:54] and they are summed together.

So my understanding is these are NOT equivalent models implementation 1 is more expressive but needs more data, Implementation 2 makes an additive assumption but generalizes better with less data.

My questions:

  1. Is my understanding correct that these two models are fundamentally different and not mathematically equivalent?

  2. Which one should converge to a lower loss on a small dataset like names.txt used in karapthy's video( makemore1) (~32k words)?

  3. Is Implementation 1 essentially just a lookup table being trained with gradient descent, making it equivalent to the counting model?