网络编程
位置:首页>> 网络编程>> Python编程>> 自己搭建resnet18网络并加载torchvision自带权重的操作

自己搭建resnet18网络并加载torchvision自带权重的操作

作者:找不到服务器1703  发布时间:2021-11-28 12:24:33 

标签:resnet18,网络,加载,torchvision,权重

直接搭建网络必须与torchvision自带的网络的权重也就是pth文件的结构、尺寸和变量命名完全一致,否则无法加载权重文件。

此时可比较2个字典逐一加载,详见

pytorch加载预训练模型与自己模型不匹配的解决方案


import torch
import torchvision
import cv2 as cv
from utils.utils import letter_box
from model.backbone import ResNet18

model1 = ResNet18(1)
model2 = torchvision.models.resnet18(progress=False)
fc = model2.fc
model2.fc = torch.nn.Linear(512, 1)
# print(model)
model_dict1 = model1.state_dict()
model_dict2 = torch.load('resnet18.pth')
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
   if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
       continue
   model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model1.load_state_dict(model_dict1)
missing, unspected = model2.load_state_dict(model_dict2)
image = cv.imread('zhn1.jpg')
image = letter_box(image, 224)
image = image[:, :, ::-1].transpose(2, 0, 1)
print('Network loading complete.')
model1.eval()
model2.eval()
with torch.no_grad():
   image = torch.tensor(image/256, dtype=torch.float32).unsqueeze(0)
   predict1 = model1(image)
   predict2 = model2(image)
print('finished')
# torch.save(model.state_dict(), 'resnet18.pth')

来源:https://blog.csdn.net/qq_34288751/article/details/114163057

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com