网络编程
位置:首页>> 网络编程>> Python编程>> pytorch模型存储的2种实现方法

pytorch模型存储的2种实现方法

作者:慢行厚积  发布时间:2023-10-06 11:37:24 

标签:pytorch,模型,存储

1、保存整个网络结构信息和模型参数信息:

torch.save(model_object, './model.pth')

直接加载即可使用:

model = torch.load('./model.pth')

2、只保存网络的模型参数-推荐使用

torch.save(model_object.state_dict(), './params.pth')

加载则要先从本地网络模块导入网络,然后再加载参数:


from models import AgeModel
model = AgeModel()
model.load_state_dict(torch.load('./params.pth'))

来源:https://www.cnblogs.com/wanghui-garcia/p/11236382.html

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com