Python将实例对象作为方法使用

本文最后更新于:2024年3月29日 上午

问题

刚接触神经网络,使用nn.Module,看到下面代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub

dataset = KarateClub()


class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(dataset.num_features, 4)
self.conv2 = GCNConv(4, 4)
self.conv3 = GCNConv(4, 2)
self.classifier = Linear(2, dataset.num_classes)

def forward(self, x, edge_index):
h = self.conv1(x, edge_index)
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh() # Final GNN embedding space.

# Apply a final (linear) classifier.
out = self.classifier(h)

return out, h


model = GCN()
print(model)
data = dataset[0]
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

前面都好理解,主要是第32行,model是GCN类的一个实例,也可以理解。

但是第35行model(data.x, data.edge_index)有点搞不明白了,为什么类的实例还能传参数?

查询了一些资料,发现Python通过一个特殊函数__call__()让类实例也可以变成一个可调用对象。

__init__

1
2
3
4
5
6
7
8
9
class A:
def __init__(self):
print('init函数')

def __call__(self, param):
print('call 函数', param)


a = A()

输出:

分析:a=A()进行了类的实例化,会自动调用__init__()方法。

__call__

1
2
3
4
5
6
7
8
9
10
class A:
def __init__(self):
print('init函数')

def __call__(self, param):
print('call 函数', param)


a = A()
a(1)

输出:

分析:a是类A的实例对象,a(1)相当于调用了实例(不知道这么说对不对,意思就是实例对象也可以被调用,后面加括号传参数),会自动调用__call__()方法。

__call__()中可以调用其它函数,如forward函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class A():
def __init__(self):
print('init函数')

def __call__(self, param):
print('call 函数', param)
res = self.forward(param)
return res

def forward(self, input):
print('forward 函数', input)
output = input + 1
return output


a = A()
b = a(1)
print('结果b =', b)

输出:

到这就有nn.Module那味了,下面这个例子更接近文章开头展示的内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch


class A(torch.nn.Module):
def __init__(self):
super().__init__()
print('init函数')

def forward(self, input):
print('forward 函数', input)
output = input + 1
return output


a = A()
b = a(1)
print('结果b =', b)

输出:

这里并没有调用__call__()(甚至我们都没有实现),还是调用了forward()方法,原因是因为父类nn.Module实现了__call__()方法。

我们可以重写__call__()方法,让其不调用forward

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch.nn


class A(torch.nn.Module):
def __init__(self):
super().__init__()
print('init函数')

def forward(self, input):
print('forward 函数', input)
output = input + 1
return output

def __call__(self, input):
print('重写 call 函数', input)
return input


a = A()
b = a(1)
print('结果b =', b)

输出:

参考链接:
https://blog.csdn.net/qq_43745026/article/details/125537774


Python将实例对象作为方法使用
https://summersong.top/post/c4b70d70.html
作者
SummerSong
发布于
2023年8月2日
更新于
2024年3月29日
许可协议