Tensorflow2.x实现用于model.fit()中的医学图像dataset

from tensorflow import keras
import SimpleITK as sitk
from scipy import ndimage
import numpy as np
import random
import math
import os


class Seg3DDataset(keras.utils.Sequence):
    def __init__(self, work_dir, num_classes, batch_size, 
                 hu_min_val, hu_max_val,
                 mode='train'):
        self.work_dir = work_dir
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.mode = mode
        self.hu_min_val = hu_min_val
        self.hu_max_val = hu_max_val

        images_dir = os.path.join(work_dir, "JPEGImages")
        labels_dir = os.path.join(work_dir, "Segmentations")
        file_names = os.listdir(labels_dir)
        random.shuffle(file_names)

        self.images_path = []
        self.labels_path = []
        for filename in file_names:
            image_path = os.path.join(images_dir, filename)
            label_path = os.path.join(labels_dir, filename)
            self.images_path.append(image_path)
            self.labels_path.append(label_path)
    
    def __len__(self):
        return math.floor(len(self.images_path) / self.batch_size)
    
    def __getitem__(self, idx):
        batch_xpaths = self.images_path[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        batch_ypaths = self.labels_path[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        x_re = []
        y_re = []
        for x_path,y_path in zip(batch_xpaths,batch_ypaths):
            image = sitk.ReadImage(x_path)
            label = sitk.ReadImage(y_path)
            if self.mode == 'train' and random.randint(0,1)==0: # 1/2的概率旋转
                angle_ = random.randint(20,90) # 训练时数据随机旋转增强
                image,label = self.rotate_image(image,label,angle=angle_)

            image_array = sitk.GetArrayFromImage(image)
            # image_array归一化
            image_array = self.normalize_img(image_array)

            label_array = sitk.GetArrayFromImage(label)
            image_array = np.transpose(image_array,[1,2,0]).astype('float32') # [256,256,16]
            label_array = np.transpose(label_array,[1,2,0]) # [256,256,16]

            # label要one-hot
            onehot_label = np.zeros((label_array.shape[0], label_array.shape[1], label_array.shape[2],self.num_classes), dtype=np.float32)
            for i in range(self.num_classes):
                onehot_label[:, :, :,i] = (label_array == i).astype(np.float32)
            # image要增加channel
            image_array = np.expand_dims(image_array,axis=-1)

            x_re.append(image_array)
            y_re.append(onehot_label)
        
        return np.array(x_re),np.array(y_re)


    def rotate_image(self,image,label,angle):
        """
        旋转image和label,返回对应的image和label
        输入输出均为nii的图像
        默认输入的image和label已经对齐
        """
        # print(f'原始spacing信息'.center(60,'='))
        spacing = image.GetSpacing()
        origin = image.GetOrigin()
        direction = image.GetDirection()
        assert image.GetSpacing()==label.GetSpacing(),f'image: {
     image.GetSpacing()}; label: {
     label.GetSpacing()}'
        assert image.GetOrigin()==label.GetOrigin(),f'image: {
     image.GetOrigin()}; label: {
     label.GetOrigin()}'
        assert image.GetDirection()==label.GetDirection(),f'image: {
     image.GetDirection()}; label: {
     label.GetDirection()}'
        
        image_array = sitk.GetArrayFromImage(image)
        label_array = sitk.GetArrayFromImage(label)
        assert image_array.shape==label_array.shape,f'image_array: {
     image_array.shape}, label_array: {
     label_array.shape}'
        # print('original shape: ',image_array.shape,label_array.shape)

        imageArray_rotate = ndimage.rotate(image_array,angle,axes=[1,2],reshape=False,mode='nearest',order=0)
        labelArray_rotate = ndimage.rotate(label_array,angle,axes=[1,2],reshape=False,mode='nearest',order=0)
        # print('rotate shape: ',imageArray_rotate.shape,labelArray_rotate.shape)
        assert image_array.shape==imageArray_rotate.shape,f'org: {
     image_array.shape}, rotate: {
     imageArray_rotate.shape}'
        assert label_array.shape==labelArray_rotate.shape,f'org: {
     label_array.shape}, rotate: {
     labelArray_rotate.shape}'

        labelArray_rotate[labelArray_rotate!=0] = 1

        image_rotate = sitk.GetImageFromArray(imageArray_rotate)
        label_rotate = sitk.GetImageFromArray(labelArray_rotate)
        
        image_rotate.SetSpacing(spacing)
        image_rotate.SetOrigin(origin)
        image_rotate.SetDirection(direction)
        label_rotate.SetSpacing(spacing)
        label_rotate.SetOrigin(origin)
        label_rotate.SetDirection(direction)

        # print(f'rotate之后spacing信息'.center(60,'='))
        # print(image_rotate.GetSpacing(),image_rotate.GetOrigin(),image_rotate.GetDirection())
        # print(label_rotate.GetSpacing(),label_rotate.GetOrigin(),label_rotate.GetDirection())
        assert image_rotate.GetSpacing()==spacing,f'rotate: {
     image_rotate.GetSpacing()}, org: {
     spacing}'
        assert image_rotate.GetOrigin()==origin,f'rotate: {
     image_rotate.GetOrigin()}, org: {
     origin}'
        assert image_rotate.GetDirection()==direction,f'rotate: {
     image_rotate.GetDirection()}, org: {
     direction}'

        return image_rotate,label_rotate

    def normalize_img(self,img:np.ndarray)->np.ndarray:
        """ 归一化 """
        # min_val=-1000
        # max_val=600
        value_range = self.hu_max_val - self.hu_min_val
        norm_0_1 = (img - self.hu_min_val) / value_range
        img = np.clip(2*norm_0_1-1,-1,1)        
        return img

    def on_epoch_end(self):
        seed = random.randint(1,100)
        random.seed(seed)
        random.shuffle(self.images_path)
        random.seed(seed)
        random.shuffle(self.labels_path)

最近更新

  1. TCP协议是安全的吗?

    2024-01-30 17:06:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-30 17:06:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-30 17:06:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-30 17:06:02       20 阅读

热门阅读

  1. js读取json的固定数据的一种方法

    2024-01-30 17:06:02       36 阅读
  2. html表单添加默认创建时间

    2024-01-30 17:06:02       35 阅读
  3. vue数据绑定

    2024-01-30 17:06:02       39 阅读
  4. 基础算法-差分-一维数组

    2024-01-30 17:06:02       30 阅读
  5. 基于STM32F103的路灯监控系统设计

    2024-01-30 17:06:02       30 阅读
  6. 聊聊PowerJob的SystemInfoController

    2024-01-30 17:06:02       31 阅读
  7. 小程序的应用、页面、组件生命周期(超全版)

    2024-01-30 17:06:02       31 阅读
  8. 获取文件夹下所有文件路径

    2024-01-30 17:06:02       39 阅读
  9. 代码随想录算法训练营|day21

    2024-01-30 17:06:02       42 阅读
  10. 提高 Code Review 质量的最佳实践

    2024-01-30 17:06:02       33 阅读
  11. 【Vue】2-3、Vue 的基本使用

    2024-01-30 17:06:02       50 阅读