Pytorch提取模型特征向量保存至csv的例子
作者:朴素.无恙 发布时间:2022-09-28 00:41:17
标签:Pytorch,特征,向量,csv
Pytorch提取模型特征向量
# -*- coding: utf-8 -*-
"""
dj
"""
import torch
import torch.nn as nn
import os
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.models as models
import pretrainedmodels
import pandas as pd
class FCViewer(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class M(nn.Module):
def __init__(self, backbone1, drop, pretrained=True):
super(M,self).__init__()
if pretrained:
img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained='imagenet')
else:
img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained=None)
self.img_encoder = list(img_model.children())[:-2]
self.img_encoder.append(nn.AdaptiveAvgPool2d(1))
self.img_encoder = nn.Sequential(*self.img_encoder)
if drop > 0:
self.img_fc = nn.Sequential(FCViewer())
else:
self.img_fc = nn.Sequential(
FCViewer())
def forward(self, x_img):
x_img = self.img_encoder(x_img)
x_img = self.img_fc(x_img)
return x_img
model1=M('resnet18',0,pretrained=True)
features_dir = '/home/cc/Desktop/features'
transform1 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
file_path='/home/cc/Desktop/picture'
names = os.listdir(file_path)
print(names)
for name in names:
pic=file_path+'/'+name
img = Image.open(pic)
img1 = transform1(img)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
y = model1(x)
y = y.data.numpy()
y = y.tolist()
#print(y)
test=pd.DataFrame(data=y)
#print(test)
test.to_csv("/home/cc/Desktop/features/3.csv",mode='a+',index=None,header=None)
jiazaixunlianhaodemoxing
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
class ResidualBlock(nn.Module):
def __init__(self, inchannel, outchannel, stride=1):
super(ResidualBlock, self).__init__()
self.left = nn.Sequential(
nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True),
nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(outchannel)
)
self.shortcut = nn.Sequential()
if stride != 1 or inchannel != outchannel:
self.shortcut = nn.Sequential(
nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(outchannel)
)
def forward(self, x):
out = self.left(x)
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, ResidualBlock, num_classes=10):
super(ResNet, self).__init__()
self.inchannel = 64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.layer1 = self.make_layer(ResidualBlock, 64, 2, stride=1)
self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
self.fc = nn.Linear(512, num_classes)
def make_layer(self, block, channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1) #strides=[1,1]
layers = []
for stride in strides:
layers.append(block(self.inchannel, channels, stride))
self.inchannel = channels
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
def ResNet18():
return ResNet(ResidualBlock)
import os
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.models as models
import pretrainedmodels
import pandas as pd
class FCViewer(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class M(nn.Module):
def __init__(self, backbone1, drop, pretrained=True):
super(M,self).__init__()
if pretrained:
img_model = pretrainedmodels.__dict__[backbone1](num_classes=1000, pretrained='imagenet')
else:
img_model = ResNet18()
we='/home/cc/Desktop/dj/model1/incption--7'
# 模型定义-ResNet
#net = ResNet18().to(device)
img_model.load_state_dict(torch.load(we))#diaoyong
self.img_encoder = list(img_model.children())[:-2]
self.img_encoder.append(nn.AdaptiveAvgPool2d(1))
self.img_encoder = nn.Sequential(*self.img_encoder)
if drop > 0:
self.img_fc = nn.Sequential(FCViewer())
else:
self.img_fc = nn.Sequential(
FCViewer())
def forward(self, x_img):
x_img = self.img_encoder(x_img)
x_img = self.img_fc(x_img)
return x_img
model1=M('resnet18',0,pretrained=None)
features_dir = '/home/cc/Desktop/features'
transform1 = transforms.Compose([
transforms.Resize(56),
transforms.CenterCrop(32),
transforms.ToTensor()])
file_path='/home/cc/Desktop/picture'
names = os.listdir(file_path)
print(names)
for name in names:
pic=file_path+'/'+name
img = Image.open(pic)
img1 = transform1(img)
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
y = model1(x)
y = y.data.numpy()
y = y.tolist()
#print(y)
test=pd.DataFrame(data=y)
#print(test)
test.to_csv("/home/cc/Desktop/features/3.csv",mode='a+',index=None,header=None)
来源:https://blog.csdn.net/weixin_40123108/article/details/90678916


猜你喜欢
- 这篇文章主要介绍了django-多对多表的创建和插入代码实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需
- 最终效果如下图,右侧灰边看相对位置,版权所有谨防假冒:去年曾针对有时间先后的翻页记录了思考片段。之后没来得及调整一直是默认和插件并用,虽然难
- 一.WITH AS的含义 WITH AS短语,也叫做子查询部分(subquery factoring),可以让你做很多事情,定义一个SQL片
- 痛点引出在平时的开发当中,渲染侧边栏导航菜单有时会遇到过需要侧边栏有多层甚至无限层级的问题。此时更优雅的方式便是使用递归组件<el-m
- 本文实例讲述了python获取文件后缀名及批量更新目录下文件后缀名的方法。分享给大家供大家参考。具体实现方法如下:1. 获取文件后缀名:#!
- 本文实例为大家分享了python+opencv实现堆叠图片的具体代码,供大家参考,具体内容如下# import cv2# import nu
- 使用Python方法比用各种命令方便,可以设置超时时间,到底通不通,端口是否开放一眼能看出来。命令和返回完整权限,可以ping通,端口开放,
- 本文主要是用PyTorch来实现一个简单的回归任务。 编辑器:spyder1.引入相应的包及生成伪数据import torchimport
- 在windows下写bat的时候,通过pause命令,可以暂停程序运行,例如经常见的程序会在终端提示”按任意键继续……”,用户在终端回车后程
- 1、find(sub[, start[, end]])在索引start和end之间查找字符串sub找到,则返回最左端的索引值,未找到,则返回
- 在我们关于SQL服务器安全系列的这文章里,我们的目标是向你提供安全安装SQL服务器所需要的工具和信心,这样的话,你有价值的数据就会受到保护,
- 身份证号码的编排规则前1、2位数字表示:所在省份的代码;第3、4位数字表示:所在城市的代码;第5、6位数字表示:所在区县的代码;第7~14位
- 栈(stack)又名堆栈,它是一种运算受限的线性表。其限制是仅允许在表的一端进行插入和删除运算。这一端被称为栈顶,相对地,把另一端称为栈底。
- 本文内容速览1、绘图数据准备还是使用鸢尾花iris数据集#导入本帖要用到的库,声明如下:import matplotlib.pyplot a
- 谈到“登录”,大多数人脑海中会立刻浮现出那个“两小框:一用户名,一密码,外加一按钮”的经典豆腐块, 这样的功能模块在互联网上屡见不鲜, 成为
- MyBatis-Plus实现数据库curd操作1.mp是什么MyBatis-Plus(简称MP)是一个MyBatis 的增强工具,在MyBa
- MongoDB安装模块pip install pymongo连接数据库import pymongoclient = pymongo.Mong
- 本文实例讲述了Go语言中的匿名结构体用法。分享给大家供大家参考。具体实现方法如下:package main
- 在python-numpy使用中,可以用双层 for循环对数组元素进行访问,也可以切片成每一行后进行一维数组的遍历。代码如下:import
- 1 lambda函数函数格式是lambda keys:express 匿名函数lambda是一个表达式函数,接受ke