Python实现随机从图像中获取多个patch
作者:拜阳 发布时间:2021-09-19 13:39:56
经常有一些图像任务需要从一张大图中截取固定大小的patch来进行训练。这里面常常存在下面几个问题:
patch的位置尽可能随机,不然数据丰富性可能不够,容易引起过拟合
如果原图较大,读图带来的IO开销可能会非常大,影响训练速度,所以最好一次能够截取多个patch
我们经常不太希望因为随机性的存在而使得图像中某些区域没有被覆盖到,所以还需要注意patch位置的覆盖程度
基于以上问题,我们可以使用下面的策略从图像中获取位置随机的多个patch:
以固定的stride获取所有patch的左上角坐标
对左上角坐标进行随机扰动
对patch的左上角坐标加上宽和高得到右下角坐标
检查patch的坐标是否超出图像边界,如果超出则将其收进来,收的过程应保证patch尺寸不变
加入ROI(Region Of Interest)功能,也就是说patch不一定非要在整张图中获取,而是可以指定ROI区域
下面是实现代码和例子:
注意下面代码只是获取了patch的bounding box,并没有把patch截取出来。
# -*- coding: utf-8 -*-
import cv2
import numpy as np
def get_random_patch_bboxes(image, bbox_size, stride, jitter, roi_bbox=None):
"""
Generate random patch bounding boxes for a image around ROI region
Parameters
----------
image: image data read by opencv, shape is [H, W, C]
bbox_size: size of patch bbox, one digit or a list/tuple containing two
digits, defined by (width, height)
stride: stride between adjacent bboxes (before jitter), one digit or a
list/tuple containing two digits, defined by (x, y)
jitter: jitter size for evenly distributed bboxes, one digit or a
list/tuple containing two digits, defined by (x, y)
roi_bbox: roi region, defined by [xmin, ymin, xmax, ymax], default is whole
image region
Returns
-------
patch_bboxes: randomly distributed patch bounding boxes, n x 4 numpy array.
Each bounding box is defined by [xmin, ymin, xmax, ymax]
"""
height, width = image.shape[:2]
bbox_size = _process_geometry_param(bbox_size, min_value=1)
stride = _process_geometry_param(stride, min_value=1)
jitter = _process_geometry_param(jitter, min_value=0)
if bbox_size[0] > width or bbox_size[1] > height:
raise ValueError('box_size must be <= image size')
if roi_bbox is None:
roi_bbox = [0, 0, width, height]
# tl is for top-left, br is for bottom-right
tl_x, tl_y = _get_top_left_points(roi_bbox, bbox_size, stride, jitter)
br_x = tl_x + bbox_size[0]
br_y = tl_y + bbox_size[1]
# shrink bottom-right points to avoid exceeding image border
br_x[br_x > width] = width
br_y[br_y > height] = height
# shrink top-left points to avoid exceeding image border
tl_x = br_x - bbox_size[0]
tl_y = br_y - bbox_size[1]
tl_x[tl_x < 0] = 0
tl_y[tl_y < 0] = 0
# compute bottom-right points again
br_x = tl_x + bbox_size[0]
br_y = tl_y + bbox_size[1]
patch_bboxes = np.concatenate((tl_x, tl_y, br_x, br_y), axis=1)
return patch_bboxes
def _process_geometry_param(param, min_value):
"""
Process and check param, which must be one digit or a list/tuple containing
two digits, and its value must be >= min_value
Parameters
----------
param: parameter to be processed
min_value: min value for param
Returns
-------
param: param after processing
"""
if isinstance(param, (int, float)) or \
isinstance(param, np.ndarray) and param.size == 1:
param = int(np.round(param))
param = [param, param]
else:
if len(param) != 2:
raise ValueError('param must be one digit or two digits')
param = [int(np.round(param[0])), int(np.round(param[1]))]
# check data range using min_value
if not (param[0] >= min_value and param[1] >= min_value):
raise ValueError('param must be >= min_value (%d)' % min_value)
return param
def _get_top_left_points(roi_bbox, bbox_size, stride, jitter):
"""
Generate top-left points for bounding boxes
Parameters
----------
roi_bbox: roi region, defined by [xmin, ymin, xmax, ymax]
bbox_size: size of patch bbox, a list/tuple containing two digits, defined
by (width, height)
stride: stride between adjacent bboxes (before jitter), a list/tuple
containing two digits, defined by (x, y)
jitter: jitter size for evenly distributed bboxes, a list/tuple containing
two digits, defined by (x, y)
Returns
-------
tl_x: x coordinates of top-left points, n x 1 numpy array
tl_y: y coordinates of top-left points, n x 1 numpy array
"""
xmin, ymin, xmax, ymax = roi_bbox
roi_width = xmax - xmin
roi_height = ymax - ymin
# get the offset between the first top-left point of patch box and the
# top-left point of roi_bbox
offset_x = np.arange(0, roi_width, stride[0])[-1] + bbox_size[0]
offset_y = np.arange(0, roi_height, stride[1])[-1] + bbox_size[1]
offset_x = (offset_x - roi_width) // 2
offset_y = (offset_y - roi_height) // 2
# get the coordinates of all top-left points
tl_x = np.arange(xmin, xmax, stride[0]) - offset_x
tl_y = np.arange(ymin, ymax, stride[1]) - offset_y
tl_x, tl_y = np.meshgrid(tl_x, tl_y)
tl_x = np.reshape(tl_x, [-1, 1])
tl_y = np.reshape(tl_y, [-1, 1])
# jitter the coordinates of all top-left points
tl_x += np.random.randint(-jitter[0], jitter[0] + 1, size=tl_x.shape)
tl_y += np.random.randint(-jitter[1], jitter[1] + 1, size=tl_y.shape)
return tl_x, tl_y
if __name__ == '__main__':
image = cv2.imread('1.bmp')
patch_bboxes = get_random_patch_bboxes(
image,
bbox_size=[64, 96],
stride=[128, 128],
jitter=[32, 32],
roi_bbox=[500, 200, 1500, 800])
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255)]
color_idx = 0
for bbox in patch_bboxes:
color_idx = color_idx % 6
pt1 = (bbox[0], bbox[1])
pt2 = (bbox[2], bbox[3])
cv2.rectangle(image, pt1, pt2, color=colors[color_idx], thickness=2)
color_idx += 1
cv2.namedWindow('image', 0)
cv2.imshow('image', image)
cv2.waitKey(0)
cv2.destroyAllWindows()
cv2.imwrite('image.png', image)
在实际应用中可以进一步增加一些简单的功能:
1.根据位置增加一些过滤功能。比如说太靠近边缘的给剔除掉,有些算法可能有比较严重的边缘效应,所以此时我们可能不太想要边缘的数据加入训练
2.也可以根据某些简单的算法策略进行过滤。比如在超分辨率这样的任务中,我们可能一般不太关心面积非常大的平坦区域,比如纯色墙面,大片天空等,此时可以使用方差进行过滤
3.设置最多保留数目。有时候原图像的大小可能有很大差异,此时利用上述方法得到的patch数量也就随之有很大的差异,然而为了保持训练数据的均衡性,我们可以设置最多保留数目,为了确保覆盖程度,一般需要在截取之前对patch进行shuffle,或者计算stride
来源:https://blog.csdn.net/bby1987/article/details/114296879
猜你喜欢
- 一、函数解释在torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:换句话说,就是需要传入5个参数,
- 一、ASP的平反想到ASP 很多人会说 “asp语言很蛋疼,不能面向对象,功能单一,很多东西实现不了” 等等诸如此类。 以上说法都是错误的,
- Oracle shutdown的时候突然断电,导致使用sql/plus启动时无法连接到数据库,具体描述为:connection can no
- 导语春节是中国特有的传统节日,中国结是中华民族特有的纯粹的文化精髓,富含丰富的文化底蕴,代表着我们对未来,对美好生活的向往和憧憬。新春佳节,
- 问题详情:使用pip install pyecharts 安装的是最新版,本人默认回车后安装1.1.0版本,出现如图问题:解决方法:(推荐第
- 定位原理很简单,故不赘述,直接上源码,内附注释。(如果对您的学习有所帮助,还请帮忙点个赞,谢谢了)#!/usr/bin/env python
- Airtest全称AirtestProject,是由网易游戏推出的一款自动化测试框架,在软件测试的时候使用到了该框架。这里记录一下安装、使用
- 从我们论坛中收集了这段HTML制作页面需要最大化、最小化时可以借鉴参考。最大化效果:<OBJECT id="max
- (一)深入浅出理解索引结构实际上,您可以把索引理解为一种特殊的目录。微软的SQL SERVER提供了两种索引:聚集索引(clustered
- Application对象 Application对象是个应用程序级的对象,用来在所有用户间共享信息,并可以在Web应用程序运行期间持久地保
- 因此,我们主要解决的思路是效验session ID的有效性. 以下为引用的内容: <?php if(!isset($_SESSION[
- 在工作之余抽了点时间写了一下这个,在ie6-ie7-ff下显示位置基本都一致了。(发现demo页面用栅格线做背景,调试还真的容易得多 。热力
- 问题你想将HTML或者XML实体如 &entity; 或 &#code; 替换为对应的文本。 再者,你需要转换文本 * 定的字
- python中的数字类型工具python中为更高级的工作提供很多高级数字编程支持和对象,其中数字类型的完整工具包括:1.整数与浮点型,2.复
- 圆形的绘制 :OpenCV中使用circle(img,center,radius,color,thickness=None,lineType
- ACCESS有个BUG,那就是在使用 like 搜索时如果遇到日文就会出现“内存溢出”的问题,提示“80040e14/内
- 注释:在大多数的情况下,修改MySQL是需要有mysql里的root权限的,所以一般用户无法更改密码,除非请求管理员。方法1使用phpmya
- 语法:Void header(string $string[,bool $replace=true [, int $http_respons
- 我们的规范到底做到哪一步算是发挥良好的价值?其实一件事物我们理解错根本目的会导致出大不一样的结果,直接反应在设计师到底要体现什么的价值。想想
- 1. 排序有什么用“排序”这个专业名词原本是来源于计算机程序操作中的,是一种很常见的算法设计,当然,对交互设计来说,探讨冒泡排序和堆排序之间