CLIP 图文检索,相似度计算

CLIP 是OpenAI提出的神经网络,它可以从自然语言监督中有效地学习视觉概念。
CLIP 可以应用于任何视觉分类基准,只需提供要识别的视觉类别的名称,类似于 GPT-2 和 GPT-3 的“零样本”功能。

相关paper
用法可以参考github

这里举几个使用CLIP的例子。

首先你需要安装pytorch, 还有matplotlib, opencv等,
然后安装clip

pip install git+https://github.com/openai/CLIP.git

1.零样本图像分类

这里的分类并不是直接让CLIP预测一个标签,而是你给出一些标签的候选项,它会给这些候选项预测概率。

比如这张图片
请添加图片描述
给CLIP一些标签:“a dog”, “a cat”,“a man”,“a tree”, “food”,它会给每个标签预测一个概率,概率最高的就是最后的label.
你会看到"a cat"的得分最高。

同样的,如果用多张图片去匹配一个提示标签,可以用下面代码的logits_per_text.

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, preprocess = clip.load('ViT-B/32', device=device)

image = preprocess(Image.open('cat.png')).unsqueeze(0).to(device)
text = clip.tokenize(["a dog", "a cat","a man","a tree", "food"]).to(device)

with torch.no_grad():
   #计算图像和token之间的分数
   #logits_per_image是image和每个token的分数,Tensor(1,5)
   #logits_per_text是每个token和image的分数,Tensor(5,1)
   logits_per_image, logits_per_text = model(image, text)
   probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print('Label probs:', probs)
#output:Label probs: [[3.159e-03 9.927e-01 1.589e-03 3.490e-04 2.241e-03]]

2. 用提示语搜索图片

现在用COCO数据集的图片来做测试,用val_2017数据,大概5000张图片。
用一个提示语"a red bus"搜索最相近的3张图片,看能得到什么。
这里计算image feature和text feature的相似度时,用了余弦相似度。

data_folder = 'coco/val2017'
images = []
for root,dirs,files in os.walk(data_folder):
     for file in files:
         if file.endswith('jpg'):
             images.append(root + '/' + file)
text = clip.tokenize(['a red bus']).to(device)
text_features = model.encode_text(text)
result = {}
cos = nn.CosineSimilarity(dim=0)

for img in images:
    with torch.no_grad():
        image_preprocess = preprocess(Image.open(img)).unsqueeze(0).to(device)
        image_features = model.encode_image(image_preprocess)
        sim = cos(image_features[0], text_features[0]).item()
        sim = (sim+1)/2 #(-1,1) --> (0,1)
        result[img] = sim

sorted_value = sorted(result.items(), key=lambda x:x[1], reverse=True)
sorted_res = dict(sorted_value)
top_3 = dict(itertools.islice(sorted_res.items(),3))
print(top_3)
#
# fig,axs =plt.subplots(1,3)
#
# i=0
# for key in top_3:
#     key_img = cv2.cvtColor(cv2.imread(key),cv2.COLOR_BGR2RGB)
#     axs[i].imshow(key_img)
#     axs[i].set_title('sim='+"{:.3f}".format(top_3[key]))
#     axs[i].axis('off')
#     i=i+1
# fig.suptitle('a red bus')
#
# plt.show()

根据提示语,按相似度从高到低,检索出如下3张图片。

请添加图片描述

3.图片的相似度

给出两张图片,计算它们的相似度。
现在要比的是上面“a red bus"中左边2个图片的相似度。
是通过计算image feature的余弦相似度实现的,而image feature是通过CLIP的encode得到。

img1 = 'bus1.jpg'
img2 = 'bus2.jpg'
cos = nn.CosineSimilarity(dim=0)

img1_process = preprocess(Image.open(img1)).unsqueeze(0).to(device)
img2_process = preprocess(Image.open(img2)).unsqueeze(0).to(device)

img1_feature = model.encode_image(img1_process)
img2_feature = model.encode_image(img2_process)

sim = cos(img1_feature[0], img2_feature[0]).item()
sim = (sim+1)/2
print("similarity: ", sim)
#output: similarity:  0.844970703125

4.用图片检索图片

还是用这个红色的bus, 看看用它能从COCO数据中检索出什么。

请添加图片描述

img1='bus1.jpg'
input_image = preprocess(Image.open(img1)).unsqueeze(0).to(device)
input_image_features = model.encode_image(input_image)

result = {}
for img in images:
    with torch.no_grad():
        image_preprocess = preprocess(Image.open(img)).unsqueeze(0).to(device)
        image_features = model.encode_image( image_preprocess)
        cos = torch.nn.CosineSimilarity(dim=0)
        sim = cos(image_features[0],input_image_features[0]).item()
        sim = (sim+1)/2
        result[img]=sim


sorted_value = sorted(result.items(), key=lambda x:x[1], reverse=True)
sorted_res = dict(sorted_value)

top_3 = dict(itertools.islice(sorted_res.items(), 3))

print(top_3)

请添加图片描述

参考资料:
https://medium.com/@jeremy-k/unlocking-openai-clip-part-1-intro-to-zero-shot-classification-f81194f4dff7
https://medium.com/@jeremy-k/unlocking-openai-clip-part-2-image-similarity-bf0224ab5bb0

相关推荐

  1. 几种计算图像/向量相似的指标(实现)

    2024-04-03 02:20:03       12 阅读
  2. 数据的相似计算

    2024-04-03 02:20:03       36 阅读
  3. milvus 相似检索的底层原理

    2024-04-03 02:20:03       13 阅读
  4. python opencv比较图片相似

    2024-04-03 02:20:03       25 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-03 02:20:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-03 02:20:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-03 02:20:03       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-03 02:20:03       20 阅读

热门阅读

  1. k8s 常用指令

    2024-04-03 02:20:03       14 阅读
  2. Centos7安装Docker-Compose

    2024-04-03 02:20:03       20 阅读
  3. bash简化if-else

    2024-04-03 02:20:03       14 阅读
  4. P10086 [ROIR 2022 Day 1] 口算比赛

    2024-04-03 02:20:03       13 阅读
  5. radash 工具整理常用 API

    2024-04-03 02:20:03       14 阅读
  6. QT实现windows下获取CPU、内存及磁盘信息

    2024-04-03 02:20:03       14 阅读
  7. 浅谈数据治理之道 数据运用(四)

    2024-04-03 02:20:03       20 阅读
  8. Docker 容器如何访问外部网络以及端口映射原理?

    2024-04-03 02:20:03       18 阅读
  9. c语言之函数指针作形参

    2024-04-03 02:20:03       14 阅读
  10. Allegro许可分析工具

    2024-04-03 02:20:03       15 阅读
  11. AGI时代,LLM可以在AutoML哪些环节进行增强?

    2024-04-03 02:20:03       11 阅读