pytorch实现用Resnet提取特征并保存为txt文件的方法
作者:qq_32464407 发布时间:2023-04-10 17:21:09
标签:pytorch,Resnet,特征
接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。
以下是提取一张jpg图像的特征的程序:
# -*- coding: utf-8 -*-
import os.path
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
features_dir = './features'
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
transform1 = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor() ]
)
img = Image.open(img_path)
img1 = transform1(img)
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)
以下是提取一个文件夹下所有jpg、jpeg图像的程序:
# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
def extractor(img_path, saved_path, net, use_gpu):
transform = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor() ]
)
img = Image.open(img_path)
img = transform(img)
x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
if use_gpu:
x = x.cuda()
net = net.cuda()
y = net(x).cpu()
y = y.data.numpy()
np.savetxt(saved_path, y, delimiter=',')
if __name__ == '__main__':
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
files_list = []
sub_dirs = [x[0] for x in os.walk(data_dir) ]
sub_dirs = sub_dirs[1:]
for sub_dir in sub_dirs:
for extention in extensions:
file_glob = os.path.join(sub_dir, '*.' + extention)
files_list.extend(glob.glob(file_glob))
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
use_gpu = torch.cuda.is_available()
for x_path in files_list:
print(x_path)
fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)
另外最近发现一个很简单的提取不含FC层的网络的方法:
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-1] # delete the last fc layer.
convnet = nn.Sequential(*modules)
另一种更简单的方法:
resnet = models.resnet152(pretrained=True)
del resnet.fc
来源:https://blog.csdn.net/qq_32464407/article/details/79190197


猜你喜欢
- 本文介绍一个用python结合xlsxwriter自动生成业务报表的程序。这里的业务数据采用的是指定的值,真实情况下需要其他程序来接入数据。
- 效果基于Python3。在自己写小工具的时候因为这个功能纠结了一会儿,这里写个小例子,供有需要的参考。小例子,就是点击按钮打开路径选择窗口,
- 老外真是聪明,这个方法也想得到,有兴趣的不妨试试,但是如果对方的服务器安全搞的很好的话,这个代码也许就不能用了,但不管怎么样,学习一下也是好
- 前面已经介绍过几种基本语句(print,import,赋值语句),下面我们来介绍条件语句,循环语句。一. print和import的更多信息
- CSS命名规范一.文件命名规范全局样式:global.css;框架布局:layout.css;字体样式:font.css;链接样式:link
- 背景刚入行的同学,看到在SQL语句中出现where 1 = 1这样的条件可能会有所困惑,而长时间这样使用的朋友可能又习以为常。那么,你是否还
- 目录1、mysqldump执行过程:特点2、导出 CSV 文件(最灵活)执行过程特点3、物理拷贝(最快)过程局限总结1、mysqldump执
- 如何查看cpu的核数代码: from multiprocessing import cpu_count print(&q
- 1、转化成时间格式seconds =35400m, s = divmod(seconds, 60)h, m = divmod(m, 60)p
- 程序流程分析图:传播过程:代码展示:创建环境使用<pip install+包名>来下载torch,torchvision包准备数
- 前言schedule是一个第三方轻量级的任务调度模块,可以按照秒,分,小时,日期或者自定义事件执行时间。如果想执行多个任务,也可以添加多个t
- 背景:路由结构/video/1.mp4,即/video是父路由,/1.mp4是/video的动态子路由,在/video父路由中会通过url的
- 本文分析了Python出现segfault错误解决方法。分享给大家供大家参考,具体如下:最近python程序在运行过程中偶尔会引发系统seg
- --查出表中有重复的id的记录,并计算相同id的数量select id,count(id) from @table group by id
- 📚引言泰坦尼克号的沉没是历史上最惨痛的沉船事件之一。1912年4月15日,泰坦尼克号在其处女航中与冰山相撞后沉没,2224名乘客和船员中的1
- 概述SQL Server的主要性能取决于磁盘I/O效率,SQL Server 。2008提供了数据压缩功能来提高磁盘I/O效率。表压缩意味着
- 分组取TOP数据是T-SQL中的常用查询, 如学生信息管理系统中取出每个学科前3名的学生。这种查询在SQL Server 2005之前,写起
- 相信很多与页面打过交道的同学都对 Yahoo 的 Best Practices for Speeding Up Your Web Site
- 信息图表设计(Inforgraphic Design),是信息设计(Information Design)学科的一个分支,它兴起于20世纪末
- 1. 题目编写程序, 4名牌手打牌,计算机随机将52张牌(不含大小鬼)发给4名牌手,在屏幕上显示每位牌手的牌。提示:设计出3个类:Card类