把vgg-face.mat权重迁移到pytorch模型示例
作者:美利坚节度使 发布时间:2021-11-03 16:29:20
标签:vgg-face,mat,权重,pytorch
最近使用pytorch时,需要用到一个预训练好的人脸识别模型提取人脸ID特征,想到很多人都在用用vgg-face,但是vgg-face没有pytorch的模型,于是写个vgg-face.mat转到pytorch模型的代码
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Thu May 10 10:41:40 2018
@author: hy
"""
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
from scipy.io import loadmat
import scipy.misc as sm
import matplotlib.pyplot as plt
class vgg16_face(nn.Module):
def __init__(self,num_classes=2622):
super(vgg16_face,self).__init__()
inplace = True
self.conv1_1 = nn.Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))
self.relu1_1 = nn.ReLU(inplace)
self.conv1_2 = nn.Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))
self.relu1_2 = nn.ReLU(inplace)
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu2_1 = nn.ReLU(inplace)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu2_2 = nn.ReLU(inplace)
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu3_1 = nn.ReLU(inplace)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu3_2 = nn.ReLU(inplace)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu3_3 = nn.ReLU(inplace)
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu4_1 = nn.ReLU(inplace)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu4_2 = nn.ReLU(inplace)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu4_3 = nn.ReLU(inplace)
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu5_1 = nn.ReLU(inplace)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu5_2 = nn.ReLU(inplace)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.relu5_3 = nn.ReLU(inplace)
self.pool5 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
self.fc6 = nn.Linear(in_features=25088, out_features=4096, bias=True)
self.relu6 = nn.ReLU(inplace)
self.drop6 = nn.Dropout(p=0.5)
self.fc7 = nn.Linear(in_features=4096, out_features=4096, bias=True)
self.relu7 = nn.ReLU(inplace)
self.drop7 = nn.Dropout(p=0.5)
self.fc8 = nn.Linear(in_features=4096, out_features=num_classes, bias=True)
self._initialize_weights()
def forward(self,x):
out = self.conv1_1(x)
x_conv1 = out
out = self.relu1_1(out)
out = self.conv1_2(out)
out = self.relu1_2(out)
out = self.pool1(out)
x_pool1 = out
out = self.conv2_1(out)
out = self.relu2_1(out)
out = self.conv2_2(out)
out = self.relu2_2(out)
out = self.pool2(out)
x_pool2 = out
out = self.conv3_1(out)
out = self.relu3_1(out)
out = self.conv3_2(out)
out = self.relu3_2(out)
out = self.conv3_3(out)
out = self.relu3_3(out)
out = self.pool3(out)
x_pool3 = out
out = self.conv4_1(out)
out = self.relu4_1(out)
out = self.conv4_2(out)
out = self.relu4_2(out)
out = self.conv4_3(out)
out = self.relu4_3(out)
out = self.pool4(out)
x_pool4 = out
out = self.conv5_1(out)
out = self.relu5_1(out)
out = self.conv5_2(out)
out = self.relu5_2(out)
out = self.conv5_3(out)
out = self.relu5_3(out)
out = self.pool5(out)
x_pool5 = out
out = out.view(out.size(0),-1)
out = self.fc6(out)
out = self.relu6(out)
out = self.fc7(out)
out = self.relu7(out)
out = self.fc8(out)
return out, x_pool1, x_pool2, x_pool3, x_pool4, x_pool5
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def copy(vgglayers, dstlayer,idx):
layer = vgglayers[0][idx]
kernel, bias = layer[0]['weights'][0][0]
if idx in [33,35]: # fc7, fc8
kernel = kernel.squeeze()
dstlayer.weight.data.copy_(torch.from_numpy(kernel.transpose([1,0]))) # matrix format: axb -> bxa
elif idx == 31: # fc6
kernel = kernel.reshape(-1,4096)
dstlayer.weight.data.copy_(torch.from_numpy(kernel.transpose([1,0]))) # matrix format: axb -> bxa
else:
dstlayer.weight.data.copy_(torch.from_numpy(kernel.transpose([3,2,1,0]))) # matrix format: axbxcxd -> dxcxbxa
dstlayer.bias.data.copy_(torch.from_numpy(bias.reshape(-1)))
def get_vggface(vgg_path):
"""1. define pytorch model"""
model = vgg16_face()
"""2. get pre-trained weights and other params"""
#vgg_path = "/home/hy/vgg-face.mat" # download from http://www.vlfeat.org/matconvnet/pretrained/
vgg_weights = loadmat(vgg_path)
data = vgg_weights
meta = data['meta']
classes = meta['classes']
class_names = classes[0][0]['description'][0][0]
normalization = meta['normalization']
average_image = np.squeeze(normalization[0][0]['averageImage'][0][0][0][0])
image_size = np.squeeze(normalization[0][0]['imageSize'][0][0])
layers = data['layers']
# =============================================================================
# for idx,layer in enumerate(layers[0]):
# name = layer[0]['name'][0][0]
# print idx,name
# """
# 0 conv1_1
# 1 relu1_1
# 2 conv1_2
# 3 relu1_2
# 4 pool1
# 5 conv2_1
# 6 relu2_1
# 7 conv2_2
# 8 relu2_2
# 9 pool2
# 10 conv3_1
# 11 relu3_1
# 12 conv3_2
# 13 relu3_2
# 14 conv3_3
# 15 relu3_3
# 16 pool3
# 17 conv4_1
# 18 relu4_1
# 19 conv4_2
# 20 relu4_2
# 21 conv4_3
# 22 relu4_3
# 23 pool4
# 24 conv5_1
# 25 relu5_1
# 26 conv5_2
# 27 relu5_2
# 28 conv5_3
# 29 relu5_3
# 30 pool5
# 31 fc6
# 32 relu6
# 33 fc7
# 34 relu7
# 35 fc8
# 36 prob
# """
# =============================================================================
"""3. load weights to pytorch model"""
copy(layers,model.conv1_1,0)
copy(layers,model.conv1_2,2)
copy(layers,model.conv2_1,5)
copy(layers,model.conv2_2,7)
copy(layers,model.conv3_1,10)
copy(layers,model.conv3_2,12)
copy(layers,model.conv3_3,14)
copy(layers,model.conv4_1,17)
copy(layers,model.conv4_2,19)
copy(layers,model.conv4_3,21)
copy(layers,model.conv5_1,24)
copy(layers,model.conv5_2,26)
copy(layers,model.conv5_3,28)
copy(layers,model.fc6,31)
copy(layers,model.fc7,33)
copy(layers,model.fc8,35)
return model,class_names,average_image,image_size
if __name__ == '__main__':
"""test"""
vgg_path = "/home/hy/vgg-face.mat" # download from http://www.vlfeat.org/matconvnet/pretrained/
model,class_names,average_image,image_size = get_vggface(vgg_path)
imgpath = "/home/hy/e/avg_face.jpg"
img = sm.imread(imgpath)
img = sm.imresize(img,[image_size[0],image_size[1]])
input_arr = np.float32(img)#-average_image # h,w,c
x = torch.from_numpy(input_arr.transpose((2,0,1))) # c,h,w
avg = torch.from_numpy(average_image) #
avg = avg.view(3,1,1).expand(3,224,224)
x = x - avg
x = x.contiguous()
x = x.view(1, x.size(0), x.size(1), x.size(2))
x = Variable(x)
out, x_pool1, x_pool2, x_pool3, x_pool4, x_pool5 = model(x)
# plt.imshow(x_pool1.data.numpy()[0,45]) # plot
来源:https://blog.csdn.net/ying86615791/article/details/80347761
0
投稿
猜你喜欢
- FTP即文件传输协议;它基于客户机-服务器模型体系结构,应用广泛。它有两个通道:一个命令通道和一个数据通道。命令通道用于控制通信,数据通道用
- 在Https页面中,如果iframe所引入页面是非https协议的页面,或者src属性不存在都可能导致浏览器弹出安全警告。本人在网上查找相关
- 在网站开发过程中,经常会遇到这样的需求:用户登陆系统才可以访问某些页面,如果用户没有登陆而直接访问就会跳转到登陆界面。要实现这样的需求其实很
- 目录生成器nextsendthrowclose使用场景大集合的生成简化代码结构协程与并发总结生成器如果在一个方法内,包含了 yield 关键
- 对于大多数web应用来说,数据库都是一个十分基础性的部分。如果你在使用PHP,那么你很可能也在使用MySQL—LAMP系列中举足轻重的一份子
- 之前在《首都机场的点烟器》中分析了一个软件系统所处的状态并且列举了不同的状态所需要的展示给用户的各类信息,我们先简单回顾一下:要设计一个软件
- 能评估使用方法性能评估模块提供了一系列用于模型性能评估的函数,这些函数在模型编译时由metrics关键字设置性能评估函数类似与目标函数, 只
- import介绍import语句作用就是用来导入模块的,它可以出现在程序中的任何位置。import语句语法使用import语句导入模块,im
- 前言有一天朋友A向我抱怨,他的老板要求他把几百份word填好的word表格简历信息整理到excel中,看着他一个个将姓名,年龄……从word
- 序列是Python中最基本的数据结构。序列中的每个元素都分配一个数字 - 它的位置,或索引,第一个索引是0,第二个索引是1,依此类推。Pyt
- 代码:import sysfrom PyQt5.QtWidgets import (QWidget, QHBoxLayout, QLabel
- 如下所示:import matplotlib.pyplot as pltimport numpy as npdef readfile(fil
- 1、最优化与线性规划最优化问题的三要素是决策变量、目标函数和约束条件。线性规划(Linear programming),是研究线性约束条件下
- 相信很多人在使用Ajax与后台php页面进行交互的时候都碰到过中文乱码的问题。JSON作为一种轻量级的数据交换格式,备受亲睐,但是用PHP作
- 以前经常吃公司旁边的食堂,人多,排队。夏天的时候,我们总要找一个靠窗口通风好的地方坐,没有空调只有风扇,风扇很多,开关都集中在一个地方,应该
- 前几天,为了增强本站的SEO,着手把另一个域名:www.aspxhome.com下的所有页面301转向到www.cidianwang.com
- 前言又要过年了,今年你不妨自己写一段代码来抢回家的火车票,是不是很Cool。下面话不多说了,来一起看看详细的介绍吧。先准备好:12306网站
- 上一课:ACCESS入门教程:初识Access 2000窗口接口简介 通过上一课的学习,你是否感觉Access的窗口和接口还有点搞不清楚,对
- 启动mysql server 失败,查看/var/log/mysqld.err080329 16:01:29 [ERROR] Can'
- 主题众所周知,django.forms极其强大,不少的框架也借鉴了这个模式,如Scrapy。在表单验证时,django.forms是一绝,也