In the video The spelled-out intro to language modeling: building makemore, Andrej Karpathy briefly mentioned a bug that can occur when broadcasting tensors during mathematical operations.

In the lines below, N is a 27x27 tensor. The code is trying to normalize the rows of the tensor by dividing each row by the sum of the row.

Here is the incorrect calculation:

P = N.float()
P /= P.sum(1)

The correct calculation should be:

P = N.float()
P /= P.sum(1,keepdim=True)

Andrej briefly explained why this happens, but his explanation didn't land for me. I walked through a few smaller examples to understand the mechanics of the tensor operations.

P / P.sum(1, keepdim=True)

P = [[1, 2, 3],
     [4, 5, 6],
     [7, 8, 9]]

P = torch.tensor(P)
P_sum = torch.tensor(P).sum(1, keepdim=True)

# P_sum: this is what a column vector looks like in PyTorch
# [[6],
#  [15],
#  [24]]

Here, P.sum(1, keepdim=True) computes the sum of each row in P and returns its results as a 3 element column vector. The 1 argument is specifying that we want sums computed by the first axis, which are the rows of the vector. keepdim=True is saying that the output should have the same dimensions as the input. In other words, each row should remain a vector, rather than being flattened during the summation operation.

Now when you divide P by P_sum, each element in a row of P is divided by the sum of that row. This normalizes each row so that the sum of the elements in each row is 1.

P / P.sum(1, keepdim=True)  

# [[1/6, 2/6, 3/6],
# [4/15, 5/15, 6/15],
# [7/24, 8/24, 9/24]]

P / P.sum(1)

P = [[1, 2, 3],
     [4, 5, 6],
     [7, 8, 9]]
     
P = torch.tensor(P)
P_sum = P.sum(1) 

# [6, 15, 24]

P.sum(1) computes the sum of each row in P but it flattens the results into a single, flat row vector. The inner dimensions are discarded.

When you divide P by P.sum(1), the broadcasting rules in PyTorch will try to align the shapes. The broadcasting will align the 3 elements of P.sum(1) with the columns of P, which is not the intended behavior. This will result in incorrect normalization. Notice how each summed row is dividing a row, rather than a column. It's operating on the wrong dimension.

P / P.sum(1)  

# [[1/6, 2/15, 3/24],
#  [4/6, 5/15, 6/24],
#  [7/6, 8/15, 9/24]]

PyTorch mechanics can seem a little unintuitive and mind-bendy when you're learning them, so it can be helpful to pick them apart like we did here to understand them.