pytoch索引操作
列举了一些,pytorch索引操作。
取出矩阵元素
假设我有一个四维的矩阵,我想取出第二维的某两个元素,其他维不变。则如下:
1 | a = torch.rand((2,3,3,2)) |
输出:
1 | tensor([[[[0.6026, 0.8672], |
操作:
1 | simple = a[:,[1,2],:,:] |
输出:
1 | torch.Size([2, 2, 3, 2]) |
torch.gather()
假设有一个三维矩阵:
1 | t = torch.Tensor( |
场景1:想取出第2维的指定元素
若想取出第二维的前两个元素,可以这样去做:
1 | indices = torch.LongTensor( |
输出:
1 | tensor([[[ 1., 2.], |
其索引在第2个维度上,相当于
1
2
3 [ [t[0,0,0],t[0,0,1]], [t[0,1,0],t[0,1,1]]]
[[t[1,0,0],t[1,0,1]], t[1,1,0],t[1,1,1]]
场景2:想取出第1维的指定元素
这里的索引下标对应第1维的索引,其他维度固定
1 | indices = torch.LongTensor( |
输出:
1 | tensor([[[1., 2., 3.]], |
这里传进去的索引相当于
1
2 t[0,_,0],t[0,_,1],t[0,_,2]
t[1,_,0],t[1,_,1],t[1,_,2]
场景3:想取出第0维的指定元素
这里的索引下标对应的第0维的索引,其他维度固定
1 | indices = torch.LongTensor([ |
输出:
1 | tensor([[[ 1., 8., 3.], |
其选取的元素,其索引在第0个维度上,相当于
1
2
3 t(0,0,0) t(1,0,1),t(0,0,2)
t(0,1,0) t(1,1,1),t(0,1,2)
进阶操作:从4维矩阵的第2维中取出元素
已知一个4维矩阵,其形状为[2,3,2,3]:
1 | noise_neg_vector = torch.Tensor( |
使用一个在第2维上的索引矩阵,其中形状为[2,3,1]
:
1 | indices = torch.Tensor( |
首先需要将其变为如下形式,其形状为[2,3,1,3]
:
1 | indices_test = torch.LongTensor( |
其可以通过以下方式变形得到:
1 | indices_q = indices.repeat((1,1,3)).unsqueeze(2) |
repeate(1,1,3):是指第0、1维度,数据量不变,在第2维度数据扩增三倍
unsqueeze(2):是指在扩增后的矩阵,的第二维插入一个维度即可