DeepLearning

PyTorchのテンソルの次元、サイズの例

MNISTでレイヤーのサイズを確認していた時にハマったのでメモします。

import torch

# 1d: [batch_size]
# ターゲットのラベルや予測に使用する
torch.Size([10])

# 2d: [batch_size, num_features (channels * height * width)]
# nn.Linear() の入力に使用する
torch.Size([64, 128])

# 3d: [batch_size, channels, num_features (height * width)]
# nn.Conv1d()の入力として利用される場合
# RNNに与える場合は、[seq_len, batch_size, num_features]
torch.Size([64, 1, 784])

# 4d: [batch_size, channels, height, width]
# nn.Conv2d()の入力に使用する
torch.Size([64, 1, 28, 28])

# 5D: [batch_size, channels, depth, height, width]
# nn.Conv3d()の入力に使用する
torch.Size([64, 1, 3, 28, 28])