用途
从批量Tensor
中按照指定索引获取值,官方文档:torch.gather — PyTorch 2.0 documentation
实战
生成一个二维张量矩阵:
1 2 3 4 5 6 7 8 9 10 11 12
| import torch
tensor_0 = torch.arange(3, 12).view(3, 3) print(tensor_0)
''' 输出结果: tensor([[ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) '''
|
输入行向量 index,并替换行索引(dim=0)
1 2 3 4 5 6 7
| index = torch.tensor([[2, 1, 0]]) tensor_1 = tensor_0.gather(dim=0, index=index) print(tensor_1) ''' 输出结果: tensor([[9, 7, 5]]) '''
|
图解过程:
输入行向量 index,并替换列索引(dim=1)
1 2 3 4 5 6 7
| index = torch.tensor([[2, 1, 0]]) tensor_2 = tensor_0.gather(1, index) print(tensor_2) ''' 输出结果: tensor([[5, 4, 3]]) '''
|
图解过程:
输入列向量 index,并替换行索引(dim=0)
1 2 3 4 5 6 7 8 9
| index = torch.tensor([[2, 1, 0]]).t() tensor_3 = tensor_0.gather(0, index) print(tensor_3) ''' 输出结果: tensor([[9], [6], [3]]) '''
|
图解过程:
输入列向量 index,并替换列索引(dim=1)
1 2 3 4 5 6 7 8 9
| index = torch.tensor([[2, 1, 0]]).t() tensor_4 = tensor_0.gather(1, index) print(tensor_4) ''' 输出结果: tensor([[5], [7], [9]]) '''
|
图解过程:
输入二维矩阵 index,并替换列索引(dim=1)
1 2 3 4 5 6 7 8 9
| index = torch.tensor([[0, 2], [1, 2]]) tensor_5 = tensor_0.gather(1, index) print(tensor_5) ''' 输出结果: tensor([[3, 5], [7, 8]]) '''
|
图解过程: