version = torch.__version__
model = ft_net(283, 0.5, 2)
model_dict = model.state_dict()
save_path = os.path.join('./model','ft_ResNet50','net_29_218_516.pth')
pretrained_dict = torch.load(save_path)
# del(network_dict["classifier.classifier.0.weight"])
# del(network_dict["classifier.classifier.0.bias"])
# print(network_dict.keys())
# model = network.load_state_dict(network_dict, strict=False)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
for name in model_dict:
# if name in pretrained_dict.keys():
if name not in ["classifier.classifier.0.weight","classifier.classifier.0.bias"]:
model.state_dict()[name].copy_(pretrained_dict[name])
print(model)