昇思训练营打卡第十八天(K近邻算法实现红酒聚类)

K近邻(K-Nearest Neighbors,KNN)算法是一种基本的机器学习算法,它既可以用于分类任务,也可以用于回归任务。KNN算法的核心思想是,如果一个新样本在特征空间中的K个最邻近的样本大多数属于某一个类别,那么这个新样本也属于这个类别

KNN算法的基本步骤

  1. 选择距离度量:确定样本之间的距离计算方法,常用的距离度量方法有欧氏距离、曼哈顿距离等。

  2. 确定邻居的数量K:K值的选择对KNN算法的结果有重要影响,K值通常需要通过交叉验证等方法来确定。

  3. 选择训练样本:训练样本应该能够代表整个数据集的特性。

  4. 进行分类

    • 对于一个新的输入实例,计算它与训练集中每一个实例的距离。
    • 选择距离最近的K个实例。
    • 根据这K个实例的标签,通过多数投票等方式,确定新实例的类别。

KNN算法的优缺点

优点:
  • 简单易懂,易于实现。
  • 不需要训练模型,因此对于训练数据没有假设,适用于各种类型的决策边界。
  • 可以适用于多分类问题。
缺点:
  • 计算量大,因为需要计算每个测试样本与所有训练样本的距离。
  • 对噪声和离群点敏感,因为它们会影响近邻的选择。
  • 需要预先确定K值,K值的选择对结果有较大影响。

应用场景

KNN算法在现实世界中广泛应用于模式识别、文本分类、图像识别等领域,尤其是在数据分布较为稠密且特征维度不高的情况下表现良好。

from download import download

# 下载红酒数据集
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MachineLearning/wine.zip"  
path = download(url, "./", kind="zip", replace=True)
%matplotlib inline
import os
import csv
import numpy as np
import matplotlib.pyplot as plt

import mindspore as ms
from mindspore import nn, ops

ms.set_context(device_target="CPU")
with open('wine.data') as csv_file:
    data = list(csv.reader(csv_file, delimiter=','))
print(data[56:62]+data[130:133])
X = np.array([[float(x) for x in s[1:]] for s in data[:178]], np.float32)
Y = np.array([s[0] for s in data[:178]], np.int32)
attrs = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols',
         'Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue',
         'OD280/OD315 of diluted wines', 'Proline']
plt.figure(figsize=(10, 8))
for i in range(0, 4):
    plt.subplot(2, 2, i+1)
    a1, a2 = 2 * i, 2 * i + 1
    plt.scatter(X[:59, a1], X[:59, a2], label='1')
    plt.scatter(X[59:130, a1], X[59:130, a2], label='2')
    plt.scatter(X[130:, a1], X[130:, a2], label='3')
    plt.xlabel(attrs[a1])
    plt.ylabel(attrs[a2])
    plt.legend()
plt.show()
train_idx = np.random.choice(178, 128, replace=False)
test_idx = np.array(list(set(range(178)) - set(train_idx)))
X_train, Y_train = X[train_idx], Y[train_idx]
X_test, Y_test = X[test_idx], Y[test_idx]
class KnnNet(nn.Cell):
    def __init__(self, k):
        super(KnnNet, self).__init__()
        self.k = k

    def construct(self, x, X_train):
        #平铺输入x以匹配X_train中的样本数
        x_tile = ops.tile(x, (128, 1))
        square_diff = ops.square(x_tile - X_train)
        square_dist = ops.sum(square_diff, 1)
        dist = ops.sqrt(square_dist)
        #-dist表示值越大,样本就越接近
        values, indices = ops.topk(-dist, self.k)
        return indices

def knn(knn_net, x, X_train, Y_train):
    x, X_train = ms.Tensor(x), ms.Tensor(X_train)
    indices = knn_net(x, X_train)
    topk_cls = [0]*len(indices.asnumpy())
    for idx in indices.asnumpy():
        topk_cls[Y_train[idx]] += 1
    cls = np.argmax(topk_cls)
    return cls
acc = 0
knn_net = KnnNet(5)
for x, y in zip(X_test, Y_test):
    pred = knn(knn_net, x, X_train, Y_train)
    acc += (pred == y)
    print('label: %d, prediction: %s' % (y, pred))
print('Validation accuracy is %f' % (acc/len(Y_test)))

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-11 12:22:02       7 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-11 12:22:02       8 阅读
  3. 在Django里面运行非项目文件

    2024-07-11 12:22:02       7 阅读
  4. Python语言-面向对象

    2024-07-11 12:22:02       10 阅读

热门阅读

  1. spring boot 3.2.x 使用CDS加速启动

    2024-07-11 12:22:02       10 阅读
  2. 37.深度学习中的梯度下降法及其实现

    2024-07-11 12:22:02       9 阅读
  3. Spring Boot与Spring MVC的区别和联系

    2024-07-11 12:22:02       9 阅读
  4. 代码随想录-DAY⑥-哈希表——leetcode 383 | 454

    2024-07-11 12:22:02       9 阅读
  5. linux去掉行首的#字符

    2024-07-11 12:22:02       7 阅读
  6. 常见的负载均衡算法和实现方式

    2024-07-11 12:22:02       11 阅读
  7. Android焦点之Focused Window的更新(二)

    2024-07-11 12:22:02       8 阅读
  8. SpringBoot源码阅读(9)——转换服务

    2024-07-11 12:22:02       8 阅读
  9. C#中的Dictionary

    2024-07-11 12:22:02       9 阅读
  10. C语言标准库中的函数

    2024-07-11 12:22:02       9 阅读
  11. MVC分页

    MVC分页

    2024-07-11 12:22:02      11 阅读
  12. 整数 d → 字符 ‘d‘ 的转换代码为:d+‘0‘

    2024-07-11 12:22:02       9 阅读
  13. 进阶版智能家居系统Demo[C#]:整合AI和自动化

    2024-07-11 12:22:02       10 阅读
  14. 【C语言】C语言可以做什么?

    2024-07-11 12:22:02       9 阅读