怎么把VMamba作为Feature Extractor集成到现有模型

首先克隆VMamba的项目代码

VMamba官方代码
安装相关依赖

pip install -r requirements.txt

然后自定义特征提取的代码

先把 Mamba官方代码仓库这个文件夹VMamba/classification/models下面的所有文件复制到你需要自定义特征提取的代码的目录下,然后把Mamba官方代码仓库 VMamba/kernels 文件夹复制到相同目录下,

然后执行这个命令

cd kernels/selective_scan && pip install .

然后你需要自定义编写一个 FeatureExtractor.py 进行测试

import torch
import torch.nn as nn
from VMamba import Backbone_VSSM
# Backbone_VSSM 根据不同的参数配置构建不同的模型变体

'''
# 实例化模型 vmamba_tiny_s2l5
model = Backbone_VSSM(
    out_indices=(0, 1, 2, 3), # 选择输出的stage
    pretrained=None,          # 不加载预训练权重
    norm_layer="ln",          # 使用LayerNorm
    depths=[2, 2, 5, 2],      # 每个stage的深度
    dims=96,                  # 每个stage的通道数
    drop_path_rate=0.2,       
    patch_size=4,
    in_chans=3,
    num_classes=1000,
    ssm_d_state=1,
    ssm_ratio=2.0,
    ssm_dt_rank="auto",
    ssm_act_layer="silu",
    ssm_conv=3,
    ssm_conv_bias=False,
    ssm_drop_rate=0.0,
    ssm_init="v0",
    forward_type="v05_noz",
    mlp_ratio=4.0,
    mlp_act_layer="gelu",
    mlp_drop_rate=0.0,
    gmlp=False,
    patch_norm=True,
    downsample_version="v3",
    patchembed_version="v2",
    use_checkpoint=False,
    posembed=False,
    imgsize=224
)
'''
# 实例化模型 vmamba_small_s2l15
model = Backbone_VSSM(
    out_indices=(0, 1, 2, 3), # 选择输出的stage
    pretrained=None,          # 不加载预训练权重
    norm_layer="ln2d",        # 使用LayerNorm2d
    depths=[2, 2, 15, 2],     # 每个stage的深度
    dims=96,                  # 每个stage的初始通道数
    drop_path_rate=0.3,       # drop path 概率
    patch_size=4,
    in_chans=3,
    num_classes=1000,
    ssm_d_state=1,
    ssm_ratio=2.0,
    ssm_dt_rank="auto",
    ssm_act_layer="silu",
    ssm_conv=3,
    ssm_conv_bias=False,
    ssm_drop_rate=0.0,
    ssm_init="v0",
    forward_type="v05_noz",
    mlp_ratio=4.0,
    mlp_act_layer="gelu",
    mlp_drop_rate=0.0,
    gmlp=False,
    patch_norm=True,
    downsample_version="v3",
    patchembed_version="v2",
    use_checkpoint=False,
    posembed=False,
    imgsize=224
)

# 加载预训练权重
pretrained_path = 'lib/vssm_small_0229_ckpt_epoch_222.pth'
checkpoint = torch.load(pretrained_path, map_location='cpu')

# 检查权重文件包含的键
# print(checkpoint.keys())  # 可能需要调整,根据实际权重文件内容

# 提取 'model' 键下的子字典
state_dict = checkpoint['model']

# 加载权重到模型
model.load_state_dict(state_dict, strict=False)

# 移动模型到GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 切换到评估模式
model.eval()

# 生成随机图像,并移动到GPU
random_image = torch.randn(1, 3, 224, 224).to(device)

# 前向传播,获取不同阶段的特征图
with torch.no_grad():
    feature_maps = model(random_image)

# 打印每个阶段的特征图形状
for i, feature_map in enumerate(feature_maps):
    print(f"Stage {i} feature map shape: {feature_map.shape}")

相关推荐

  1. 怎么VMamba作为Feature Extractor集成现有模型

    2024-07-17 11:16:03       25 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-17 11:16:03       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-17 11:16:03       71 阅读
  3. 在Django里面运行非项目文件

    2024-07-17 11:16:03       58 阅读
  4. Python语言-面向对象

    2024-07-17 11:16:03       69 阅读

热门阅读

  1. AI时代的技术应用与创新:探索未来

    2024-07-17 11:16:03       17 阅读
  2. 新版本 Android Studio 没有BuildConfig ?

    2024-07-17 11:16:03       29 阅读
  3. 前缀匹配工具之IP-Prefix

    2024-07-17 11:16:03       28 阅读
  4. 高精度减法(C++)

    2024-07-17 11:16:03       24 阅读
  5. 谈人工智能在电子档案系统的应用

    2024-07-17 11:16:03       18 阅读
  6. Android 音频通道切换HDMI,蓝牙,喇叭

    2024-07-17 11:16:03       26 阅读
  7. C#拆分单页PDF

    2024-07-17 11:16:03       25 阅读
  8. TCP/IP、UDP、HTTP 协议介绍比较和总结

    2024-07-17 11:16:03       22 阅读
  9. js | 原型链

    2024-07-17 11:16:03       23 阅读