昇思MindSpore 应用学习-FCN图像语义分割-CSDN

日期

心得

昇思MindSpore 应用学习-FCN图像语义分割 (AI 代码解析)

全卷积网络(Fully Convolutional Networks,FCN)是UC Berkeley的Jonathan Long等人于2015年在Fully Convolutional Networks for Semantic Segmentation[1]一文中提出的用于图像语义分割的一种框架。
FCN是首个端到端(end to end)进行像素级(pixel level)预测的全卷积网络。

语义分割

在具体介绍FCN之前,首先介绍何为语义分割:
图像语义分割(semantic segmentation)是图像处理和机器视觉技术中关于图像理解的重要一环,AI领域中一个重要分支,常被应用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶感知等领域。
语义分割的目的是对图像中每个像素点进行分类。与普通的分类任务只输出某个类别不同,语义分割任务输出与输入大小相同的图像,输出图像的每个像素对应了输入图像每个像素的类别。语义在图像领域指的是图像的内容,对图片意思的理解,下图是一些语义分割的实例:

模型简介

FCN主要用于图像分割领域,是一种端到端的分割方法,是深度学习应用在图像语义分割的开山之作。通过进行像素级的预测直接得出与原图大小相等的label map。因FCN丢弃全连接层替换为全卷积层,网络所有层均为卷积层,故称为全卷积网络。
全卷积神经网络主要使用以下三种技术:

  1. 卷积化(Convolutional)

使用VGG-16作为FCN的backbone。VGG-16的输入为224*224的RGB图像,输出为1000个预测值。VGG-16只能接受固定大小的输入,丢弃了空间坐标,产生非空间输出。VGG-16中共有三个全连接层,全连接层也可视为带有覆盖整个区域的卷积。将全连接层转换为卷积层能使网络输出由一维非空间输出变为二维矩阵,利用输出能生成输入图片映射的heatmap。

  1. 上采样(Upsample)

在卷积过程的卷积操作和池化操作会使得特征图的尺寸变小,为得到原图的大小的稠密图像预测,需要对得到的特征图进行上采样操作。使用双线性插值的参数来初始化上采样逆卷积的参数,后通过反向传播来学习非线性上采样。在网络中执行上采样,以通过像素损失的反向传播进行端到端的学习。

  1. 跳跃结构(Skip Layer)

利用上采样技巧对最后一层的特征图进行上采样得到原图大小的分割是步长为32像素的预测,称之为FCN-32s。由于最后一层的特征图太小,损失过多细节,采用skips结构将更具有全局信息的最后一层预测和更浅层的预测结合,使预测结果获取更多的局部细节。将底层(stride 32)的预测(FCN-32s)进行2倍的上采样得到原尺寸的图像,并与从pool4层(stride 16)进行的预测融合起来(相加),这一部分的网络被称为FCN-16s。随后将这一部分的预测再进行一次2倍的上采样并与从pool3层得到的预测融合起来,这一部分的网络被称为FCN-8s。 Skips结构将深层的全局信息与浅层的局部信息相结合。

网络特点

  1. 不含全连接层(fc)的全卷积(fully conv)网络,可适应任意尺寸输入。
  2. 增大数据尺寸的反卷积(deconv)层,能够输出精细的结果。
  3. 结合不同深度层结果的跳级(skip)结构,同时确保鲁棒性和精确性。

数据处理

开始实验前,需确保本地已经安装Python环境及MindSpore。

from download import download  # 从'download'库导入'download'函数

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"  # 定义要下载的文件的URL

# 调用'download'函数下载文件
download(
    url,                         # 文件的URL地址
    "./dataset",                 # 下载文件存储的本地目录
    kind="tar",                  # 文件的压缩格式是tar
    replace=True                 # 如果目标文件存在,是否替换
)
  1. 导入download函数
from download import download

这行代码从download库中导入了download函数,使得我们可以使用这个函数来下载文件。

  1. 定义文件的URL
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"

这行代码定义了一个变量url,其中存储了待下载文件的URL地址。

  1. 调用download函数
download(url, "./dataset", kind="tar", replace=True)
  • url:指定要下载的文件的URL地址。

  • ./dataset:指定下载文件后存储的本地目录。在这里,文件将会被下载到当前目录下的dataset目录中。

  • kind="tar":指定下载的文件是一个tar压缩包。

  • replace=True:如果目标路径已经存在同名文件,是否用新下载的文件替换它。True表示替换。

  • download** 函数**:

    • 来源download
    • 功能:从指定URL下载文件并存储到本地指定目录中。
    • 参数
      • url:需要下载文件的URL地址。
      • path:指定文件下载后的本地存储路径。
      • kind:文件的压缩类型(例如tar, zip等)。
      • replace:布尔值,指示是否在目标文件存在时进行替换。

数据预处理

由于PASCAL VOC 2012数据集中图像的分辨率大多不一致,无法放在一个tensor中,故输入前需做标准化处理。

数据加载

将PASCAL VOC 2012数据集与SDB数据集进行混合。
下面的代码定义了一个 SegDataset 类,用于处理语义分割数据集。该类使用 MindSpore 框架,并包含图像和标签的预处理方法,以及创建可用于训练神经网络的数据集对象的方法。以下是代码的详细解析:

import numpy as np  # 导入NumPy库,用于数组操作
import cv2  # 导入OpenCV库,用于图像处理
import mindspore.dataset as ds  # 从MindSpore导入数据集模块

class SegDataset:
    def __init__(self,
                 image_mean,
                 image_std,
                 data_file='',
                 batch_size=32,
                 crop_size=512,
                 max_scale=2.0,
                 min_scale=0.5,
                 ignore_label=255,
                 num_classes=21,
                 num_readers=2,
                 num_parallel_calls=4):
        # 初始化方法,接收并存储各种参数和属性
        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale  # 确保最大缩放比例大于最小缩放比例

    def preprocess_dataset(self, image, label):
        # 预处理方法,用于图像和标签的各种预处理操作
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
            label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
        label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out

    def get_dataset(self):
        # 获取数据集方法,返回预处理后的数据集
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
                                 shuffle=True, num_parallel_workers=self.num_readers)
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
                              output_columns=["data", "label"],
                              num_parallel_workers=self.num_parallel_calls)
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset

# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)

dataset = dataset.get_dataset()
  1. 导入必要的库
import numpy as np
import cv2
import mindspore.dataset as ds

这部分代码导入了NumPy库(用于数组操作)、OpenCV库(用于图像处理)和MindSpore的dataset模块(用于数据集操作)。

  1. 定义 SegDataset
class SegDataset:
    def __init__(self, image_mean, image_std, data_file='', batch_size=32, crop_size=512, max_scale=2.0, min_scale=0.5, ignore_label=255, num_classes=21, num_readers=2, num_parallel_calls=4):
        # 初始化方法,接收并存储各种参数和属性
        ...
    def preprocess_dataset(self, image, label):
        # 预处理方法,用于图像和标签的各种预处理操作
        ...
    def get_dataset(self):
        # 获取数据集方法,返回预处理后的数据集
        ...

定义了一个 SegDataset 类,其中包括初始化方法 __init__、图像和标签的预处理方法 preprocess_dataset 以及获取数据集的方法 get_dataset

  1. **预处理方法 **preprocess_dataset
def preprocess_dataset(self, image, label):
    ...
  • 图像和标签解码:使用OpenCV从缓冲区解码图像和标签。
  • 随机缩放:根据随机缩放比例调整图像和标签大小。
  • 归一化:使用预定义的均值和标准差对图像进行归一化处理。
  • 填充和裁剪:进行必要的填充,以确保图像和标签的尺寸至少达到裁剪大小,然后进行随机裁剪。
  • 水平翻转:有50%的概率对图像和标签进行水平翻转。
  • 转置和类型转换:将图像转置以匹配神经网络的输入格式,并进行类型转换。
  1. **获取数据集方法 **get_dataset
def get_dataset(self):
    ...
  • 配置NUMA支持:设置数据集配置以支持NUMA(非一致性内存访问)。
  • MindDataset初始化:初始化一个 MindDataset,指定文件路径、列名和并行读取数。
  • 映射预处理操作:使用 map 方法将预处理操作应用于数据集。
  • 打乱和批处理:对数据集进行打乱并按照指定的批次大小进行批处理。
  1. 使用示例
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

dataset = SegDataset(image_mean=IMAGE_MEAN, image_std=IMAGE_STD, data_file=DATA_FILE, batch_size=train_batch_size, crop_size=crop_size, max_scale=max_scale, min_scale=min_scale, ignore_label=ignore_label, num_classes=num_classes, num_readers=2, num_parallel_calls=4)

dataset = dataset.get_dataset()

这部分定义了创建数据集和模型训练的参数,并实例化了一个 SegDataset 对象,最终调用 get_dataset 方法获取处理后的数据集。

该代码旨在使用MindSpore框架进行语义分割数据集的预处理和批处理操作,处理步骤包括图像解码、归一化、填充、裁剪、数据增强(如随机缩放和水平翻转)等,最终生成适合于神经网络输入的批处理数据集。

训练集可视化

运行以下代码观察载入的数据集图片(数据处理过程中已做归一化处理)。
下面是对你提供的代码进行详细解析:

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 8))  # 设置图形大小为16x8英寸

# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)  # 创建一个2行4列的子图,并选择第i个子图
    show_data = next(dataset.create_dict_iterator())  # 从数据集中获取下一个数据
    show_images = show_data["data"].asnumpy()  # 将数据转换为NumPy数组
    show_images = np.clip(show_images, 0, 1)  # 将数据进行裁剪,保证值在0到1之间

    # 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))  # 将图像从CHW格式转换为HWC格式,并显示
    plt.axis("off")  # 隐藏坐标轴
    plt.subplots_adjust(wspace=0.05, hspace=0)  # 调整子图之间的间距
plt.show()  # 显示图像
  1. 导入必要的库
import numpy as np
import matplotlib.pyplot as plt

这部分代码导入了NumPy库(用于数组操作)和Matplotlib库(用于图像展示)。

  1. 设置图像大小
plt.figure(figsize=(16, 8))

这行代码设置了整个图形的大小,尺寸为16x8英寸。

  1. 展示训练集中的数据
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)

    # 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()
  • 循环展示:使用 for 循环展示训练集中的前8张图片。
    • plt.subplot(2, 4, i):创建一个2行4列的子图布局,并选择第i个子图。
    • show_data = next(dataset.create_dict_iterator()):从数据集中获取下一个数据。
    • show_images = show_data["data"].asnumpy():将数据转换为NumPy数组。
    • show_images = np.clip(show_images, 0, 1):将数据进行裁剪,保证值在0到1之间。
    • plt.imshow(show_images[0].transpose(1, 2, 0)):将图像从CHW(channels, height, width)格式转换为HWC(height, width, channels)格式,并显示。
    • plt.axis("off"):隐藏坐标轴。
    • plt.subplots_adjust(wspace=0.05, hspace=0):调整子图之间的间距。
  • plt.show():显示图像。

该代码的主要功能是从训练集的数据集中获取图像数据,并展示前8张图像。每张图像从CHW格式转换为HWC格式后使用Matplotlib进行展示。通过调整子图的间距,确保图像之间没有过多的空白。

网络构建

网络流程

FCN网络的流程如下图所示:

  1. 输入图像image,经过pool1池化后,尺寸变为原始尺寸的1/2。
  2. 经过pool2池化,尺寸变为原始尺寸的1/4。
  3. 接着经过pool3、pool4、pool5池化,大小分别变为原始尺寸的1/8、1/16、1/32。
  4. 经过conv6-7卷积,输出的尺寸依然是原图的1/32。
  5. FCN-32s是最后使用反卷积,使得输出图像大小与输入图像相同。
  6. FCN-16s是将conv7的输出进行反卷积,使其尺寸扩大两倍至原图的1/16,并将其与pool4输出的特征图进行融合,后通过反卷积扩大到原始尺寸。
  7. FCN-8s是将conv7的输出进行反卷积扩大4倍,将pool4输出的特征图反卷积扩大2倍,并将pool3输出特征图拿出,三者融合后通反卷积扩大到原始尺寸。

下面是对你提供的 FCN8s 类的详细解析:

代码

import mindspore.nn as nn

class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        # 第一层卷积块
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第二层卷积块
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第三层卷积块
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第四层卷积块
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第五层卷积块
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 全连接层1
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=4096, kernel_size=7, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )

        # 全连接层2
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(in_channels=4096, out_channels=4096, kernel_size=1, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )

        # 最终评分层(分类层)
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')

        # 上采样层
        self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
        self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
        self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, kernel_size=16, stride=8, weight_init='xavier_uniform')

    def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        
        x6 = self.conv6(p5)
        
        x7 = self.conv7(x6)
        
        sf = self.score_fr(x7)
        
        u2 = self.upscore2(sf)
        
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        
        u4 = self.upscore_pool4(f4)
        
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        
        out = self.upscore8(f3)
        
        return out

解析

1. 导入必要的库
import mindspore.nn as nn

导入MindSpore库中的神经网络模块。

2. 定义 FCN8s
class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
  • 定义了 FCN8s 类,继承自 nn.Cell,用于构建全卷积网络(Fully Convolutional Network, FCN-8s)。
  • n_class 为分类数。
3. 定义卷积层和池化层
  • 第一层卷积块
self.conv1 = nn.SequentialCell(
    nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(64),
    nn.ReLU()
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 包含两个卷积层、批归一化层和ReLU激活函数,并使用最大池化层进行下采样。
  • 第二层卷积块
self.conv2 = nn.SequentialCell(
    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(128),
    nn.ReLU()
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第三层卷积块
self.conv3 = nn.SequentialCell(
    nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(256),
    nn.ReLU(),
    nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(256),
    nn.ReLU(),
    nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(256),
    nn.ReLU()
)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第四层卷积块
self.conv4 = nn.SequentialCell(
    nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(512),
    nn.ReLU(),
    nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(512),
    nn.ReLU(),
    nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(512),
    nn.ReLU()
)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
  • 第五层卷积块
self.conv5 = nn.SequentialCell(
    nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(512),
    nn.ReLU(),
    nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(512),
    nn.ReLU(),
    nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
    nn.BatchNorm2d(512),
    nn.ReLU()
)
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
4. 定义全连接层(卷积方式)
  • 全连接层1
self.conv6 = nn.SequentialCell(
    nn.Conv2d(in_channels=512, out_channels=4096, kernel_size=7, weight_init='xavier_uniform'),
    nn.BatchNorm2d(4096),
    nn.ReLU(),
)
  • 全连接层2
self.conv7 = nn.SequentialCell(
    nn.Conv2d(in_channels=4096, out_channels=4096, kernel_size=1, weight_init='xavier_uniform'),
    nn.BatchNorm2d(4096),
    nn.ReLU(),
)
5. 定义评分层和上采样层
  • 评分层
self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
  • 上采样层
self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, kernel_size=4, stride=2, weight_init='xavier_uniform')
self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, kernel_size=4, stride=2, weight_init='xavier_uniform')
self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, kernel_size=16, stride=8, weight_init='xavier_uniform')
6. 构建前向传播路径
def construct(self, x):
    x1 = self.conv1(x)
    p1 = self.pool1(x1)
    
    x2 = self.conv2(p1)
    p2 = self.pool2(x2)
    
    x3 = self.conv3(p2)
    p3 = self.pool3(x3)
    
    x4 = self.conv4(p3)
    p4 = self.pool4(x4)
    
    x5 = self.conv5(p4)
    p5 = self.pool5(x5)
    
    x6 = self.conv6(p5)
    
    x7 = self.conv7(x6)
    
    sf = self.score_fr(x7)
    
    u2 = self.upscore2(sf)
    
    s4 = self.score_pool4(p4)
    f4 = s4 + u2
    
    u4 = self.upscore_pool4(f4)
    
    s3 = self.score_pool3(p3)
    f3 = s3 + u4
    
    out = self.upscore8(f3)
    
    return out
  • 依次通过各个卷积块和池化层,提取特征。
  • 通过全连接层进行进一步特征提取。
  • 通过评分层和上采样层进行分类和逐步上采样,以恢复到输入图像的分辨率。

总结

该代码实现了FCN-8s模型,适用于语义分割任务。通过多层卷积和池化层提取图像特征,再通过全连接层和上采样层进行分类和恢复图像分辨率。最终输出的 out 是与输入图像同样大小的分类结果。

训练准备

导入VGG-16部分预训练权重

FCN使用VGG-16作为骨干网络,用于实现图像编码。使用下面代码导入VGG-16预训练模型的部分预训练权重。
下面是对你提供的代码进行详细解析:

代码

from download import download
from mindspore import load_checkpoint, load_param_into_net

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)

def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)

解析

1. 导入必要的库
from download import download
from mindspore import load_checkpoint, load_param_into_net
  • download:从 download 模块导入 download 函数,用于下载文件。
  • load_checkpoint** and **load_param_into_net:从 MindSpore 导入函数,用于加载预训练模型参数。
2. 下载预训练模型
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
  • url:预训练模型的下载链接。
  • download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True):将预训练模型下载到本地并存储为 fcn8s_vgg16_pretrain.ckpt。如果该文件已存在,则替换它。
3. 定义函数 load_vgg16
def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)
  • 加载检查点文件
    • ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt":定义预训练模型文件的路径。
    • param_vgg = load_checkpoint(ckpt_vgg16):加载检查点文件,获取预训练模型的参数。
  • 加载参数到网络中
    • load_param_into_net(net, param_vgg):将预训练模型的参数加载到网络 net 中。注意,这里 net 需要在函数外部定义并传递给 load_param_into_net

总结

这段代码实现了以下功能:

  1. 从指定的URL下载预训练的VGG-16模型检查点文件。
  2. 定义了一个 load_vgg16 函数,用于将下载的预训练模型参数加载到神经网络中。

这段代码依赖于 download 模块和预先定义的 net 网络模型。在实际使用中,确保 net 已经定义,并且网络结构与预训练模型的参数匹配。

损失函数

在语义分割任务中,每个像素点都需要进行分类,因此损失函数选择交叉熵损失函数来计算FCN网络输出与真实标签(mask)之间的交叉熵损失。MindSpore 提供了 mindspore.nn.CrossEntropyLoss 来作为损失函数。

import mindspore.nn as nn

# 定义交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()

自定义评价指标 Metrics

为了评估训练出来的模型效果,我们可以使用以下几个评价指标:

1. Pixel Accuracy (PA, 像素精度)

像素精度是标记正确的像素占总像素的比例。公式如下:
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2. Mean Pixel Accuracy (MPA, 均像素精度)

均像素精度是每个类内被正确分类像素数的比例,之后求所有类的平均。公式如下:
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3. Mean Intersection over Union (MIoU, 均交并比)

均交并比是语义分割的标准度量,计算两个集合(真实值和预测值)的交集和并集之比。在每个类上计算IoU,之后平均。公式如下:
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4. Frequency Weighted Intersection over Union (FWIoU, 频权交并比)

频权交并比是对均交并比的一种提升,根据每个类出现的频率为其设置权重。公式如下:
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

5. 代码实现

以下是如何计算这些评价指标的示例代码。假设我们有一个混淆矩阵 confusion_matrix,其中 confusion_matrix[i][j] 表示真实类为 i 但被预测为 j 的像素数量。

import numpy as np

def compute_metrics(confusion_matrix):
    # 初始化指标
    PA = 0
    MPA = 0
    MIoU = 0
    FWIoU = 0
    
    # 计算总像素数
    total_pixels = np.sum(confusion_matrix)
    
    # PA: Pixel Accuracy
    PA = np.trace(confusion_matrix) / total_pixels
    
    # MPA: Mean Pixel Accuracy
    class_accuracies = np.diag(confusion_matrix) / np.sum(confusion_matrix, axis=1)
    MPA = np.mean(class_accuracies)
    
    # MIoU: Mean Intersection over Union
    union = np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) - np.diag(confusion_matrix)
    IoU = np.diag(confusion_matrix) / union
    MIoU = np.mean(IoU)
    
    # FWIoU: Frequency Weighted Intersection over Union
    freq = np.sum(confusion_matrix, axis=1) / total_pixels
    FWIoU = np.sum(freq * IoU)
    
    return PA, MPA, MIoU, FWIoU

# 示例混淆矩阵
confusion_matrix = np.array([[50, 2, 1], 
                             [10, 30, 5], 
                             [3, 5, 60]])

# 计算指标
PA, MPA, MIoU, FWIoU = compute_metrics(confusion_matrix)
print(f"Pixel Accuracy (PA): {PA}")
print(f"Mean Pixel Accuracy (MPA): {MPA}")
print(f"Mean Intersection over Union (MIoU): {MIoU}")
print(f"Frequency Weighted Intersection over Union (FWIoU): {FWIoU}")

通过上述方法,我们定义了交叉熵损失函数用于训练语义分割模型,并实现了多种评价指标来评估模型的性能。评价指标包括像素精度(PA)、均像素精度(MPA)、均交并比(MIoU)和频权交并比
(FWIoU)。这些指标可以帮助我们全面了解模型在语义分割任务中的表现。

定义损失函数

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train

class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy


class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy


class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou


class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))

        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou
导入必要的库
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train
  • numpy:用于数值计算。
  • mindspore:MindSpore框架的核心库。
  • mindspore.nn:包含神经网络相关的模块。
  • mindspore.train:包含训练相关的模块,例如 Metric 类。
PixelAccuracy 类
class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class
  • __init__:初始化方法,设置类别数量 num_class,默认值为21。
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
  • _generate_matrix:生成混淆矩阵方法。使用掩码筛选有效像素,根据真实值和预测值生成标签,计算标签的频率并生成混淆矩阵。
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
  • clear:重置混淆矩阵。
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
  • update:更新混淆矩阵。预测结果转换为标签形式,并调用 _generate_matrix 更新混淆矩阵。
    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy
  • eval:计算并返回像素精度(PA)。
PixelAccuracyClass 类
class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class
  • __init__:初始化方法,设置类别数量 num_class,默认值为21。
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
  • _generate_matrix:生成混淆矩阵方法。使用掩码筛选有效像素,根据真实值和预测值生成标签,计算标签的频率并生成混淆矩阵。
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
  • update:更新混淆矩阵。预测结果转换为标签形式,并调用 _generate_matrix 更新混淆矩阵。
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
  • clear:重置混淆矩阵。
    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy
  • eval:计算并返回均像素精度(MPA)。
MeanIntersectionOverUnion 类
class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class
  • __init__:初始化方法,设置类别数量 num_class,默认值为21。
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
  • _generate_matrix:生成混淆矩阵方法。使用掩码筛选有效像素,根据真实值和预测值生成标签,计算标签的频率并生成混淆矩阵。
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
  • update:更新混淆矩阵。预测结果转换为标签形式,并调用 _generate_matrix 更新混淆矩阵。
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
  • clear:重置混淆矩阵。
    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou
  • eval:计算并返回均交并比(MIoU)。
FrequencyWeightedIntersectionOverUnion 类
class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class
  • __init__:初始化方法,设置类别数量 num_class,默认值为21。
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
  • _generate_matrix:生成混淆矩阵方法。使用掩码筛选有效像素,根据真实值和预测值生成标签,计算标签的频率并生成混淆矩阵。
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
  • update:更新混淆矩阵。预测结果转换为标签形式,并调用 _generate_matrix 更新混淆矩阵。
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
  • clear:重置混淆矩阵。
    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))

        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou
  • eval:计算并返回频权交并比(FWIoU)。

模型训练

导入VGG-16预训练参数后,实例化损失函数、优化器,使用Model接口编译网络,训练FCN-8s网络。

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model

# 设置运行环境
device_target = "GPU"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)

# 设置训练参数
train_batch_size = 4
num_classes = 21

# 初始化模型结构
net = FCN8s(n_class=21)

# 导入vgg16预训练参数
load_vgg16()

# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

# 学习率调度器
lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                            base_lr,
                                            total_step,
                                            iters_per_epoch,
                                            decay_epoch=2)
lr = Tensor(lr_scheduler[-1])

# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)

# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)

# 定义动态损失缩放管理器
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)

# 初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})

# 设置检查点保存参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]

save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=save_steps, keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s", directory="./ckpt", config=config_ckpt)
callbacks.append(ckpt_callback)

# 训练模型
model.train(train_epochs, dataset, callbacks=callbacks)

解析

设置运行环境
device_target = "GPU"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)
  • 设置 MindSpore 运行模式为 PYNATIVE_MODE,并指定运行设备为 GPU。
设置训练参数
train_batch_size = 4
num_classes = 21
  • 设置训练批次大小为 4,类别数量为 21。
初始化模型结构
net = FCN8s(n_class=21)
  • 初始化 FCN8s 模型,类别数量为 21。
导入 VGG16 预训练参数
load_vgg16()
  • 导入 VGG16 预训练参数(假设存在 load_vgg16 函数)。
计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                            base_lr,
                                            total_step,
                                            iters_per_epoch,
                                            decay_epoch=2)
lr = Tensor(lr_scheduler[-1])
  • 设置最小学习率为 0.0005,基础学习率为 0.05,训练轮数为 1。
  • 计算每个 epoch 的迭代次数。
  • 使用 cosine decay 学习率调度器生成学习率,并将最后的学习率转换为 Tensor。
定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
  • 定义交叉熵损失函数,并忽略索引为 255 的像素。
定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
  • 使用 Momentum 优化器,设置学习率、动量和权重衰减。
定义动态损失缩放管理器
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
  • 定义动态损失缩放管理器,设置缩放因子和窗口大小。
初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})
  • 初始化 Model 类,对于 Ascend 设备,使用损失缩放管理器。
  • 设置评价指标,包括像素精度、均像素精度、均交并比和频权交并比。
设置检查点保存参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]

save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=save_steps, keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s", directory="./ckpt", config=config_ckpt)
callbacks.append(ckpt_callback)
  • 设置时间监控和损失监控回调,并将它们添加到回调列表中。
  • 配置检查点保存参数,包括保存间隔和最多保留的检查点数量。
  • 创建检查点回调,并将其添加到回调列表中。
训练模型
model.train(train_epochs, dataset, callbacks=callbacks)
  • 使用配置的模型和回调进行模型训练。

模型评估

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)

# 加载权重文件到网络
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

# 初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset_eval = dataset.get_dataset()

# 评估模型
model.eval(dataset_eval)

解析

定义图像均值和标准差
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
  • 定义图像预处理所需的均值和标准差。
下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
  • 从指定 URL 下载预训练的权重文件,并保存为 FCN8s.ckpt
加载权重文件到网络
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
  • 加载权重文件到网络模型中。
初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer,
                  metrics={"pixel accuracy": PixelAccuracy(num_classes), "mean pixel accuracy": PixelAccuracyClass(num_classes),
                           "mean IoU": MeanIntersectionOverUnion(num_classes), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion(num_classes)})
  • 根据设备类型初始化模型,设置损失函数、优化器和评价指标。
实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
  • 实例化数据集,设置图像均值、标准差、数据文件路径、批次大小、裁剪大小、缩放范围、忽略标签、类别数量、读取器数量和并行调用数量。
评估模型
model.eval(dataset_eval)
  • 使用评估数据集对模型进行评估。

模型推理

使用训练的网络对模型推理结果进行展示。

import cv2
import matplotlib.pyplot as plt
import numpy as np
from mindspore import load_checkpoint, load_param_into_net
from mindspore.dataset import Dataset

# 初始化网络
net = FCN8s(n_class=num_classes)

# 加载预训练模型参数
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

# 设置评估批次大小
eval_batch_size = 4

# 初始化图像和预测结果列表
img_lst = []
mask_lst = []
res_lst = []

# 生成推理数据
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([eval_batch_size, 512, 512])
show_images = np.clip(show_images, 0, 1)

# 保存输入图像和标签掩码
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])

# 进行推理
res = net(show_data["data"]).asnumpy().argmax(axis=1)

# 可视化输入图像和推理结果
plt.figure(figsize=(8, 5))
for i in range(eval_batch_size):
    # 上方显示输入图片
    plt.subplot(2, eval_batch_size, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))  # 将CHW格式转换为HWC格式
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    
    # 下方显示推理结果
    plt.subplot(2, eval_batch_size, i + eval_batch_size + 1)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)

plt.show()

解析

初始化网络
net = FCN8s(n_class=num_classes)
  • 初始化一个FCN8s模型,指定类别数目。
加载预训练模型参数
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
  • 加载预训练模型参数,将其加载到网络中。
设置评估批次大小
eval_batch_size = 4
  • 设置评估批次大小为4。
初始化图像和预测结果列表
img_lst = []
mask_lst = []
res_lst = []
  • 初始化存储输入图像、标签掩码和推理结果的列表。
生成推理数据
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([eval_batch_size, 512, 512])
show_images = np.clip(show_images, 0, 1)
  • 从评估数据集中获取一个批次的数据,并将其转换为numpy数组。
  • 将标签掩码重塑为形状 [eval_batch_size, 512, 512]
  • 将图像数据剪裁到 [0, 1] 范围内。
保存输入图像和标签掩码
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])
  • 将输入图像和标签掩码保存到列表中。
进行推理
res = net(show_data["data"]).asnumpy().argmax(axis=1)
  • 使用网络进行推理,并将结果转换为numpy数组。
  • 使用 argmax 获取每个像素的类别。
可视化输入图像和推理结果
plt.figure(figsize=(8, 5))
for i in range(eval_batch_size):
    # 上方显示输入图片
    plt.subplot(2, eval_batch_size, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))  # 将CHW格式转换为HWC格式
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    
    # 下方显示推理结果
    plt.subplot(2, eval_batch_size, i + eval_batch_size + 1)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)

plt.show()
  • 使用matplotlib可视化输入图像和推理结果。
  • 上方显示输入图像,下方显示推理结果。
  • 调整子图间的间距。

总结

FCN的核心贡献在于提出使用全卷积层,通过学习让图片实现端到端分割。与传统使用CNN进行图像分割的方法相比,FCN有两大明显的优点:一是可以接受任意大小的输入图像,无需要求所有的训练图像和测试图像具有固定的尺寸。二是更加高效,避免了由于使用像素块而带来的重复存储和计算卷积的问题。
同时FCN网络也存在待改进之处:
一是得到的结果仍不够精细。进行8倍上采样虽然比32倍的效果好了很多,但是上采样的结果仍比较模糊和平滑,尤其是边界处,网络对图像中的细节不敏感。 二是对各个像素进行分类,没有充分考虑像素与像素之间的关系(如不连续性和相似性)。忽略了在通常的基于像素分类的分割方法中使用的空间规整(spatial regularization)步骤,缺乏空间一致性。

引用

[1]Long, Jonathan, Evan Shelhamer, and Trevor Darrell. “Fully convolutional networks for Semantic Segmentation.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.

整体代码

代码解析

1. 数据下载与预处理
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"
download(url, "./dataset", kind="tar", replace=True)
  • 功能:从指定URL下载数据集并解压到./dataset目录。
  • APIdownload函数用于下载文件,kind="tar"表示下载的文件是tar压缩包。
2. 数据集类定义
import numpy as np
import cv2
import mindspore.dataset as ds

class SegDataset:
    def __init__(self, image_mean, image_std, data_file='', batch_size=32, crop_size=512, max_scale=2.0, min_scale=0.5, ignore_label=255, num_classes=21, num_readers=2, num_parallel_calls=4):
        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale

    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
            label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
        label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out

    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
                                 shuffle=True, num_parallel_workers=self.num_readers)
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
                              output_columns=["data", "label"],
                              num_parallel_workers=self.num_parallel_calls)
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset
  • 功能:定义了一个数据集类SegDataset,用于加载和预处理数据。
  • API
    • cv2.imdecode:解码图像数据。
    • cv2.resize:调整图像大小。
    • cv2.copyMakeBorder:添加图像边界。
    • mindspore.dataset.MindDataset:加载MindSpore格式的数据集。
    • dataset.map:应用数据预处理操作。
    • dataset.shuffle:打乱数据顺序。
    • dataset.batch:将数据分批。
3. 数据集实例化与展示
# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)

dataset = dataset.get_dataset()

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 8))

# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()
  • 功能:实例化SegDataset并展示部分训练数据。
  • API
    • dataset.create_dict_iterator:创建数据集迭代器。
    • plt.imshow:显示图像。
    • plt.axis("off"):关闭坐标轴。
4. 模型定义
import mindspore.nn as nn

class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=3, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=4096,
                      kernel_size=7, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(in_channels=4096, out_channels=4096,
                      kernel_size=1, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
                                  kernel_size=1, weight_init='xavier_uniform')
        self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                                kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=16, stride=8, weight_init='xavier_uniform')

    def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out
  • 功能:定义了FCN8s模型结构。
  • API
    • nn.SequentialCell:顺序容器。
    • nn.Conv2d:二维卷积层。
    • nn.BatchNorm2d:二维批归一化层。
    • nn.ReLU:ReLU激活函数。
    • nn.MaxPool2d:二维最大池化层。
    • nn.Conv2dTranspose:二维反卷积层。

代码解析 (继续)

5. 模型加载预训练参数
from download import download
from mindspore import load_checkpoint, load_param_into_net

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)

def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)
  • 功能:下载并加载预训练的VGG16模型参数。
  • API
    • download:下载模型文件。
    • load_checkpoint:加载checkpoint文件中的模型参数。
    • load_param_into_net:将加载的参数导入网络模型。
6. 评估指标定义
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train

class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy
  • 功能:定义像素准确率(Pixel Accuracy)评估指标。
  • API
    • train.Metric:定义自定义评估指标的基类。
    • np.bincount:统计每个值的出现次数。
    • np.diag:提取矩阵的对角线元素。

同理,其他三个评估指标也类似,只是计算方法不同:

class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy

class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou

class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))

        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou
7. 模型训练与保存
import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model

device_target = "GPU"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)

train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr, base_lr, total_step, iters_per_epoch, decay_epoch=2)
lr = Tensor(lr_scheduler[-1])

# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)

# 初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10, keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s", directory="./ckpt", config=config_ckpt)
callbacks.append(ckpt_callback)

model.train(train_epochs, dataset, callbacks=callbacks)
  • 功能:定义并训练模型,同时保存模型参数。
  • API
    • mindspore.set_context:设置MindSpore的运行环境。
    • mindspore.nn.cosine_decay_lr:生成余弦退火学习率。
    • ModelCheckpoint:用于保存训练过程中的检查点。
    • TimeMonitor:记录训练时间。
    • LossMonitor:记录训练损失。
    • Model:MindSpore的模型类,用于训练和评估模型。
8. 加载训练好的模型并进行评估
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)

ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN, image_std=IMAGE_STD, data_file=DATA_FILE, batch_size=train_batch_size, crop_size=crop_size, max_scale=max_scale, min_scale=min_scale, ignore_label=ignore_label, num_classes=num_classes, num_readers=2, num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)
  • 功能:下载已训练好的模型权重文件,加载并进行评估。
  • API
    • load_checkpoint:加载checkpoint文件中的模型参数。
    • load_param_into_net:将加载的参数导入网络模型。
    • model.eval:对评估数据集进行评估。
9. 推理并显示结果
import cv2
import matplotlib.pyplot as plt

net = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []

# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):
    plt.subplot(2, 4, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 4, i + 5)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()
  • 功能:使用训练好的模型进行推理,并展示输入图片与推理结果。
  • API
    • dataset_eval.create_dict_iterator:创建评估数据集的字典迭代器。
    • plt.imshow:显示图像。
    • plt.subplots_adjust:调整子图之间的间距。

总结

这段代码实现了FCN8s模型在语义分割任务上的训练和推理过程,使用了MindSpore框架,并定义了相关的数据处理、模型网络、评估指标和训练参数等。以下是主要步骤:

  1. 下载并处理数据集。
  2. 定义数据集类并进行数据预处理。
  3. 定义FCN8s模型结构。
  4. 加载预训练的VGG16参数。
  5. 定义评估指标。
  6. 设置训练参数并进行模型训练。
  7. 加载训练好的模型并进行评估。
  8. 使用训练好的模型进行推理并展示结果。

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-20 00:32:02       52 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-20 00:32:02       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-20 00:32:02       45 阅读
  4. Python语言-面向对象

    2024-07-20 00:32:02       55 阅读

热门阅读

  1. 新手教程---python-函数(新添加)

    2024-07-20 00:32:02       19 阅读
  2. Leetcode226.翻转二叉树

    2024-07-20 00:32:02       17 阅读