网络编程
位置:首页>> 网络编程>> Python编程>> pytorch 更改预训练模型网络结构的方法

pytorch 更改预训练模型网络结构的方法

作者:wayne980  发布时间:2022-04-02 04:41:19 

标签:pytorch,预训练,模型,网络结构

一个继承nn.module的model它包含一个叫做children()的函数,这个函数可以用来提取出model每一层的网络结构,在此基础上进行修改即可,修改方法如下(去除后两层):


resnet_layer = nn.Sequential(*list(model.children())[:-2])

那么,接下来就可以构建我们的网络了:


class Net(nn.Module):
 def __init__(self , model):
   super(Net, self).__init__()
   #取掉model的后两层
   self.resnet_layer = nn.Sequential(*list(model.children())[:-2])

self.transion_layer = nn.ConvTranspose2d(2048, 2048, kernel_size=14, stride=3)
   self.pool_layer = nn.MaxPool2d(32)
   self.Linear_layer = nn.Linear(2048, 8)

def forward(self, x):
   x = self.resnet_layer(x)

x = self.transion_layer(x)

x = self.pool_layer(x)

x = x.view(x.size(0), -1)

x = self.Linear_layer(x)

return x

最后,构建一个对象,并加载resnet预训练的参数就可以啦~


resnet = models.resnet50(pretrained=True)
model = Net(resnet)

来源:https://blog.csdn.net/wayne980/article/details/84026939

0
投稿

猜你喜欢

  • 这个可以说属性选择符的JS版,用来遴选元素是适合不过。在开始之前,我们复习一下CSS2的属性选择符,JQuery高手可以跳过。属性选择符:名
  • Python 操作文件时,我们一般要先判断指定的文件或目录是否存在,不然容易产生异常。1.文件# 是否存在import osos.path.
  • 前言 绝大多数的Oracle数据库性能问题都是由于数据库设计不合理造成的,只有少部分问题根植于Database Buffer、Share P
  • flask多进程会引起重复加载,解决方法:把耗资源的加载挪到函数里面或者类里面,就不会重复加载资源了。测试发现,不是flask引起的,是多进
  • 先举个例子,以前负责教育培训类网站的时候,曾经接到过这样一个项目,需求方希望做一个充满趣味性的新手入门频道,页面要炫,最好是flash,用户
  • 不知道大家有没发现DWMX中有一个和FW差不多的制作弹出菜单功能?这个功能允许用文字和图片做为主菜单,如果用文字的话要先做虚拟链接。下面简单
  • 下面开始构造HTTP数据包,IP层和TCP层使用python的Impacket库,http内容自行填写。#!/usr/bin/env pyt
  • 本文实例讲述了Python pymongo模块常用操作。分享给大家供大家参考,具体如下:环境:pymongo3.0.3,python3以下是
  • python脚本执行的3种方法:(找到自己能够使用的方法,能用的方法就是好方法)方法一:交互模式直接执行语句交互模式下直接编写执行 Pyth
  • PHP mysqli_rollback() 函数关闭自动提交,做一些查询,提交查询,然后回滚当前事务:<?php// 假定数据库用户名
  • 一、共享变量共享变量:当多个线程访问同一个变量的时候。会产生共享变量的问题。例子:import threadingsum = 0loopSu
  • 绿色在黄色和蓝色(冷暖)之间,属于较中庸的颜色,这样使得绿色的性格最为平和、安稳、大度、宽容。是一种柔顺、恬静、满足、优美、受欢迎之色。也是
  • 表单的验证一直是网页设计者头痛的问题,表单验证类 Validator就是为解决这个问题而写的,旨在使设计者从纷繁复杂的表单验证中解放出来,把
  • python闭包关于闭包, 很多blog中都这样解释 :对于一个嵌套定义的函数,外层的函数的返回值是内层函数,而在内层函数中又引用了外层函数
  • <script>  function isIPv6(str)  {  return str.mat
  • 什么是pyecharts?pyecharts 是一个用于生成 Echarts 图表的类库。echarts是百度开源的一个数据可视化 JS 库
  • 1.字母和数字键的键码值(keyCode) 按键 键码 按键 键码 按键 键码 按键 键码 A 65 J 74 S 83 1 49 B 66
  • 1、页签的表达。页签表达很清晰,当前页签突出,且层级包涵关系明确;看下图,一目了然的感觉,不用疑惑我在那部分里。不信?拿当当的对比一下,你感
  • 接下来,我们将实现微信朋友圈的爬取。如果直接用 Charles 或 mitmproxy 来监听微信朋友圈的接口数据,这是无法实现爬取的,因为
  • 本文仅仅梳理最基本的绘图方法。一、初始化假设已经安装了matplotlib工具包。利用matplotlib.figure.Figure创建一
手机版 网络编程 asp之家 www.aspxhome.com