pytorch版本PSEnet训练并部署方式
作者:__JDM__ 发布时间:2021-01-06 09:41:18
标签:pytorch,PSEnet,训练,部署
概述
源码地址
torch版本
训练环境没有按照torch的readme一样的环境,自己部署环境为:
torch==1.9.1
torchvision==0.10.1
python==3.8.0
cuda==10.2
mmcv==0.2.12
editdistance==0.5.3
Polygon3==3.0.9.1
pyclipper==1.3.0
opencv-python==3.4.2.17
Cython==0.29.24
./compile.sh
制作数据集
1、训练的数据集
采用的是rolabelimg进行标注,需要转换为ic2015格式的数据。
转换代码:
import os
from lxml import etree
import numpy as np
import math
src_xml = "ANN"
txt_dir = "gt"
xml_listdir = os.listdir(src_xml)
xml_listpath = [os.path.join(src_xml,xml_listdir1) for xml_listdir1 in xml_listdir]
def xml_out(xml_path):
gt_lines = []
ET = etree.parse(xml_path)
objs = ET.findall("object")
for ix,obj in enumerate(objs):
name = obj.find("name").text
robox = obj.find("robndbox")
cx = int(float(robox.find("cx").text))
cy = int(float(robox.find("cy").text))
w = int(float(robox.find("w").text))
h = int(float(robox.find("h").text))
angle = float(robox.find("angle").text)
# angle = math.degrees(angle1)
wx1 = cx - int(0.5 * w)
wy1 = cy - int(0.5 * h)
wx2 = cx + int(0.5 * w)
wy2 = cy - int(0.5 * h)
wx3 = cx - int(0.5 * w)
wy3 = cy + int(0.5 * h)
wx4 = cx + int(0.5 * w)
wy4 = cy + int(0.5 * h)
x1 = int((wx1 - cx) * np.cos(angle) - (wy1 - cy) * np.sin(angle) + cx)
y1 = int((wx1 - cx) * np.sin(angle) - (wy1 - cy) * np.cos(angle) + cy)
x2 = int((wx2 - cx) * np.cos(angle) - (wy2 - cy) * np.sin(angle) + cx)
y2 = int((wx2 - cx) * np.sin(angle) - (wy2 - cy) * np.cos(angle) + cy)
x3 = int((wx3 - cx) * np.cos(angle) - (wy3 - cy) * np.sin(angle) + cx)
y3 = int((wx3 - cx) * np.sin(angle) - (wy3 - cy) * np.cos(angle) + cy)
x4 = int((wx4 - cx) * np.cos(angle) - (wy4 - cy) * np.sin(angle) + cx)
y4 = int((wx4 - cx) * np.sin(angle) - (wy4 - cy) * np.cos(angle) + cy)
lines = str(x1)+","+str(y1)+","+str(x2)+","+str(y2)+","+\
str(x3)+","+str(y3)+","+str(x4)+","+str(y4)+","+str(name)+"\n"
gt_lines.append(lines)
return gt_lines
def main():
count = 0
for xml_dir in xml_listdir:
gt_lines = xml_out(os.path.join(src_xml,xml_dir))
txt_path = "gt_" + xml_dir[:-4] + ".txt"
with open(os.path.join(txt_dir,txt_path),"a+") as fd:
fd.writelines(gt_lines)
count +=1
print("Write file %s" % str(count))
if __name__ == "__main__":
main()
rolabelimg标注后的xml文件和labelimg的xml有些区别,根据不同的标注软件,转换代码略有区别。
转换后的格式为x1,y1,x2,y2,x3,y3,x4,y4,"classes"
,此处classes为检测的类别,如果是模糊训练的话,classes为“###”。
但是重点,这个源代码对于模糊训练,loss一直为1。
2、将数据集分成训练集和测试集
这里可以按照源码路径存放数据集,也可以修改源码存放位置。
PSENet-python3\dataset\psenet\psenet_ic15.py
修改下述代码为自己文件夹
3、训练
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py config/psenet/psenet_r50_ic15_736.py
其中根据源码中的readme,
可以根据自己的需要,自行选择配置文件。
4、部署测试
import torch
import numpy as np
import argparse
import os
import os.path as osp
import sys
import time
import json
from mmcv import Config
import cv2
from torchvision import transforms
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
def prepare_image(image, target_size):
"""Do image preprocessing before prediction on any data.
:param image: original image
:param target_size: target image size
:return:
preprocessed image
"""
#assert os.path.exists(img), 'file is not exists'
#img = cv2.imread(img)
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# h, w = image.shape[:2]
# scale = long_size / max(h, w)
img = cv2.resize(img, target_size)
# 将图片由(w,h)变为(1,img_channel,h,w)
tensor = transforms.ToTensor()(img)
tensor = tensor.unsqueeze_(0)
tensor = tensor.to(torch.device("cuda:0"))
return tensor
def report_speed(outputs, speed_meters):
total_time = 0
for key in outputs:
if 'time' in key:
total_time += outputs[key]
speed_meters[key].update(outputs[key])
print('%s: %.4f' % (key, speed_meters[key].avg))
speed_meters['total_time'].update(total_time)
print('FPS: %.1f' % (1.0 / speed_meters['total_time'].avg))
def load_model(cfg):
model = build_model(cfg.model)
model = model.cuda()
model.eval()
checkpoint = "psenet_r50_ic15_1024_finetune/checkpoint_580ep.pth.tar"
if checkpoint is not None:
if os.path.isfile(checkpoint):
print("Loading model and optimizer from checkpoint '{}'".format(checkpoint))
sys.stdout.flush()
checkpoint = torch.load(checkpoint)
d = dict()
for key, value in checkpoint['state_dict'].items():
tmp = key[7:]
d[tmp] = value
model.load_state_dict(d)
else:
print("No checkpoint found at")
raise
# fuse conv and bn
model = fuse_module(model)
return model
if __name__ == '__main__':
src_dir = "testimg/"
save_dir = "test_save/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
cfg = Config.fromfile("PSENet/config/psenet/psenet_r50_ic15_1024_finetune.py")
for d in [cfg, cfg.data.test]:
d.update(dict(
report_speed=False
))
if cfg.report_speed:
speed_meters = dict(
backbone_time=AverageMeter(500),
neck_time=AverageMeter(500),
det_head_time=AverageMeter(500),
det_pse_time=AverageMeter(500),
rec_time=AverageMeter(500),
total_time=AverageMeter(500)
)
model = load_model(cfg)
model.eval()
count = 0
for img_name in os.listdir(src_dir):
img = cv2.imread(src_dir + img_name)
tensor = prepare_image(img, target_size=(1376, 1024))
data = dict()
img_metas = dict()
data['imgs'] = tensor
img_metas['org_img_size'] = torch.tensor([[img.shape[0], img.shape[1]]])
img_metas['img_size'] = torch.tensor([[1376, 1024]])
data['img_metas'] = img_metas
data.update(dict(
cfg=cfg
))
with torch.no_grad():
outputs = model(**data)
if cfg.report_speed:
report_speed(outputs, speed_meters)
for bboxes in outputs['bboxes']:
x1 = bboxes[0]
y1 = bboxes[1]
x2 = bboxes[4]
y2 = bboxes[5]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 3)
count = count + 1
cv2.imwrite(save_dir + img_name, img)
print("img test:", count)
from dataset import build_data_loader
from models import build_model
from models.utils import fuse_module
from utils import ResultFormat, AverageMeter
训练代码里含有。
来源:https://blog.csdn.net/WSNjiang/article/details/120821227
0
投稿
猜你喜欢
- 今天彬Go将和大家一起讨论网页设计趋势中很重要的环节,那就是”勾引”用户的按钮。所谓”勾引”用户的按钮,其实对于Web设计师来说,就是如何设
- 研究网页编码很长时间了,因为最近要设计一个友情链接检测的VBS脚本,而与你链接的人的页面很可能是各种编码,以前采取的方法是:如果用GB231
- python库-密码学库pynacl什么是pynacl官方: https://pynacl.readthedocs.io/en/latest
- 一、 Scott用户下的表结构SCOTT。是在Oracle数据库中,一个示例用户的名称。其作用是为初学者提供一些简单的应用示例,不过其默认是
- Python中 join() 函数的使用函数:string.join()Python中有join()和os.path.join()两个函数,
- 创建游戏文件 2048.py首先导入需要的包:import cursesfrom random import randrange, choi
- Thinkphp5微信小程序获取用户信息接口的实例详解首先在官网下载示例代码, 选php的,这里有个坑 官方的php文件,编码是UTF-8+
- 1. 确认已经安装了NT/2000和SQL Server的最新补丁程序,不用说大家应该已经安装好了,但是我觉得最好还是在这里提醒一下。2.
- 有些时候我们发现一些模块没有提供pip install 命令和安装教程 , 只提供了一个setup.py文件 , 这个时候如何安装呢?步骤打
- 主题众所周知,django.forms极其强大,不少的框架也借鉴了这个模式,如Scrapy。在表单验证时,django.forms是一绝,也
- 运行下面的代码你就可以清楚的认识到这两个参数的用法,innerText只能动态的改变指定元素内的文本内容,而innerHTML则不仅仅可以改
- 一、流程分析分析发现密码加密,且发送POST请求时header必须携带x-csrftoken,否则是报403。而x-csrftoken是在第
- <%'asp事务处理。'测试数据库为sql server,服务器为本机,数据库名为test,表名为a,两个字段id(i
- python之循环遍历关于循环遍历大家都知道,不外乎for和while,今天我在这写点不一样的循环和遍历。在实践中有时会遇到删除列表中的元素
- 这篇文章主要介绍了opencv python Canny边缘提取实现过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的
- 本文实例讲述了Python使用dict.fromkeys()快速生成一个字典。分享给大家供大家参考,具体如下:>>> re
- 这篇文章主要介绍了Python动态声明变量赋值代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- 需求描述在利用numpy进行数据分析时,常有的一个需求是:根据已知的数组生成新数组。这个问题又可以分为两类:根据筛选条件生成子数组;根据变换
- 为了安全起见,最好还是给打开的文件对象指定一个名字,这样在完成操作之后可以迅速关闭文件,防止一些无用的文件对象占用内存。举个例子,对文本文件
- 一、垃圾还是经典网页技术更新很快,一个网站的界面设计寿命仅仅2-3年而已。不管是垃圾还是精品,都没有所谓的经典。经典只存在于是哪个首次成功创