MNISTでレイヤーのサイズを確認していた時にハマったのでメモします。
import torch
# 1d: [batch_size]
# ターゲットのラベルや予測に使用する
torch.Size([10])
# 2d: [batch_size, num_features (channels * height * width)]
# nn.Linear() の入力に使用する
torch.Size([64, 128])
# 3d: [batch_size, channels, num_features (height * width)]
# nn.Conv1d()の入力として利用される場合
# RNNに与える場合は、[seq_len, batch_size, num_features]
torch.Size([64, 1, 784])
# 4d: [batch_size, channels, height, width]
# nn.Conv2d()の入力に使用する
torch.Size([64, 1, 28, 28])
# 5D: [batch_size, channels, depth, height, width]
# nn.Conv3d()の入力に使用する
torch.Size([64, 1, 3, 28, 28])