自编码器实现

1.AutoEncoder.py

import torch
from torch import nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super(Encoder, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, latent_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x


class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.linear1 = nn.Linear(latent_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x


class AutoEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, output_size):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(input_size, hidden_size, latent_size)
        self.decoder = Decoder(latent_size, hidden_size, output_size)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

2.dataset.py

from io import BytesIO

import lmdb
from PIL import Image
from torch.utils.data import Dataset

from imutils.paths import list_files


class LMDBDataset(Dataset):
    def __init__(self, path, transform, resolution=256, max_num=70000):
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)

        self.keys = []
        with self.env.begin(write=False) as txn:
            cursor = txn.cursor()
            for idx, (key, _) in enumerate(cursor):
                self.keys.append(key)
                if idx > max_num:
                    break

        self.length = len(self.keys)
        self.resolution = resolution
        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = self.keys[index]
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer).resize((self.resolution, self.resolution))
        img = self.transform(img)

        return img


IMG_EXTENSIONS = ['webp', '.png', '.jpg', '.jpeg', '.ppm', '.bmp', '.pgm', '.tif', '.tiff']


class NormalDataset(Dataset):
    def __init__(self, path, transform, resolution=256, max_num=70000):
        self.files = []
        listed_files = sorted(list(list_files(path)))
        for i in range(min(max_num, len(listed_files))):
            file = listed_files[i]
            if any(file.lower().endswith(ext) for ext in IMG_EXTENSIONS):
                self.files.append(file)

        self.resolution = resolution
        self.transform = transform
        self.length = len(self.files)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        img = Image.open(self.files[index]).resize((self.resolution, self.resolution))
        img = self.transform(img)

        return img


def set_dataset(type, path, transform, resolution):
    datatype = None
    if type == 'lmdb':
        datatype = LMDBDataset
    elif type == 'normal':
        datatype = NormalDataset
    else:
        raise NotImplementedError
    return datatype(path, transform, resolution)

3.utils.py

from torch.utils import data


def data_sampler(dataset, shuffle):
    if shuffle:
        return data.RandomSampler(dataset)
    else:
        return data.SequentialSampler(dataset)


def sample_data(loader):
    while True:
        for batch in loader:
            yield batch

4.train.py

from torch import nn, optim
from tqdm import tqdm
from AutoEncoder import AutoEncoder
from torch.utils import data
from torchvision import transforms
from dataset import set_dataset
import argparse
from utils import data_sampler, sample_data
import torch
import os
import matplotlib.pyplot as plt


#  获取数据集
def train(args, dataloader_train, dataloader_test, model, criterion, optimizer):
    # 训练
    for epoch in range(args.epochs):
        model.train()
        train_loss = 0
        train_sample = 0
        t = tqdm(dataloader_train, desc=f'[{epoch}/{args.epochs}]')
        for i, x in enumerate(t):
            x = x.to(args.device).view(args.batch_size, args.image_size * args.image_size * 3)
            output = model(x)
            loss = criterion(x, output)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # 计算平均损失
            train_loss += loss.item()
            train_sample += args.batch_size
            t.set_postfix({"loss": train_loss / train_sample})

        torch.save({"model_state_dict": model.state_dict(), "epoch": epoch+1}, f'{args.ckpt_dir}/ckpt_{epoch}.pt')
        torch.save({"model_state_dict": model.state_dict(), "epoch": epoch+1}, f'{args.ckpt_dir}/ckpt.pt')

        # 测试
        model.eval()
        imgs = next(sample_data(dataloader_test))
        imgs = imgs.to(args.device)
        imgs = imgs.to(args.device).view(args.batch_size, args.image_size * args.image_size * 3)
        test_output = model(imgs)

        imgs = imgs[0].view(3, args.image_size, args.image_size)
        imgs = imgs.permute(1, 2, 0) * 0.5 + 0.5
        test_output = test_output[0].view(3, args.image_size, args.image_size)
        test_output = test_output.permute(1, 2, 0) * 0.5 + 0.5
        concat = torch.cat((imgs, test_output), 1)
        print(concat.shape)
        plt.matshow(concat.cpu().detach().numpy(), cmap='gray')
        plt.savefig(f"{args.sample_dir}/test_{epoch}.png")
        plt.show()


if __name__ == '__main__':
    # 设置训练参数
    args = {
        "exp_name": "test2",
        "dataset_type": "normal",
        "dataset_path_train": r"D:\MyFiles\Papers\Codes\datasets\celebahq_train",
        "dataset_path_test": r"D:\MyFiles\Papers\Codes\datasets\celebahq_test",
        "image_size": 64,
        "batch_size": 32,
        "hidden_size": 128,
        "latent_size": 64,
        "learning_rate": 0.001,
        "epochs": 30,
        "device": "cuda",
        "log_interval": 100,
        "save_interval": 500,
        "ckpt": "ckpt.pt",
        "resume": False,
    }
    args = argparse.Namespace(**args)

    # Create folders for working directory
    base_dir = f"experiments/{args.exp_name}"
    ckpt_dir = f"{base_dir}/checkpoints"
    sample_dir = f"{base_dir}/samples"
    os.makedirs(ckpt_dir, exist_ok=True)
    os.makedirs(sample_dir, exist_ok=True)
    args.ckpt_dir = ckpt_dir
    args.sample_dir = sample_dir

    # Init transforms
    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    # Init dataset and dataloader
    dataset_train = set_dataset(
        type=args.dataset_type,
        path=args.dataset_path_train,
        transform=transform,
        resolution=args.image_size
    )

    loader_train = data.DataLoader(
        dataset=dataset_train,
        batch_size=args.batch_size,
        sampler=data_sampler(dataset=dataset_train, shuffle=True)
    )

    dataset_test = set_dataset(
        type=args.dataset_type,
        path=args.dataset_path_test,
        transform=transform,
        resolution=args.image_size
    )

    loader_test = data.DataLoader(
        dataset=dataset_test,
        batch_size=args.batch_size,
        sampler=data_sampler(dataset=dataset_test, shuffle=False)
    )

    input_size = output_size = args.image_size * args.image_size * 3

    #  获取模型
    model = AutoEncoder(input_size, args.hidden_size, args.latent_size, output_size).to(args.device)

    # 加载预训练模型
    if args.resume:
        print("加载预训练模型:", args.ckpt)
        ckpt = torch.load(f"{ckpt_dir}/{args.ckpt}", map_location=lambda storage, loc: storage)
        args.epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model_state_dict'])

    # 设置损失函数
    criterion = nn.MSELoss()

    # 设置优化器
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    train(args, loader_train, loader_test, model, criterion, optimizer)

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-04-23 21:40:06       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-23 21:40:06       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-23 21:40:06       20 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-23 21:40:06       20 阅读

热门阅读

  1. web学习

    web学习

    2024-04-23 21:40:06      12 阅读
  2. sklearn【Accuracy】准确度介绍和案例学习!

    2024-04-23 21:40:06       15 阅读
  3. QT中使用QTableView控件

    2024-04-23 21:40:06       16 阅读
  4. C语言学习笔记

    2024-04-23 21:40:06       16 阅读
  5. 双目视觉(双目相机)

    2024-04-23 21:40:06       12 阅读
  6. C语言结构体介绍

    2024-04-23 21:40:06       19 阅读
  7. 大厂面试:二叉搜索树如何获取其中第k小的结点

    2024-04-23 21:40:06       16 阅读
  8. 动态规划专练( 231.打家劫舍Ⅱ)

    2024-04-23 21:40:06       11 阅读