Pytorch 使用 nii数据做输入数据的操作
作者:evanna-y 发布时间:2023-12-28 23:21:33
使用pix2pix-gan做医学图像合成的时候,如果把nii数据转成png格式会损失很多信息,以为png格式图像的灰度值有256阶,因此直接使用nii的医学图像做输入会更好一点。
但是Pythorch中的Dataloader是不能直接读取nii图像的,因此加一个CreateNiiDataset的类。
先来了解一下pytorch中读取数据的主要途径——Dataset类。在自己构建数据层时都要基于这个类,类似于C++中的虚基类。
自己构建的数据层包含三个部分
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
根据自己的需要编写CreateNiiDataset子类:
因为我是基于https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
做pix2pix-gan的实验,数据包含两个部分mr 和 ct,不需要标签,因此上面的 def getitem(self, index):中不需要index这个参数了,类似地,根据需要,加入自己的参数,去掉不需要的参数。
class CreateNiiDataset(Dataset):
def __init__(self, opt, transform = None, target_transform = None):
self.path1 = opt.dataroot # parameter passing
self.A = 'MR'
self.B = 'CT'
lines = os.listdir(os.path.join(self.path1, self.A))
lines.sort()
imgs = []
for line in lines:
imgs.append(line)
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def crop(self, image, crop_size):
shp = image.shape
scl = [int((shp[0] - crop_size[0]) / 2), int((shp[1] - crop_size[1]) / 2)]
image_crop = image[scl[0]:scl[0] + crop_size[0], scl[1]:scl[1] + crop_size[1]]
return image_crop
def __getitem__(self, item):
file = self.imgs[item]
img1 = sitk.ReadImage(os.path.join(self.path1, self.A, file))
img2 = sitk.ReadImage(os.path.join(self.path1, self.B, file))
data1 = sitk.GetArrayFromImage(img1)
data2 = sitk.GetArrayFromImage(img2)
if data1.shape[0] != 256:
data1 = self.crop(data1, [256, 256])
data2 = self.crop(data2, [256, 256])
if self.transform is not None:
data1 = self.transform(data1)
data2 = self.transform(data2)
if np.min(data1)<0:
data1 = (data1 - np.min(data1))/(np.max(data1)-np.min(data1))
if np.min(data2)<0:
#data2 = data2 - np.min(data2)
data2 = (data2 - np.min(data2))/(np.max(data2)-np.min(data2))
data = {}
data1 = data1[np.newaxis, np.newaxis, :, :]
data1_tensor = torch.from_numpy(np.concatenate([data1,data1,data1], 1))
data1_tensor = data1_tensor.type(torch.FloatTensor)
data['A'] = data1_tensor # should be a tensor in Float Tensor Type
data2 = data2[np.newaxis, np.newaxis, :, :]
data2_tensor = torch.from_numpy(np.concatenate([data2,data2,data2], 1))
data2_tensor = data2_tensor.type(torch.FloatTensor)
data['B'] = data2_tensor # should be a tensor in Float Tensor Type
data['A_paths'] = [os.path.join(self.path1, self.A, file)] # should be a list, with path inside
data['B_paths'] = [os.path.join(self.path1, self.B, file)]
return data
def load_data(self):
return self
def __len__(self):
return len(self.imgs)
注意:最后输出的data是一个字典,里面有四个keys=[‘A',‘B',‘A_paths',‘B_paths'], 一定要注意数据要转成FloatTensor。
其次是data[‘A_paths'] 接收的值是一个list,一定要加[ ] 扩起来,要不然测试存图的时候会有问题,找这个问题找了好久才发现。
然后直接在train.py的主函数里面把数据加载那行改掉就好了
data_loader = CreateNiiDataset(opt)
dataset = data_loader.load_data()
Over!
补充知识:nii格式图像存为npy格式
我就废话不多说了,大家还是直接看代码吧!
import nibabel as nib
import os
import numpy as np
img_path = '/home/lei/train/img/'
seg_path = '/home/lei/train/seg/'
saveimg_path = '/home/lei/train/npy_img/'
saveseg_path = '/home/lei/train/npy_seg/'
img_names = os.listdir(img_path)
seg_names = os.listdir(seg_path)
for img_name in img_names:
print(img_name)
img = nib.load(img_path + img_name).get_data() #载入
img = np.array(img)
np.save(saveimg_path + str(img_name).split('.')[0] + '.npy', img) #保存
for seg_name in seg_names:
print(seg_name)
seg = nib.load(seg_path + seg_name).get_data()
seg = np.array(seg)
np.save(saveseg_path + str(seg_name).split('.')[0] + '.npy
来源:https://blog.csdn.net/sudakuang/article/details/94746715
猜你喜欢
- Python应用编程需要用到的针对不同数据库引擎的数据库接口:http://wiki.python.org/moin/DatabaseInt
- 在国内利用Python从Internet上爬取数据时,有些网站或API接口被限速或屏蔽,这时使用代理可以加速爬取过程,减少请求失败,Pyth
- python2和python3对于字符串的处理有很大的区别熟悉了python2的写法用python3时真的会遇到很多问题啊……区别pytho
- 1、卓越亚马逊的首页轮换图片,每刷新一次,都是随机不同的顺序显示,这样的设计解决了对于较多图片轮换而靠后的图片信息很少被看到的问题,这点对于
- 序言哈喽兄弟们,今天咱们来了解一下 fileinput 。说到fileinput,可能90%的码农表示没用过,甚至没有听说过。这不奇怪,因为
- 目录简单的验证码简单的登录页面我们经常在登录一个网站,或者注册的时候需要输入一个验证码,有时候觉得很烦,因为有些验证码不仅复杂还看不清,许多
- 这篇文章主要介绍了深入了解如何基于Python读写Kafka,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需
- 阅读上一篇:定义网页的语言编码 用web标准设计网站,过渡的方法主要是采用XHTML+CSS,css样式表是必不可少的。这就要求所有网页设计
- 用pytesseract识别图片中的数字Win 平台 使用步骤一、安装包。二、找个图片,运行如下识别程序。示例程序:import pytes
- 问题你想创建一个内嵌变量的字符串,变量被它的值所表示的字符串替换掉。解决方案Python并没有对在字符串中简单替换变量值提供直接的支持。 但
- 本文是小编针对js保留两位小数这个大家经常遇到的经典问题整理了在各种情况下的函数写法以及遇到问题的分析,以下是全部内容:一、我们首先从经典的
- PyTorch创建自己的数据集图片文件在同一的文件夹下思路是继承 torch.utils.data.Dataset,并重点重写其 __get
- 在命令行中使用 Python 时,它可以接收大约 20 个选项(option),语法格式如下:python [-bBdEhiIOqsSuvV
- 今天偶尔在知乎上看到某大佬用Python写的ATM系统案例,然后观摩了下他的实现思路和源码,感觉受益颇多,于是就根据自己的思路和目前掌握的P
- 数据类型是一种值的集合以及定义在这种值上的一组操作。一切语言的基础都是数据结构,所以打好基础对于后面的学习会有百利而无一害的作用。pytho
- Tensorflow训练模型默认占满所有GPU问题在使用gpu服务器训练tensorflow模型时,总是占满显存!TensorFlow默认的
- 最简单的方法:取整后判断是否和原值相等!javascript的取整函数是:parseIntif(parseInt(value)==value
- 0.前言Telnet协议属于TCP/IP协议族里的一种,对于我们这些网络攻城狮来说,再熟悉不过了,常用于远程登陆到网络设备进行操作,但是,它
- 一、出错情况 有些时候当你重启了数据库服务,会发现有些数据库变成了正在恢复、置疑、可疑等情况,这个时候DBA就会很紧张了,下面是一些在实践中
- 在IDLE下清屏:#网上有些先定义函数,再?print("\n" * 100)输出一百个换行的方法有点扯淡,跟连按回车没