In the video The spelled-out intro to language modeling: building makemore, Andrej Karpathy briefly mentioned a bug that can occur when broadcasting tensors during mathematical operations.
In the lines below, N is a 27x27 tensor. The code is trying to normalize the rows of the tensor by dividing each row by the sum of the row.
Here is the incorrect calculation:
P = N.float()
P /= P.sum(1)
The correct calculation should be:
P = N.float()
P /= P.sum(1,keepdim=True)
Andrej briefly explained why this happens, but his explanation didn't land for me. I walked through a few smaller examples to understand the mechanics of the tensor operations.
P / P.sum(1, keepdim=True)
P = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
P = torch.tensor(P)
P_sum = torch.tensor(P).sum(1, keepdim=True)
# P_sum: this is what a column vector looks like in PyTorch
# [[6],
# [15],
# [24]]
Here, P.sum(1, keepdim=True)
computes the sum of each row in P and returns its results as a 3 element column vector. The 1
argument is specifying that we want sums computed by the first axis, which are the rows of the vector. keepdim=True
is saying that the output should have the same dimensions as the input. In other words, each row should remain a vector, rather than being flattened during the summation operation.
Now when you divide P by P_sum
, each element in a row of P is divided by the sum of that row. This normalizes each row so that the sum of the elements in each row is 1.
P / P.sum(1, keepdim=True)
# [[1/6, 2/6, 3/6],
# [4/15, 5/15, 6/15],
# [7/24, 8/24, 9/24]]
P / P.sum(1)
P = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
P = torch.tensor(P)
P_sum = P.sum(1)
# [6, 15, 24]
P.sum(1)
computes the sum of each row in P but it flattens the results into a single, flat row vector. The inner dimensions are discarded.
When you divide P by P.sum(1)
, the broadcasting rules in PyTorch will try to align the shapes. The broadcasting will align the 3 elements of P.sum(1)
with the columns of P, which is not the intended behavior. This will result in incorrect normalization. Notice how each summed row is dividing a row, rather than a column. It's operating on the wrong dimension.
P / P.sum(1)
# [[1/6, 2/15, 3/24],
# [4/6, 5/15, 6/24],
# [7/6, 8/15, 9/24]]
PyTorch mechanics can seem a little unintuitive and mind-bendy when you're learning them, so it can be helpful to pick them apart like we did here to understand them.