从零实现softmax回归【基于Pytorch】

参考资料:沐神——动手学深度学习

import torch
import torchvision
from matplotlib import pyplot as plt
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
from IPython import display

def get_dataloader_workers():  #@save
    """使用自定义个进程来读取数据"""
    return 0   #大于0会产生线程死锁??

# ===================== 下载数据集 ==================
def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="F:\MnistDataSet", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="F:\MnistDataSet", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))


# ================== 从零实现softmax ===================
# 初始化模型参数
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(32)

num_inputs = 784
num_outputs = 10

w = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

# 定义softmax操作
def softmax(x):
    x_exp = torch.exp(x)
    partition = x_exp.sum(1,keepdim=True)
    return x_exp / partition

# 定义模型
def net(X):
    return softmax(torch.matmul(X.reshape((-1, w.shape[0])), w) + b)

# 定义损失函数
def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])


# 计算预测正确的数量(精度)
def accuracy(y_hat,y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)  # argmax(axis=1)返回行最大值的下标
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum() )

# 对于任意数据迭代器data_iter可访问的数据集, 我们可以评估在任意模型net的精度
def evaluate_accuracy(net, data_iter):  #@save
    """计算在指定数据集上模型的精度"""
    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

class Accumulator:  #@save
    """在n个变量上累加"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

# 训练
def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater):  #@save
    """训练模型(定义见第3章)"""
    animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],
                        legend=['train loss', 'train acc', 'test acc'])
    for epoch in range(num_epochs):
        train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc,))
    train_loss, train_acc = train_metrics
    assert train_loss < 0.5, train_loss
    assert train_acc <= 1 and train_acc > 0.7, train_acc
    assert test_acc <= 1 and test_acc > 0.7, test_acc

learning_rate = 0.1
def updater(batch_size):
    return d2l.sgd([w,b],learning_rate,batch_size)

num_epochs = 10
train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)

def predict_ch3(net, test_iter, n=6):  #@save
    """预测标签(定义见第3章)"""
    for x, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(x).argmax(axis=1))
    titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(
        x[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
    plt.show()
predict_ch3(net, test_iter)
# print(net, test_iter)

相关推荐

  1. 实现softmax回归基于Pytorch

    2024-02-19 17:20:03       30 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-02-19 17:20:03       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-02-19 17:20:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-02-19 17:20:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-02-19 17:20:03       18 阅读

热门阅读

  1. 使用docker搭建php开发环境

    2024-02-19 17:20:03       38 阅读
  2. 怎么测试阿里云香港服务器是不是cn2?

    2024-02-19 17:20:03       34 阅读
  3. C/C++与汇编交互总结

    2024-02-19 17:20:03       29 阅读
  4. (力扣记录)199.二叉树的右视图

    2024-02-19 17:20:03       24 阅读
  5. Linux中精简卷对Oracle的影响

    2024-02-19 17:20:03       33 阅读
  6. Oracle触发器

    2024-02-19 17:20:03       27 阅读
  7. 索引失效的 12 种情况

    2024-02-19 17:20:03       30 阅读
  8. C++/Python/MATLAB检查内存使用情况

    2024-02-19 17:20:03       35 阅读
  9. Python爬虫开发:Scrapy框架与Requests库

    2024-02-19 17:20:03       30 阅读
  10. 力扣_字符串10—重复的DNA序列

    2024-02-19 17:20:03       25 阅读