【论文复现】zoedepth踩坑

注意模型IO:
保证输入、输出精度、类型与复现目标一致。

模型推理的代码

from torchvision import transforms
def image_to_tensor(img_path, unsqueeze=True):
    rgb = transforms.ToTensor()(Image.open(img_path))
    if unsqueeze:
        rgb = rgb.unsqueeze(0)
    return rgb


def disparity_to_tensor(disp_path, unsqueeze=True):
    disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
    # disp = cv2.imread(disp_path, -1) / (2 ** 8 - 1)
    disp = torch.from_numpy(disp)[None, ...]
    if unsqueeze:
        disp = disp.unsqueeze(0)
    return disp.float()

输入图像:uint8 960*1280

# load input
try:
    image = image_to_tensor(img_path).cuda()  # [1,3,h,w]
except:
    image = image_to_tensor(img_path[:-3] + 'jpg').cuda()  # [1,3,h,w]

# image = image.type(torch.float32)/255/255
if image.shape[1] == 1:
    image = torch.tile(image, dims=(1, 3, 1, 1))
# image = image.float() / (2 ** 16 - 1)
image = image[:, 0:3, ...]

使用numpy加载测试,必须归一化到【0,1】

image_np = cv2.imread(input_pic, cv2.IMREAD_GRAYSCALE)

# if len(image_np.shape) == 2:
#     image_np= np.repeat(image_np[:, :, np.newaxis], 3, axis=2)
pic_int = torch.from_numpy(image_np).cuda().unsqueeze(0).unsqueeze(0).float()
if pic_int.shape[1] == 1:
	pic_int = torch.tile(pic_int , dims=(1, 3, 1, 1))
self.zoe(pic_int )

归一化对比

失败

pic_int = pic_int/255

成功

有的数据集中,图像本身就尺度很大。比如保存成16bit的byte格式,读入后:

tensor([[[[13629, 13629, 14012,  ..., 21654, 21017, 20635],
          [13629, 12993, 13247,  ..., 21654, 21781, 21399],
          [12993, 12865, 12738,  ..., 21145, 21017, 21272],
          ...,
          [17196, 17069, 16941,  ..., 21399, 20890, 22291],
          [17069, 17196, 16814,  ..., 21399, 21909, 22291],
          [17196, 17705, 16686,  ..., 21527, 22036, 65535]]]], device='cuda:0',
       dtype=torch.int32)

就需要二次归一化

归一化测试脚本

import torch
import warnings

def check_tensor_values(tensor):
    """
    Check the maximum and minimum values of a tensor.
    Issues a warning if the maximum value is greater than 1 or the minimum value is less than 0.001.

    Parameters:
    tensor (torch.Tensor): The tensor to check.
    """
    max_value = torch.max(tensor)
    min_value = torch.min(tensor)

    # Check for the maximum value condition
    if max_value > 1:
        warnings.warn("The maximum value is greater than 1!", UserWarning)

    # Check for the minimum value condition
    if min_value < 0.001:
        warnings.warn("The minimum value is less than 0.001!", UserWarning)

图像转换16位

from PIL import Image
import numpy as np

# 打开图像
image = Image.open('lutao_exp/thermal/left_thermal_darkpre_ft_0_epoch_out_depth_colored.png')

# 将图像转换为8位灰度图像,然后转换为16位
image_gray = image.convert('L')  # 转换为灰度图像
image_gray_16bit = np.array(image_gray, dtype=np.uint16) * 256  # 转换为16位

# 创建一个新的Pillow图像对象
image_16bit_pillow = Image.fromarray(image_gray_16bit, mode='I;16')

# 保存16位图像
image_16bit_pillow.save('lutao_exp/thermal/left_thermal_darkpre_ft_0_epoch_out_depth_colored16.png')

相关推荐

  1. 论文5:UMI

    2023-12-10 16:16:01       26 阅读
  2. 论文:torch.max(p,1)

    2023-12-10 16:16:01       34 阅读
  3. 论文的conda环境

    2023-12-10 16:16:01       31 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2023-12-10 16:16:01       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2023-12-10 16:16:01       100 阅读
  3. 在Django里面运行非项目文件

    2023-12-10 16:16:01       82 阅读
  4. Python语言-面向对象

    2023-12-10 16:16:01       91 阅读

热门阅读

  1. 力扣-151. 反转字符串中的单词

    2023-12-10 16:16:01       67 阅读
  2. 聊聊spring.mvc.servlet.load-on-startup

    2023-12-10 16:16:01       51 阅读
  3. HarmonyOS--ArkTS(1)--基本语法(2)

    2023-12-10 16:16:01       62 阅读
  4. K8S学习指南(3)-minikube的安装

    2023-12-10 16:16:01       50 阅读
  5. 从零开始搭建链上dex自动化价差套利程序(10)

    2023-12-10 16:16:01       64 阅读