PyTorchのTensorはどうやってデータを持っているのか?

お仕事でPyTorchを扱っているのですが、以下のような疑問がふつふつと湧いてきましたので、Tensorのデータが実際にはどうやって保持・管理されているのかを調べて整理しました。

  • image_tensor = minibatch_tensor[i, :, :, :]はメモリコピーが発生するのか?
  • input_tensorをコピーして別々のネットワークにフォワードしたいのだけどどうすればいいのか?
  • 効率的にメモリアクセスできるレイアウトになっているのか?

この投稿ではPyTorch 1.1.0を使ってます。

import torch
import numpy as np

print(torch.__version__)
# 1.1.0

Tensorでのデータの持ち方

Tensorはnumpyのビューとよく似ています。

メモリ上の実体はStorageオブジェクトが持つ

Tensorのデータはメモリ上では全て1次元配列として保持されており、その実体を管理しているのがStorageオブジェクトです

Tensorからstorageメソッドを呼び出すことで、Storageオブジェクトが取得できます。ここではFloatStorageがオブジェクトが返されてますが、型ごとに定義されてます。
Tensorを更新すれば、Storageも更新されます。

a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
a_storage = a.storage()
print(a_storage)
#  1.0
#  2.0
#  3.0
#  4.0
#  5.0
#  6.0
# [torch.FloatStorage of size 6]

a[1, 1] = 10

print(a)
# tensor([[ 1.,  2.,  3.],
#         [ 4., 10.,  6.]])

print(a_storage)
#  1.0
#  2.0
#  3.0
#  4.0
#  10.0
#  6.0
# [torch.FloatStorage of size 6]

オフセットとストライド

TensorはStorageに対するビューの役割を果たしているのですが、TensorからStorage (1次元配列) へマッピングするために、Tensorではオフセットとストライドを持ってます。それぞれstorage_offsetメソッドとstrideメソッドで取得できます。

上のaの場合、オフセットは0、ストライドは(3, 1)となります。
これは0 + 次数0のインデックス*3 + 次数1のインデックス*1 = Storage (1次元配列) 上のインデックスでマッピングされます。

print(a.storage_offset(), a.stride())
# 0 (3, 1)

Tensorにインデックスを使ってアクセスすると、Storageは同じでオフセット・ストライドのみが異なるTensorオブジェクトが生成されます。つまりメモリコピーが発生しません。

以下のbcはいずれもaから生成されたTensorオブジェクトですが、Storageオブジェクトは全て同じです。aの値を更新すると、bcの値も変更されていることが確認できます。

  • transposeは転置を行う関数で、次数0と1が入れ替わっているので、ストライドも(3, 1) -> (1, 3) となってます
b = a[1:]

print(b)
# tensor([[ 4., 10.,  6.]])

print(b.storage_offset(), b.stride())
# 3 (3, 1)

c = a.transpose(0, 1)

print(c)
# tensor([[ 1.,  4.],
#         [ 2., 10.],
#         [ 3.,  6.]])

print(c.storage_offset(), c.stride())
# 0 (1, 3)

a[1, 2] = 11
print(a)
# tensor([[ 1.,  2.,  3.],
#         [ 4., 10., 11.]])

print(b)
# tensor([[ 4., 10., 11.]])

print(c)
# tensor([[ 1.,  4.],
#         [ 2., 10.],
#         [ 3., 11.]])

contiguous

オフセットとストライドによってマッピングされますので、メモリレイアウトは変更されません。

TensorからStorageへのマッピングが、メモリの連続した領域へのアクセスとなっているかどうかをチェックするメソッドとしてis_contiguousが用意されています。

  • 連続している方がCPU / GPUのキャッシュが効きやすいため、場合によってはこういったチェックが必要になります

上の例ではabは連続 (返り値True) してますが、cはメモリ上では0番目 -> 3番目 -> 1番目 -> ... という配置になってるため非連続 (返り値False) となってます。

print(a.is_contiguous(), b.is_contiguous(), c.is_contiguous())
# True True False

連続した領域になるように再配置するには、contiguousメソッドを使います。これはメモリコピーが行われることに注意です。

c_cont = c.contiguous()

print(c_cont.is_contiguous())
# True

print(c_cont.storage())
#  1.0
#  4.0
#  2.0
#  10.0
#  3.0
#  6.0
# [torch.FloatStorage of size 6]

メモリコピーが発生する場合

上のcontiguous以外でメモリコピーが発生するケースを見ていきます。

明示的にメモリコピーを行う (= Storageを生成する) 場合、cloneメソッドを使います。

a_clone = a.clone()
a_clone[0, 0] = 100

print(a)
# tensor([[ 1.,  2.,  3.],
#         [ 4., 10., 11.]])

boolean indexingの場合は、オフセットとストライドでマッピングできないので、暗黙的にコピーが作られます。このあたりもnumpyと同じですね。

a_filterd = a[a >= 10]
print(a_filterd)
# tensor([10., 11.])

a_filterd[0] = 100
print(a)
# tensor([[ 1.,  2.,  3.],
#         [ 4., 10., 11.]])

値そのものの変更や元のStorageのサイズを変更が発生するメソッドでもコピーが発生します。

  • 上で見たtransposeやviewなどはオフセットやストライドの変更のみで済むため、コピーは発生しません
  • 末尾"_"のTensorメソッドはStorage内のメモリを上書きするため (in-place) 、これもコピーは発生しません
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

c = torch.pow(a, 2)
print(c.storage())
#  1
#  4
#  9
# [torch.LongStorage of size 3]

c_ = c.pow_(2)
print(c_.storage())
#  1
#  16
#  81
# [torch.LongStorage of size 3]

print(c.storage())
#  1
#  16
#  81
# [torch.LongStorage of size 3]

d = torch.cat((a, b))
print(d.storage())
#  1
#  2
#  3
#  4
#  5
#  6
# [torch.LongStorage of size 6]

もちろんGPUに送るとメモリコピーになります。

  • storageメソッドでアクセスするとtorch.cuda.FloatStorageオブジェクトが返されており、CPUとは別のクラスになっていることがわかります
a = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)

a_gpu = a.to(device="cuda")

print(a_gpu)
# tensor([[1., 2., 3.],
#         [4., 5., 6.]], device='cuda:0')

print(a_gpu.storage())
#  1.0
#  2.0
#  3.0
#  4.0
#  5.0
#  6.0
# [torch.cuda.FloatStorage of size 6]

まとめ

PyTorchのTensorのデータの持ち方について調べました。

  • 実体はStorageオブジェクトが持っている
  • Tensorはオフセットとストライドを持っており、Storageオブジェクトとマッピングされている
  • clone、contiguous、gpuメソッドや値やサイズの変更を伴うメソッドなどでメモリコピーが発生する

参考

Deep Learning with PyTorch | PyTorch

  • 2章を主に参考にしました