pytoch索引操作

列举了一些,pytorch索引操作。

取出矩阵元素

假设我有一个四维的矩阵,我想取出第二维的某两个元素,其他维不变。则如下:

1
a = torch.rand((2,3,3,2))

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
tensor([[[[0.6026, 0.8672],
[0.8203, 0.3348],
[0.0586, 0.4303]],

[[0.9675, 0.6671],
[0.4051, 0.4087],
[0.1357, 0.4891]],

[[0.9431, 0.1633],
[0.8467, 0.9947],
[0.7302, 0.8947]]],


[[[0.4329, 0.3623],
[0.7188, 0.1713],
[0.1396, 0.5085]],

[[0.8540, 0.8602],
[0.5326, 0.5237],
[0.2139, 0.6461]],

[[0.8101, 0.1998],
[0.9879, 0.1140],
[0.2582, 0.7770]]]])

操作:

1
2
simple = a[:,[1,2],:,:]
simple.shape

输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
torch.Size([2, 2, 3, 2])
tensor([[[[0.9675, 0.6671],
[0.4051, 0.4087],
[0.1357, 0.4891]],

[[0.9431, 0.1633],
[0.8467, 0.9947],
[0.7302, 0.8947]]],


[[[0.8540, 0.8602],
[0.5326, 0.5237],
[0.2139, 0.6461]],

[[0.8101, 0.1998],
[0.9879, 0.1140],
[0.2582, 0.7770]]]])

torch.gather()

假设有一个三维矩阵

1
2
3
4
5
6
t = torch.Tensor(
[
[ [1,2,3],[4,5,6]],
[ [7,8,9],[10,11,12]]
]
)

场景1:想取出第2维的指定元素

若想取出第二维的前两个元素,可以这样去做:

1
2
3
4
5
6
indices  = torch.LongTensor(
[
[[0,1],[0,1]],
[[0,1],[0,1]]
])
torch.gather(t,2,indices)

输出:

1
2
3
4
5
tensor([[[ 1.,  2.],
[ 4., 5.]],

[[ 7., 8.],
[10., 11.]]])

其索引在第2个维度上,相当于

1
2
3
[ [t[0,0,0],t[0,0,1]], [t[0,1,0],t[0,1,1]]]

[[t[1,0,0],t[1,0,1]], t[1,1,0],t[1,1,1]]

场景2:想取出第1维的指定元素

这里的索引下标对应第1维的索引,其他维度固定

1
2
3
4
5
6
7
indices = torch.LongTensor(
[
[ [0,0,0] ],
[[0,0,0]]
]
) # [2,1,3]
torch.gather(t,1,indices) # [2,1,3]

输出:

1
2
3
tensor([[[1., 2., 3.]],

[[7., 8., 9.]]])

这里传进去的索引相当于

1
2
>t[0,_,0],t[0,_,1],t[0,_,2]
>t[1,_,0],t[1,_,1],t[1,_,2]

场景3:想取出第0维的指定元素

这里的索引下标对应的第0维的索引,其他维度固定

1
2
3
4
5
indices = torch.LongTensor([
[[0,1,0],
[0,1,0]]
]) # [1,2,3]
torch.gather(t,0,indices)

输出:

1
2
tensor([[[ 1.,  8.,  3.],
[ 4., 11., 6.]]]) #[1,2,3]

其选取的元素,其索引在第0个维度上,相当于

1
2
3
t(0,0,0) t(1,0,1),t(0,0,2)

t(0,1,0) t(1,1,1),t(0,1,2)

进阶操作:从4维矩阵的第2维中取出元素

已知一个4维矩阵,其形状为[2,3,2,3]:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
noise_neg_vector = torch.Tensor(
[[[[0.6187, 0.2877, 0.0165],
[0.1382, 0.3398, 0.5056]],

[[0.1142, 0.9186, 0.9289],
[0.8017, 0.9092, 0.5331]],

[[0.0331, 0.4678, 0.6310],
[0.8356, 0.7720, 0.9275]]],


[[[0.2631, 0.9801, 0.7762],
[0.7321, 0.9805, 0.3595]],

[[0.2959, 0.7852, 0.8006],
[0.0096, 0.1796, 0.6223]],

[[0.3650, 0.0877, 0.4225],
[0.3855, 0.9757, 0.9067]]]])

使用一个在第2维上的索引矩阵,其中形状为[2,3,1]

1
2
3
4
5
6
indices = torch.Tensor(
[
[[1],[0],[1]],

[[1],[0],[1]]
])

首先需要将其变为如下形式,其形状为[2,3,1,3]

1
2
3
4
5
6
indices_test = torch.LongTensor(
[
[[[1,1,1]],[[0,0,0]],[[1,1,1]]],
[[[1,1,1]],[[0,0,0]],[[1,1,1]]],
]
)

其可以通过以下方式变形得到:

1
indices_q = indices.repeat((1,1,3)).unsqueeze(2)

repeate(1,1,3):是指第0、1维度,数据量不变,在第2维度数据扩增三倍

unsqueeze(2):是指在扩增后的矩阵,的第二维插入一个维度即可

作者

bd160jbgm

发布于

2021-06-30

更新于

2021-07-17

许可协议