1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| import torch.nn as nn import torch
class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv=nn.Conv2d(1,2,3)
def forward(self,x): out=self.conv(x) return out
x=torch.tensor([[[1,2,3],[4,5,6],[7,8,9]]],dtype=torch.float32) net=Net() net.load_state_dict(torch.load('test.pth')) print(net.conv.training) print(dir(net.conv)) print(net(x)) print(dir(net)) for module in net.modules(): print(module) for p in net.parameters(): print(p.requires_grad) # torch.save(net.state_dict(),'test.pth') w1=torch.tensor([[[[-0.263,-0.2638,-0.0838],[0.0267,-0.2375,-0.0575],[-0.005,-0.1557,-0.2331]]]],dtype=torch.float32) b1=-0.2392 s=0 for i in range(9): s+=x[0][i//3][i%3].item()*w1[0][0][i//3][i%3].item() s+=b1 print(s)
|