pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.path

import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class DtataAndLabel(Dataset):
    def __init__(self,path='fruits',is_train=True):
        self.tran=transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(size=(88,88))
        ])
        is_train='train' if True else 'test'
        self.data=[]
        path=os.path.join(path,is_train)
        print('path=',path)
        print(os.path.join(path, '*', '*'))
        img_paths=glob.glob(os.path.join(path,'*','*'))
        for img_path in img_paths:
            label=0 if img_path.split('\\')[-2]=='blueberry' else 1
            self.data.append((img_path,label))
    def __getitem__(self, idx):
        #每一张图片返回一个img_tensor,one_hot
        img_path,label =self.data[idx]
        img=cv2.imread(img_path)
        # img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
        img_tensor=self.tran(img)
        img_tensor=img_tensor/255
        img_tensor=torch.flatten(img_tensor)
        one_hot=torch.zeros(2)
        one_hot[label]=1
        return img_tensor,one_hot
    def __len__(self):
        return len(self.data)

if __name__ == '__main__':
    # 测试
    data=DtataAndLabel()
    print(data[1][0].shape)
    print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(88*88*3, 800),
            nn.ReLU(),
            nn.Linear(800, 500),
            nn.ReLU(),
            nn.Linear(500, 800),
            nn.ReLU(),
            nn.Linear(800, 200),
            nn.ReLU(),
            nn.Linear(200, 2),

        )
        self.softmax=nn.Softmax(dim=1)
    def forward(self,x):
        x=self.model(x)
        x=self.softmax(x)
        return x
if __name__ == '__main__':
    net=Net()
    #测试一下
    x=torch.randn(1,100*100)
    out=net(x)
    print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():
    def __init__(self):
        self.writer = SummaryWriter("logs")
        self.train_data=DtataAndLabel(is_train=True)
        self.test_data=DtataAndLabel(is_train=False)
        #使用加载器分批加载
        self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)
        self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)
        #初始化网络
        #损失函数
        #优化器

        net=Net()
        self.net=net
        self.loss=nn.MSELoss()
        self.opt=torch.optim.Adam(net.parameters(),lr=0.001)
        self.min_loss=100.0
        self.weight_path='weight/best.pt'

    def train(self,epoch):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)

            self.opt.zero_grad()
            loss.backward()
            self.opt.step()

            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.train_loader)
        avg_acc = sum_acc / len(self.train_loader)
        print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
        self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)
        self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)

    def test(self,epoch):
        sum_loss = 0
        sum_acc = 0
        for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):
            out = self.net(img_tensors)
            loss = self.loss(out, targets)
            sum_loss += loss.item()
            pred_cls = torch.argmax(out, dim=1)
            target_cls = torch.argmax(targets, dim=1)
            accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))
            sum_acc += accuracy.item()
        avg_loss = sum_loss / len(self.test_loader)
        avg_acc = sum_acc / len(self.test_loader)
        print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')
        self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)
        self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)
        if avg_loss<self.min_loss:
            self.min_loss=min(self.min_loss,avg_loss)
            torch.save(self.net.state_dict(), self.weight_path)
    def run(self):
        for epo in range(100):
            self.train(epo)
            self.test(epo)

if __name__ == '__main__':
    trainer=TrainAndTest()
    trainer.run()



精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

相关推荐

  1. pytorch bert实现文本分类

    2024-07-11 21:18:03       51 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-11 21:18:03       70 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-11 21:18:03       74 阅读
  3. 在Django里面运行非项目文件

    2024-07-11 21:18:03       62 阅读
  4. Python语言-面向对象

    2024-07-11 21:18:03       72 阅读

热门阅读

  1. 减法原则的定义

    2024-07-11 21:18:03       19 阅读
  2. 实现基于Zookeeper的分布式协调服务

    2024-07-11 21:18:03       22 阅读
  3. ios的info.plist 配置

    2024-07-11 21:18:03       25 阅读
  4. iOS 开发中不常见的专业术语

    2024-07-11 21:18:03       18 阅读
  5. Onnx 1-深度学习-Operators

    2024-07-11 21:18:03       22 阅读
  6. Windows 32 汇编笔记(一):基础知识

    2024-07-11 21:18:03       18 阅读
  7. HarmonyOS学习之ArkTS语法补充学习

    2024-07-11 21:18:03       24 阅读
  8. Linux基础: 三. 相对路径和绝对路径

    2024-07-11 21:18:03       26 阅读
  9. Lemo 的 AGI 应用实战博文导航

    2024-07-11 21:18:03       20 阅读
  10. 音视频开发——FFmpeg 实现MP4转FLV文件 C语言实现

    2024-07-11 21:18:03       20 阅读