Pytorch学习-调整torchvision.models中模型输出类别数

假设你的类别只有10个,而torchvision.models中Vgg16的输出类别为1000,这时应该如何调整呢?

方法一,直接修改模型中类别的输出。

from torch.nn import Linear
import torchvision
import torch

Vgg16=torchvision.models.vgg16(pretrained=True)
Vgg16.classifier[6]=Linear(in_features=4096,out_features=10)
if torch.cuda.is_available():
    T=Vgg16.cuda()

方法二,再模型的最后增加全连接层,改变输出类别。

from torch.nn import Linear
import torchvision
import torch

res=torchvision.models.resnet101(pretrained=True,progress=True)
res.fc.add_module('linelayer',Linear(in_features=1000,out_features=10))
if torch.cuda.is_available():
    T=res.cuda()

 

相关推荐

  1. Pytorch学习-调整torchvision.models模型输出类别

    2024-05-15 21:50:11       11 阅读
  2. pytorch模型训练的学习率动态调整

    2024-05-15 21:50:11       11 阅读
  3. bert pytorch模型转onnx,并改变输入输出

    2024-05-15 21:50:11       28 阅读
  4. PyTorch模块、类和函数的命名和调用

    2024-05-15 21:50:11       33 阅读
  5. pytorch通道不一样怎么办?

    2024-05-15 21:50:11       18 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-15 21:50:11       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-15 21:50:11       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-15 21:50:11       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-15 21:50:11       20 阅读

热门阅读

  1. 代码随想录Day28

    2024-05-15 21:50:11       9 阅读
  2. 绘制奇迹:Processing中的动态图形与动画

    2024-05-15 21:50:11       9 阅读
  3. 深度学习实战:定制化智能狗门的迁移学习之旅

    2024-05-15 21:50:11       6 阅读
  4. 机器学习_朴素贝叶斯

    2024-05-15 21:50:11       10 阅读
  5. 论文合集整理推荐2024.5.15

    2024-05-15 21:50:11       12 阅读
  6. 如何在 Ubuntu 14.04 上为 Nginx 创建 SSL 证书

    2024-05-15 21:50:11       8 阅读
  7. 《IT行业的未来:趋势与展望》

    2024-05-15 21:50:11       8 阅读
  8. scanf、printf、string函数族

    2024-05-15 21:50:11       12 阅读
  9. linux的知识点分享

    2024-05-15 21:50:11       11 阅读