网络编程
位置:首页>> 网络编程>> Python编程>> pytorch实现从本地加载 .pth 格式模型

pytorch实现从本地加载 .pth 格式模型

作者:卢开毅  发布时间:2021-07-01 18:32:03 

标签:pytorch,加载,pth

可以从官网加载预训练好的模型:


import torchvision.models as models

model = models.vgg16(pretrained = True)
print(model)

但是经常会出现因为下载速度太慢而出现requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。', None, 10060, None))这种错误,因此需要我们手动去下载 .pth 文件(百度云也很慢,如果你是SVIP,当我没说;迅雷的速度也还可以),然后从本地加载。

从本地加载只需要把上面的代码换成如下:


import torchvision.models as models

model = models.vgg16(pretrained=False)
pre=torch.load(r'.\kaggle_dog_vs_cat\pretrain\vgg16-397923af.pth')
model.load_state_dict(pre)

如果你模型不是用的vgg16,而是用的vgg11或者vgg13,只需要修改语句 model = models.vgg16(pretrained=False) 为对应模型的函数即可。

来源:https://blog.csdn.net/TomorrowAndTuture/article/details/100219240

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com