图解torch.gather函数

本文最后更新于:2024年3月29日 上午

用途

从批量Tensor中按照指定索引获取值,官方文档:torch.gather — PyTorch 2.0 documentation

实战

生成一个二维张量矩阵:

1
2
3
4
5
6
7
8
9
10
11
12
import torch

tensor_0 = torch.arange(3, 12).view(3, 3)
print(tensor_0)

'''
输出结果:
tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
'''

输入行向量 index,并替换行索引(dim=0)

1
2
3
4
5
6
7
index = torch.tensor([[2, 1, 0]])
tensor_1 = tensor_0.gather(dim=0, index=index)
print(tensor_1)
'''
输出结果:
tensor([[9, 7, 5]])
'''

图解过程:

91abb806f021adc4056c0d6b4d2780a

输入行向量 index,并替换列索引(dim=1)

1
2
3
4
5
6
7
index = torch.tensor([[2, 1, 0]])
tensor_2 = tensor_0.gather(1, index)
print(tensor_2)
'''
输出结果:
tensor([[5, 4, 3]])
'''

图解过程:

e12e60c7df3500d3526cb881d0a2e88

输入列向量 index,并替换行索引(dim=0)

1
2
3
4
5
6
7
8
9
index = torch.tensor([[2, 1, 0]]).t()
tensor_3 = tensor_0.gather(0, index)
print(tensor_3)
'''
输出结果:
tensor([[9],
[6],
[3]])
'''

图解过程:

bd69b5f6688273de63f60d41a22562d

输入列向量 index,并替换列索引(dim=1)

1
2
3
4
5
6
7
8
9
index = torch.tensor([[2, 1, 0]]).t()
tensor_4 = tensor_0.gather(1, index)
print(tensor_4)
'''
输出结果:
tensor([[5],
[7],
[9]])
'''

图解过程:

5a47f7d7eae582811a6a14363e57c74

输入二维矩阵 index,并替换列索引(dim=1)

1
2
3
4
5
6
7
8
9
index = torch.tensor([[0, 2],
[1, 2]])
tensor_5 = tensor_0.gather(1, index)
print(tensor_5)
'''
输出结果:
tensor([[3, 5],
[7, 8]])
'''

图解过程:

008a9e1b8374e0d4555a8e15a9819b4

图解torch.gather函数
https://summersong.top/post/801a7544.html
作者
SummerSong
发布于
2023年9月14日
更新于
2024年3月29日
许可协议