手写kNN算法的实现-用余弦相似度来度量距离

设a为预测点,b为其中一个样本点,在向量空间里,它们的形成的夹角为θ,那么θ越小(cosθ的值越接近1),就说明a点越接近b点。所以我们可以通过考察余弦相似度来预测a点的类型。

在这里插入图片描述

在这里插入图片描述

from collections import Counter
import numpy as np

class MyKnn:
    def __init__(self,neighbors):
        self.k = neighbors
    
    def fit(self,X,Y):
        self.X = np.array(X)
        self.Y = np.array(Y)
        if self.X.ndim != 2 or self.Y.ndim != 1:
            raise Exception("dimensions are wrong!")
        
        if self.X.shape[0] != self.Y.shape[0]:
            raise Exception("input labels are not correct!")
    
    def predict(self,X_pre):
        
        pre = np.array(X_pre)
        if self.X.ndim != pre.ndim:
            raise Exception("input dimensions are wrong!")
        rs = []
        for p in pre:
            temp = []
            for a in self.X:
                cos = (p @ a)/np.linalg.norm(p)/np.linalg.norm(a)
                temp.append(cos)
            temp = np.array(temp)
            indices = np.argsort(temp)[:-self.k-1:-1]
            ss = np.take(self.Y,indices)
            found = Counter(ss).most_common(1)[0][0]
            print(found)
            rs.append(found)
        return np.array(rs)
        

测试:

# 用鸢尾花数据集来验证我们上面写的算法
from sklearn.datasets import load_iris
# 使用train_test_split对数据集进行拆分,一部分用于训练,一部分用于测试验证
from sklearn.model_selection import train_test_split
# 1.生成一个kNN模型
myknn = MyKnn(5)
# 2.准备数据集:特征集X_train和标签集y_train
X_train,y_train = load_iris(return_X_y=True)
# 留出30%的数据集用于验证测试
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train,test_size=0.3)
# 3.训练模型
myknn.fit(X_train,y_train)
# 4.预测,acc就是预测结果
acc = myknn.predict(X_test)
# 计算准确率
(acc == y_test).mean()

其实如果余弦相似度来进行分类,那么根据文章最开头讲到的,其实取余弦值最大的点作为预测类型也可以:

import numpy as np

class MyClassicfication:
    
    def fit(self,X,Y):
        self.X = np.array(X)
        self.Y = np.array(Y)
        if self.X.ndim != 2 or self.Y.ndim != 1:
            raise Exception("dimensions are wrong!")
        
        if self.X.shape[0] != self.Y.shape[0]:
            raise Exception("input labels are not correct!")
    
    def predict(self,X_pre):
        
        pre = np.array(X_pre)
        if self.X.ndim != pre.ndim:
            raise Exception("input dimensions are wrong!")
        rs = []
        for p in pre:
            temp = []
            for a in self.X:
                cos = (p @ a)/np.linalg.norm(p)/np.linalg.norm(a)
                temp.append(cos)
            temp = np.array(temp)
            index = np.argsort(temp)[-1]
            found = np.take(self.Y,index)
            rs.append(found)
        return np.array(rs)
        

测试:

# 用鸢尾花数据集来验证我们上面写的算法
from sklearn.datasets import load_iris
# 使用train_test_split对数据集进行拆分,一部分用于训练,一部分用于测试验证
from sklearn.model_selection import train_test_split
# 1.生成一个kNN模型
myCla = MyClassicfication
# 2.准备数据集:特征集X_train和标签集y_train
X_train,y_train = load_iris(return_X_y=True)
# 留出30%的数据集用于验证测试
X_train,X_test,y_train,y_test = train_test_split(X_train,y_train,test_size=0.3)
# 3.训练模型
myCla.fit(X_train,y_train)
# 4.预测,acc就是预测结果
acc = myCla.predict(X_test)
# 计算准确率
(acc == y_test).mean()

经测试,上面两种方式的准确率是差不多的。

相关推荐

  1. 相似度量方法整理

    2024-06-10 09:50:02       28 阅读
  2. 各种距离相似度量及计算

    2024-06-10 09:50:02       11 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-10 09:50:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-10 09:50:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-10 09:50:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-10 09:50:02       20 阅读

热门阅读

  1. 1341. 电影评分

    2024-06-10 09:50:02       12 阅读
  2. 如何学好量子计算机技术的两种思路

    2024-06-10 09:50:02       7 阅读
  3. 爬山算法详细介绍

    2024-06-10 09:50:02       12 阅读
  4. 4. 流程控制语句

    2024-06-10 09:50:02       10 阅读
  5. vscode 好用的插件

    2024-06-10 09:50:02       12 阅读
  6. 23种设计模式——创建型模式

    2024-06-10 09:50:02       10 阅读
  7. 2024年6月10日--6月16日(渲染+ue独立游戏)

    2024-06-10 09:50:02       14 阅读
  8. Terminal Multiplexer的使用

    2024-06-10 09:50:02       13 阅读
  9. 什么情况下需要用到动态IP

    2024-06-10 09:50:02       11 阅读
  10. node-mysql中占位符?的使用

    2024-06-10 09:50:02       10 阅读
  11. 007 CentOS 7.9 apache-tomcat-9.0.89安装及配置

    2024-06-10 09:50:02       10 阅读
  12. 设计模式-策略模式

    2024-06-10 09:50:02       11 阅读
  13. 密码学及其应用——安全邮件、公钥密码学和PKI

    2024-06-10 09:50:02       13 阅读