Pytorch 使用 nii数据做输入数据的操作
作者:evanna-y 发布时间:2023-12-28 23:21:33
使用pix2pix-gan做医学图像合成的时候,如果把nii数据转成png格式会损失很多信息,以为png格式图像的灰度值有256阶,因此直接使用nii的医学图像做输入会更好一点。
但是Pythorch中的Dataloader是不能直接读取nii图像的,因此加一个CreateNiiDataset的类。
先来了解一下pytorch中读取数据的主要途径——Dataset类。在自己构建数据层时都要基于这个类,类似于C++中的虚基类。
自己构建的数据层包含三个部分
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
根据自己的需要编写CreateNiiDataset子类:
因为我是基于https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
做pix2pix-gan的实验,数据包含两个部分mr 和 ct,不需要标签,因此上面的 def getitem(self, index):中不需要index这个参数了,类似地,根据需要,加入自己的参数,去掉不需要的参数。
class CreateNiiDataset(Dataset):
def __init__(self, opt, transform = None, target_transform = None):
self.path1 = opt.dataroot # parameter passing
self.A = 'MR'
self.B = 'CT'
lines = os.listdir(os.path.join(self.path1, self.A))
lines.sort()
imgs = []
for line in lines:
imgs.append(line)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def crop(self, image, crop_size):
shp = image.shape
scl = [int((shp[0] - crop_size[0]) / 2), int((shp[1] - crop_size[1]) / 2)]
image_crop = image[scl[0]:scl[0] + crop_size[0], scl[1]:scl[1] + crop_size[1]]
return image_crop
def __getitem__(self, item):
file = self.imgs[item]
img1 = sitk.ReadImage(os.path.join(self.path1, self.A, file))
img2 = sitk.ReadImage(os.path.join(self.path1, self.B, file))
data1 = sitk.GetArrayFromImage(img1)
data2 = sitk.GetArrayFromImage(img2)
if data1.shape[0] != 256:
data1 = self.crop(data1, [256, 256])
data2 = self.crop(data2, [256, 256])
if self.transform is not None:
data1 = self.transform(data1)
data2 = self.transform(data2)
if np.min(data1)<0:
data1 = (data1 - np.min(data1))/(np.max(data1)-np.min(data1))
if np.min(data2)<0:
#data2 = data2 - np.min(data2)
data2 = (data2 - np.min(data2))/(np.max(data2)-np.min(data2))
data = {}
data1 = data1[np.newaxis, np.newaxis, :, :]
data1_tensor = torch.from_numpy(np.concatenate([data1,data1,data1], 1))
data1_tensor = data1_tensor.type(torch.FloatTensor)
data['A'] = data1_tensor # should be a tensor in Float Tensor Type
data2 = data2[np.newaxis, np.newaxis, :, :]
data2_tensor = torch.from_numpy(np.concatenate([data2,data2,data2], 1))
data2_tensor = data2_tensor.type(torch.FloatTensor)
data['B'] = data2_tensor # should be a tensor in Float Tensor Type
data['A_paths'] = [os.path.join(self.path1, self.A, file)] # should be a list, with path inside
data['B_paths'] = [os.path.join(self.path1, self.B, file)]
return data
def load_data(self):
return self
def __len__(self):
return len(self.imgs)
注意:最后输出的data是一个字典,里面有四个keys=[‘A',‘B',‘A_paths',‘B_paths'], 一定要注意数据要转成FloatTensor。
其次是data[‘A_paths'] 接收的值是一个list,一定要加[ ] 扩起来,要不然测试存图的时候会有问题,找这个问题找了好久才发现。
然后直接在train.py的主函数里面把数据加载那行改掉就好了
data_loader = CreateNiiDataset(opt)
dataset = data_loader.load_data()
Over!
补充知识:nii格式图像存为npy格式
我就废话不多说了,大家还是直接看代码吧!
import nibabel as nib
import os
import numpy as np
img_path = '/home/lei/train/img/'
seg_path = '/home/lei/train/seg/'
saveimg_path = '/home/lei/train/npy_img/'
saveseg_path = '/home/lei/train/npy_seg/'
img_names = os.listdir(img_path)
seg_names = os.listdir(seg_path)
for img_name in img_names:
print(img_name)
img = nib.load(img_path + img_name).get_data() #载入
img = np.array(img)
np.save(saveimg_path + str(img_name).split('.')[0] + '.npy', img) #保存
for seg_name in seg_names:
print(seg_name)
seg = nib.load(seg_path + seg_name).get_data()
seg = np.array(seg)
np.save(saveseg_path + str(seg_name).split('.')[0] + '.npy
来源:https://blog.csdn.net/sudakuang/article/details/94746715


猜你喜欢
- 以下代码可自动登录12306 - 包括输入用户名密码以及自动识别验证码并点击验证码登陆。该源码需要稍作修改:把 username
- 在类中每次实例化一个对象都会生产一个字典来保存一个对象的所有的实例属性,这样非常的有用处,可以使我们任意的去设置新的属性。每次实例化一个对象
- 本文实例讲述了PHP排序算法之冒泡排序(Bubble Sort)实现方法。分享给大家供大家参考,具体如下:基本思想:冒泡排序是一种交换排序,
- 一、下载xlsx插件npm i xlsx二、通过element-ui组件的upload组件上传文件<el-upload
- 直接上图,图文并茂,相信你很快就知道要干什么。A文件:B文件:可以发现,A文件中“汉字井号”这一列和B文件中“WELL”这一列的属性相同,以
- 假设我们有一幅图像,图像中的文本被旋转了一个未知的角度。为了对文字进行角度的校正,我们需要完成如下几个步骤:1、检测出图中的文本范围2、计算
- 首先这里声明一下,关于我测试浏览器的版本是chrome15.0.874.121 Firefox 8.01 IE9 IETester下面的代码
- 本文实例讲述了Python单体模式的几种常见实现方法。分享给大家供大家参考,具体如下:这里python实现的单体模式,参考了:https:/
- 没有展开时点击展开之后<div class="flashread_item_box_time">  
- 去掉html中的table代码 Function OutTable(str) dim a,re&nb
- python中email模块使得处理邮件变得比较简单,今天着重学习了一下发送邮件的具体做法,这里写写自己的的心得,也请高手给些指
- 最近将Pytorch程序迁移到GPU上去的一些工作和思考环境:Ubuntu 16.04.3Python版本:3.5.2Pytorch版本:0
- 一、图像二值化1.效果2.源码import cv2import numpy as npimport matplotlib.pyplot as
- 使用场景一:如果在一张表中ManayTOManay字段关联的是自身,也就是出项这样的代码:ManyToManyField(self)那么,你
- 导言到目前为止的讨论编辑DataList的教程里,没有包含任何验证用户的输入,即使是用户非法输入— 遗漏了product的name或者负的p
- 前言遥感影像分类图一般为特定数值对应一类地物,用Python绘制时,主要在颜色的映射和对应的图例生成。plt.matplotlib.colo
- 本章的前面讨论如何使用SQL向一个表中插入数据。但是,如果你需要向一个表中添加许多条记录,使用SQL语句输入数据是很不方便的。幸运的是,My
- javascript上下滑动广告效果 参数说明:客服果果(
- MySQL数据库应用广泛,尤其对于JAVA程序员,不会陌生。如果在不想采购云数据库的情况下,可以自行安装MySQL数据库。文章将介绍,手动在
- go语言里边的字符串处理和PHP还有java 的处理是不一样的,首先申明字符串和修改字符串package mainimport "