问题 刚接触神经网络,使用nn.Module
,看到下面代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import torchfrom torch.nn import Linearfrom torch_geometric.nn import GCNConvfrom torch_geometric.datasets import KarateClub dataset = KarateClub()class GCN (torch.nn.Module): def __init__ (self ): super (GCN, self).__init__() torch.manual_seed(12345 ) self.conv1 = GCNConv(dataset.num_features, 4 ) self.conv2 = GCNConv(4 , 4 ) self.conv3 = GCNConv(4 , 2 ) self.classifier = Linear(2 , dataset.num_classes) def forward (self, x, edge_index ): h = self.conv1(x, edge_index) h = h.tanh() h = self.conv2(h, edge_index) h = h.tanh() h = self.conv3(h, edge_index) h = h.tanh() out = self.classifier(h) return out, h model = GCN()print (model) data = dataset[0 ] _, h = model(data.x, data.edge_index)print (f'Embedding shape: {list (h.shape)} ' )
前面都好理解,主要是第32行,model是GCN类的一个实例,也可以理解。
但是第35行model(data.x, data.edge_index)
有点搞不明白了,为什么类的实例还能传参数?
查询了一些资料,发现Python通过一个特殊函数__call__()
让类实例也可以变成一个可调用对象。
__init__
1 2 3 4 5 6 7 8 9 class A : def __init__ (self ): print ('init函数' ) def __call__ (self, param ): print ('call 函数' , param) a = A()
输出:
分析:a=A()
进行了类的实例化,会自动调用__init__()
方法。
__call__
1 2 3 4 5 6 7 8 9 10 class A : def __init__ (self ): print ('init函数' ) def __call__ (self, param ): print ('call 函数' , param) a = A() a(1 )
输出:
分析:a
是类A
的实例对象,a(1)
相当于调用了实例(不知道这么说对不对,意思就是实例对象也可以被调用,后面加括号传参数),会自动调用__call__()
方法。
__call__()
中可以调用其它函数,如forward
函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 class A (): def __init__ (self ): print ('init函数' ) def __call__ (self, param ): print ('call 函数' , param) res = self.forward(param) return res def forward (self, input ): print ('forward 函数' , input ) output = input + 1 return output a = A() b = a(1 )print ('结果b =' , b)
输出:
到这就有nn.Module
那味了,下面这个例子更接近文章开头展示的内容:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 import torchclass A (torch.nn.Module): def __init__ (self ): super ().__init__() print ('init函数' ) def forward (self, input ): print ('forward 函数' , input ) output = input + 1 return output a = A() b = a(1 )print ('结果b =' , b)
输出:
这里并没有调用__call__()
(甚至我们都没有实现),还是调用了forward()
方法,原因是因为父类nn.Module
实现了__call__()
方法。
我们可以重写__call__()
方法,让其不调用forward
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import torch.nnclass A (torch.nn.Module): def __init__ (self ): super ().__init__() print ('init函数' ) def forward (self, input ): print ('forward 函数' , input ) output = input + 1 return output def __call__ (self, input ): print ('重写 call 函数' , input ) return input a = A() b = a(1 )print ('结果b =' , b)
输出:
参考链接:https://blog.csdn.net/qq_43745026/article/details/125537774