昇思学习打卡-16-热门LLM及其他AI应用/K近邻算法实现红酒聚类

算法原理

K近邻算法可以用在分类问题和回归问题上,它的原理如下:要确定一个样本的类别,可以计算它与所有训练样本的距离,然后找出和该样本最接近的k个样本,统计出这些样本的类别并进行投票,票数最多的那个类就是分类的结果。
KNN的三个基本要素:

  • K值,一个样本的分类是由K个邻居的“多数表决”确定的。K值越小,容易受噪声影响,反之,会使类别之间的界限变得模糊。
  • 距离度量,反映了特征空间中两个样本间的相似度,距离越小,越相似。常用的有Lp距离(p=2时,即为欧式距离)、曼哈顿距离、海明距离等。
  • 分类决策规则,通常是多数表决,或者基于距离加权的多数表决(权值与距离成反比)。

距离定义

计算不同样本间的距离,可以使用欧氏距离,有时也是用Mahalanobis距离、Bhattacharyya距离
使用欧式距离时,应将特征向量的每个分量进行归一化,以减少特征值得尺度范围不同带来的干扰

模型构建

class KnnNet(nn.Cell):
    def __init__(self, k):
        super(KnnNet, self).__init__()
        self.k = k

    def construct(self, x, X_train):
        #平铺输入x以匹配X_train中的样本数
        x_tile = ops.tile(x, (128, 1))
        square_diff = ops.square(x_tile - X_train)
        square_dist = ops.sum(square_diff, 1)
        dist = ops.sqrt(square_dist)
        #-dist表示值越大,样本就越接近
        values, indices = ops.topk(-dist, self.k)
        return indices

def knn(knn_net, x, X_train, Y_train):
    x, X_train = ms.Tensor(x), ms.Tensor(X_train)
    indices = knn_net(x, X_train)
    topk_cls = [0]*len(indices.asnumpy())
    for idx in indices.asnumpy():
        topk_cls[Y_train[idx]] += 1
    cls = np.argmax(topk_cls)
    return cls
acc = 0
knn_net = KnnNet(5)
for x, y in zip(X_test, Y_test):
    pred = knn(knn_net, x, X_train, Y_train)
    acc += (pred == y)
    print('label: %d, prediction: %s' % (y, pred))
print('Validation accuracy is %f' % (acc/len(Y_test)))

在这里插入图片描述
此章节学习到此结束,感谢昇思平台。

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-12 14:58:03       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-12 14:58:03       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-12 14:58:03       58 阅读
  4. Python语言-面向对象

    2024-07-12 14:58:03       69 阅读

热门阅读

  1. Node.js 模块系统

    2024-07-12 14:58:03       17 阅读
  2. 模板方法模式的实现

    2024-07-12 14:58:03       20 阅读
  3. Android.mk中LOCAL_SDK_VERSION的作用是什么?

    2024-07-12 14:58:03       20 阅读
  4. C++:右值引用

    2024-07-12 14:58:03       22 阅读
  5. Xcode Playgrounds:探索Swift编程的交互式乐园

    2024-07-12 14:58:03       22 阅读
  6. Okhttp实现原理

    2024-07-12 14:58:03       15 阅读
  7. 2713. 矩阵中严格递增的单元格数

    2024-07-12 14:58:03       20 阅读
  8. global::System.Runtime.InteropServices.DllImport

    2024-07-12 14:58:03       20 阅读