OpenCV与机器学习:使用opencv实现KNN


前言

KNN算法,即K-Nearest Neighbor,是一种简单的机器学习算法。它的基本思想是:如果一个样本在特征空间中的K个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。KNN算法通过计算待分类样本与已知类别样本的距离,找出距离最近的K个样本,根据这K个样本的类别,通过投票的方式确定待分类样本的类别。

在KNN算法中,距离的度量方式主要有欧式距离、曼哈顿距离、余弦相似度等。其中,欧式距离是最常用的距离度量方式,计算的是两点之间的直线距离。

KNN算法的优点包括简单直观、易于实现、对异常点不敏感等。但同时,KNN算法也存在一些局限性,如对数据分布敏感、K值选择主观等。此外,对于大规模数据集和高维数据,KNN算法的效率可能会较低。

总的来说,KNN算法是一种基础且重要的机器学习算法,在数据分类、回归分析、聚类等领域都有广泛的应用。

OpenCV实现KNN

生成随机数据

import numpy as np

def generate_data(num_samples, num_features=2):
    data_size = (num_samples, num_features)
    train_data = np.random.randint(0, 100, size=data_size)
    labels_size = (num_samples, 1)
    labels = np.random.randint(0, 2, size=labels_size)
    return train_data.astype(np.float32), labels

我们构建了一个函数来生成数据集,num_samples代表的是数据量,num_features默认为2代表的是数据有两个特征。通过随机数生成数据和标签。

train_data, labels = generate_data(11)
train_data

生成的数据如下:
在这里插入图片描述

  • 注:由于随机种子的不同生成的数据很有可能不同,以实际运行结果为准(下面的画图也一样)

下面以第0个数据为例看一下数据和标签的对应

train_data[0], labels[0]

在这里插入图片描述

使用matplotlib绘制数据

import matplotlib.pyplot as plt

plt.style.use('ggplot')
plt.plot(train_data[0, 0], train_data[0, 1], color='r', marker='^', markersize=10)
plt.xlabel('x coordinate')
plt.ylabel('y coordinate')

在代码中plt.style.use(‘ggplot’)的作用是为了使得画图更加好看。随后随便画了数据中的一个点的效果如下:
在这里插入图片描述
为了方便展示数据,我们在下面构建一个绘图函数用不同的颜色和不同的形状表示不同类别的数据。

def plot_data(all_blue, all_red):
    plt.figure(figsize=(10, 6))
    plt.scatter(all_blue[:, 0], all_blue[:, 1], c='b', marker='s', s=180)
    plt.scatter(all_red[:, 0], all_red[:, 1], c='r', marker='^', s=180)
    plt.xlabel('x coordinate')
    plt.ylabel('y coordinate')

现在我们有了数据和展示函数,那么如何区分数据的类别呢,我们可以根据标签(labels)来选择不同类别的数据。不过生成的标签使用其实并不方便,在这里我们使用ravel()将数据展平。效果对比如下:

labels, labels.ravel()

在这里插入图片描述
现在我们已经有了数据的选择方式和画图方式,现在我们来画一下不同类别的数据。

blue = train_data[labels.ravel()==0]
red = train_data[labels.ravel()==1]
plot_data(blue, red)

在这里插入图片描述

OpenCV用于KNN训练

在cv2中的ml模块中有一些常见的机器学习方法,使用KNearest_create可以创建一个KNN的分类器,训练时需要传入的数据有数据、数据组织方式、标签

import cv2

knn = cv2.ml.KNearest_create()
knn.train(train_data, cv2.ml.ROW_SAMPLE, labels)

在代码中使用的数据组织方式为cv2.ml.ROW_SAMPLE,即每一行就是一条数据。

生成新数据并观察

训练完之后我们再生成一个不含标签的数据

newcomer, _ = generate_data(1)
newcomer

在这里插入图片描述
画出原始数据和新生成的不含标签的数据,一起进行对比。

plot_data(blue, red)
plt.plot(newcomer[0, 0], newcomer[0, 1], 'go', markersize=14)

在下面的图中,绿色的圆形就是新的数据,这个数据不含标签。我们观察到这个数据和蓝色方块很近,所以大概率是属于蓝色方块的。即属于第0类。
在这里插入图片描述

预测

我们先来使用最近的一个点来进行预测,这里findNearest第二个参数是1

ret, results, neighbor, dist = knn.findNearest(newcomer, 1)
results, neighbor, dist

在这里插入图片描述
然后使用最近的三个点来进行预测,这里findNearest第二个参数是3

ret, results, neighbor, dist = knn.findNearest(newcomer, 3)
results, neighbor, dist

在这里插入图片描述
可以看到预测结果是正确的

与此同时我们也可以使用predict方法进行预测。但是,首先我们需要设置k,我们要根据最近的k个
数据点来判断新数据具体属于哪一类。

使用方式如下:

knn.setDefaultK(1)
knn.predict(newcomer)

在这里插入图片描述

knn.setDefaultK(3)
knn.predict(newcomer)

在这里插入图片描述

相关推荐

  1. opencv学习 机器学习

    2024-02-04 15:08:02       34 阅读
  2. 机器学习---KNN案例

    2024-02-04 15:08:02       39 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-02-04 15:08:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-02-04 15:08:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-02-04 15:08:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-02-04 15:08:02       20 阅读

热门阅读

  1. Docker 第九章 : Docker 容器的互联(linking)

    2024-02-04 15:08:02       29 阅读
  2. QT总结-使用QOpenGLWidget后界面刷新卡死

    2024-02-04 15:08:02       36 阅读
  3. MySQL系统配置

    2024-02-04 15:08:02       34 阅读
  4. 2024/2/3 备战蓝桥杯 4-2 排序

    2024-02-04 15:08:02       36 阅读
  5. 按照特定的规则对 HTML 标签进行分类的方法

    2024-02-04 15:08:02       31 阅读
  6. C#(C Sharp)学习笔记_运算符【四】

    2024-02-04 15:08:02       31 阅读
  7. 二级C语言笔试2

    2024-02-04 15:08:02       24 阅读
  8. NLP深入学习(十三):AdaBoost 算法

    2024-02-04 15:08:02       32 阅读
  9. serialVersionUID详解

    2024-02-04 15:08:02       29 阅读