使用Transformers库中的模型提取图像特征遇到的问题

最近做一个图像相关的项目,需要提取图像特征,在使用Transformers库中的深度神经网络模型提取图像特征的过程中,遇到一些问题,记录一下。

下面是图像特征提取的简化代码及相应的中间输出,使用的是OpenAI的CLIP模型:

import os
import torch
import numpy as np

from PIL import Image
from transformers import AutoProcessor, CLIPModel

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

processor = AutoProcessor.from_pretrained('openai/clip-vit-base-patch32')
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device)

img_dir = '/dir/to/img'

img_paths = [os.path.join(img_dir, filename)  for filename in os.listdir(img_dir)]
# ['/dir/to/img/a.tif', '/dir/to/img/b.jpg', '/dir/to/img/c.jpg', '/dir/to/img/d.tif']

pilimgs = [Image.open(img) for img in img_paths]
# [<PIL.TiffImagePlugin.TiffImageFile image mode=CMYK size=3008x2000>,  <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=3702x2592>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=3024x1084>, <PIL.TiffImagePlugin.TiffImageFile image mode=CMYK size=3872x2592>]

nparrs = [np.array(img) for img in pilimgs]
# [arr.shape for arr in nparrs]
# [(2000, 3008, 4), (2592, 3702, 3), (1084, 3024, 3), (2592, 3872, 4)]

with torch.no_grad():
    inputs = processor(images=nparrs, return_tensors='pt')
    image_feature = model.get_image_features(**inputs)

通过中间输出可以看到,使用PIL.Image读取图片得到的对象的mode属性是不一样的,tif图片对应的mode为CMYK,jpg图片对应的mode为RGB,将这些对象转化为ndarray后,能够看到tif图片的通道数为4,而jpg图片的通道数为3,图片数据的这些差异会导致使用模型批量提取图片特征时出现问题,而且输入的数据类型和输入图片的顺序不同,出现的问题也不一样。

不同的模型输入及对应的结果

  • 直接输入PIL.Image.Image对象列表(如下所示),不会报错
with torch.no_grad():
    inputs = processor(images=pilimgs, return_tensors='pt')
    image_feature = model.get_image_features(**inputs)

直接传入PIL.Image.Image对象列表,无论列表中是相同mode的图片对象还是不同mode的图片对象,模型都能正确处理,不会报错。但是实际情况下,一般不会直接传入PIL.Image.Image列表,因为通常需要处理大量图片,在这种情况下,大量使用Image.open而没有进行正确的后续处理很容易造成内存泄漏。可以配合with语句将图片转化为ndarray

with open(img_path, 'rb') as f:
    img = Image.open(f)
    arr = np.array(img)
  • 输入np.ndarray列表,列表中含有通道数为4的图片且第一个元素的通道数为4或列表中图片的通道数都为4
    出现错误ValueError: Unable to infer channel dimension format
  • 输入np.ndarray列表,列表中图片的通道数不同且第一个元素的通道数为3
    出现错误ValueError: mean must have 4 elements if it is an iterable, got 3
  • 输入np.ndarray列表,列表中图片的通道数都为3,成功执行

可以看到,当传入给模型的数据类型为np.ndarray(或torch.Tensor)时,存在通道数为4(mode为CMYK或其他)的图片都会导致模型处理出现异常,因此需要将这些图片转化为RGB模式,也就是通道数为3:

with open(img_path, 'rb') as f:
    img = Image.open(f)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    arr = np.array(img)

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-30 06:02:05       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-30 06:02:05       100 阅读
  3. 在Django里面运行非项目文件

    2024-04-30 06:02:05       82 阅读
  4. Python语言-面向对象

    2024-04-30 06:02:05       91 阅读

热门阅读

  1. 【 深度可分离卷积】

    2024-04-30 06:02:05       29 阅读
  2. 设计模式(四)、策略模式

    2024-04-30 06:02:05       33 阅读
  3. Python:将数组从一个范围等效到另一个范围

    2024-04-30 06:02:05       34 阅读
  4. github fork项目不带tag解决

    2024-04-30 06:02:05       28 阅读
  5. el-row中元素如何上下居中对齐?

    2024-04-30 06:02:05       35 阅读