小试pytorch的Module类用法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch.nn as nn
import torch

class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv=nn.Conv2d(1,2,3)

def forward(self,x):
out=self.conv(x)
return out

x=torch.tensor([[[1,2,3],[4,5,6],[7,8,9]]],dtype=torch.float32)
net=Net()
net.load_state_dict(torch.load('test.pth'))
print(net.conv.training)
print(dir(net.conv))
print(net(x))
print(dir(net))
for module in net.modules():
print(module)
for p in net.parameters():
print(p.requires_grad)
# torch.save(net.state_dict(),'test.pth')
w1=torch.tensor([[[[-0.263,-0.2638,-0.0838],[0.0267,-0.2375,-0.0575],[-0.005,-0.1557,-0.2331]]]],dtype=torch.float32)
b1=-0.2392
s=0
for i in range(9):
s+=x[0][i//3][i%3].item()*w1[0][0][i//3][i%3].item()
s+=b1
print(s)

创建于2405151046,修改于2405151047