pytorch实现用Resnet提取特征并保存为txt文件的方法
作者:qq_32464407 发布时间:2023-04-10 17:21:09
标签:pytorch,Resnet,特征
接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。
以下是提取一张jpg图像的特征的程序:
# -*- coding: utf-8 -*-
import os.path
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
features_dir = './features'
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
transform1 = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor() ]
)
img = Image.open(img_path)
img1 = transform1(img)
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)
以下是提取一个文件夹下所有jpg、jpeg图像的程序:
# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
def extractor(img_path, saved_path, net, use_gpu):
transform = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor() ]
)
img = Image.open(img_path)
img = transform(img)
x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
if use_gpu:
x = x.cuda()
net = net.cuda()
y = net(x).cpu()
y = y.data.numpy()
np.savetxt(saved_path, y, delimiter=',')
if __name__ == '__main__':
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
files_list = []
sub_dirs = [x[0] for x in os.walk(data_dir) ]
sub_dirs = sub_dirs[1:]
for sub_dir in sub_dirs:
for extention in extensions:
file_glob = os.path.join(sub_dir, '*.' + extention)
files_list.extend(glob.glob(file_glob))
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
for param in resnet50_feature_extractor.parameters():
param.requires_grad = False
use_gpu = torch.cuda.is_available()
for x_path in files_list:
print(x_path)
fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)
另外最近发现一个很简单的提取不含FC层的网络的方法:
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-1] # delete the last fc layer.
convnet = nn.Sequential(*modules)
另一种更简单的方法:
resnet = models.resnet152(pretrained=True)
del resnet.fc
来源:https://blog.csdn.net/qq_32464407/article/details/79190197
0
投稿
猜你喜欢
- 安装selenium打开命令控制符输入:pip install -U selenium火狐浏览器安装firebug:www.firebug.
- 以前我一直用os.system()处理一些系统管理任务,因为我认为那是运行linux命令最简单的方式.我们能从Python官方文档里读到应该
- 本文实例讲述了Python迭代器与生成器基本用法。分享给大家供大家参考,具体如下:迭代器可以进行for循环的数据类型包括以下两种:1. 集合
- 对数据库的管理常规就是进行预防性的维护,以及修复那些出现问题的内容。进行检查和修复通常具有四个主要的任务:1. 对表进行优化2. 对表进行分
- requests相比urllib,第三方库requests更加简单人性化,是爬虫工作中常用的库requests安装初级爬虫的开始主要是使用r
- 废话少说,直接上代码:<?php/** * Note:for octet-stream upload * 这个是流式上传PHP文件 *
- EXEC SQL WHENEVER SQLERROR CONTINUE; sqlglm(msg_buffer, &buf
- Python中使用ElementTree可以很方便的处理XML,但是产生的XML文件内容会合并在一行,难以看清楚。如下格式:<root
- 本文实例讲述了python根据文件大小打log日志的方法,分享给大家供大家参考。具体方法如下:import glob import logg
- 前言需要导入以下包,没有的通过pip安装import matplotlib.pyplot as pltimport cv2from PIL
- 目录1.列表2.使用格式3.一些很有用的函数4.元组 tuple5.元组的常用函数1.列表python没有数组,而是引入了列表(list),
- python 列表和链表的区别python 中的 list 并不是我们传统意义上的列表,传统列表——通常也叫作链表(linked list)
- 很棒的新闻发布系统分享给大家,希望大家喜欢。下面就让我们来说一说基于jsp的新闻发布系统,其中使用的技术有JavaBean、fillter、
- python装饰器在平常的python编程中用到的还是很多的,在本篇文章中我们先来介绍一下python中最常使用的@staticmethod
- 这篇文章主要介绍了如何给Python代码进行加密,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以
- 表一、运算符与特殊字符 运算符描述/选择子元素,返回左侧元素的直接子元素;如果"/"位于最左侧表示选择根结点的直接子元素
- 1、设置字体、风格代码主题选择Monokai会是彩色的代码。2、配置CI代码提示<1>下载代码提示项目:https://gith
- 问题问题1Python是一种动态语言,不支持类型检查。当需要对一个对象执行类型检查时,可能会采用下面的方式:class Foo(object
- 在Web开发中,JavaScript的一个很重要的作用就是对DOM进行操作,可你知道么?对DOM的操作是非常昂贵的,因为这会导致浏览器执行回
- 清除浮动这个问题的提出,在现在来说应该算是一个非常古老的问题了,很多人对解决办法估计也能烂记于心了,但是我这个落后了不少的前端开发程序员,太