逻辑回归吧

import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
# train_set = torchvision.datasets.MNIST(root='../dataset/mnist', train=True, download=True)
# test_set = torchvision.datasets.MNIST(root='../dataset/mnist', train=False, download=True)

您指定的路径 …/dataset/mnist 是一个相对路径,表示将 MNIST 数据集下载到当前目录的上级目录中的 dataset/mnist 目录中。

具体来说,在您的文件系统中,如果您的当前工作目录是 /home/user/,那么相对路径 …/dataset/mnist 将会是 /home/dataset/mnist。

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
    
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

import torch.nn.functional as F
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel,self).__init__()
        self .linear = torch.nn.Linear(1,1)
        
    def forward(self,x):
        y_pred = F.sigmoid(self.linear(x))
        return y_pred
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss(reduction = 'sum')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    
    print(epoch,loss.item())
    plt.scatter(epoch,loss.data)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
x = np.linspace(0, 10, 200) # 每周学习时间
x_t = torch.Tensor(x).view((200, 1)) # 200行1列的矩阵
y_t = model(x_t)
y = y_t.data.numpy()
plt.scatter(x, y)
plt.plot([0, 10], [0.5, 0.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()

相关推荐

  1. 逻辑回归

    2024-03-14 07:12:08       40 阅读
  2. 逻辑回归OvR策略

    2024-03-14 07:12:08       54 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-14 07:12:08       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-14 07:12:08       106 阅读
  3. 在Django里面运行非项目文件

    2024-03-14 07:12:08       87 阅读
  4. Python语言-面向对象

    2024-03-14 07:12:08       96 阅读

热门阅读

  1. 使用链表的优先级队列

    2024-03-14 07:12:08       41 阅读
  2. qt+ffmpeg 实现音视频播放(一)

    2024-03-14 07:12:08       38 阅读
  3. Qt如何保证控件调用时候的线程安全

    2024-03-14 07:12:08       40 阅读
  4. 22.5 RabbitMQ

    2024-03-14 07:12:08       36 阅读
  5. centos 7.x 上安装 AI insightface + pytorch + cuda

    2024-03-14 07:12:08       43 阅读