sklearn 笔记:neighbors.NearestNeighbors 自定义metric

1 数据

假设我们有这样的一个数据tst_lst,表示的是5条轨迹的墨卡托坐标,我们希望算出逐点的曼哈顿距离之和,作为两条轨迹的距离

[array([[11549759.51313693,   148744.89246911],
        [11549751.49813359,   148732.97804463],
        [11549757.62070558,   148738.21148336],
        [11549877.73443613,   148886.64075531],
        [11549855.1365795 ,   148900.67083319]]),
 array([[11556428.51911408,   145454.58226351],
        [11557035.91165162,   145493.83259114],
        [11557310.50343952,   145408.66217089],
        [11557748.16714946,   145339.9824732 ],
        [11558124.96136184,   145498.27539452]]),
 array([[11560299.60987809,   143642.48133694],
        [11560236.88134503,   143437.08940241],
        [11560254.26944949,   143331.75455279],
        [11560222.79942945,   143349.26953089],
        [11560224.0350758 ,   143354.70329418]]),
 array([[11559757.30584681,   143885.2194761 ],
        [11560304.02926187,   143639.87580025],
        [11560743.21804884,   143750.12120076],
        [11560626.52182665,   144103.28312704],
        [11560722.44583186,   144272.53199179]]),
 array([[11569978.06036478,   151723.38135785],
        [11569938.73118869,   151248.5811628 ],
        [11569616.11617246,   150791.67584703],
        [11569571.34347327,   150687.55191842],
        [11569688.57402901,   150674.10077112]])]

2 处理原始数据

2.1 直接喂入的问题

如果直接将上面的数据fit入NearestNeighbors,是会报错的:

from sklearn.neighbors import NearestNeighbors

cellKDtree=NearestNeighbors().fit(tst_lst)
cellKDtree
'''
ValueError: Found array with dim 3. NearestNeighbors expected <= 2.
'''

ValueError 是由于尝试在 NearestNeighbors 对象上使用三维数组导致的。NearestNeighbors 期望的输入是一个二维数组,其中每行代表一个数据点,每列代表一个特征

2.2 修改数据形状

每一个轨迹二维矩阵转化成一个一维向量

tst_lst=np.array(tst_lst)
tst_lst_new=[]

for i in range(len(tst_lst)):
    tst_lst_new.append(np.hstack(tst_lst[i]).tolist())
tst_lst_new

'''
[[11549759.513136925,
  148744.89246911363,
  11549751.49813359,
  148732.97804463338,
  11549757.620705582,
  148738.2114833576,
  11549877.734436132,
  148886.6407553058,
  11549855.136579504,
  148900.67083319122],
 [11556428.519114085,
  145454.58226351053,
  11557035.911651615,
  145493.83259113596,
  11557310.503439516,
  145408.66217089174,
  11557748.167149458,
  145339.9824731981,
  11558124.961361844,
  145498.2753945235],
 [11560299.609878086,
  143642.48133694328,
  11560236.881345032,
  143437.0894024146,
  11560254.269449493,
  143331.75455278732,
  11560222.79942945,
  143349.26953088713,
  11560224.035075797,
  143354.7032941798],
 [11559757.305846812,
  143885.21947610297,
  11560304.02926187,
  143639.8758002481,
  11560743.218048835,
  143750.12120075937,
  11560626.521826653,
  144103.28312704086,
  11560722.445831856,
  144272.53199179273],
 [11569978.060364777,
  151723.38135785353,
  11569938.731188687,
  151248.58116280191,
  11569616.116172463,
  150791.67584703089,
  11569571.343473272,
  150687.55191841844,
  11569688.57402901,
  150674.1007711226]]
'''

此时送入NearestNeighbor已经可以了

from sklearn.neighbors import NearestNeighbors

cellKDtree=NearestNeighbors().fit(tst_lst_new)
cellKDtree

3 自定义函数

from scipy.spatial.distance import *
import numpy as np
def disfunc(x,y):
    #每次比较fit入Nearest Neighbor 的矩阵的两行

    x_points=np.array([(x[i],x[i+1]) for i in range(0,len(x),2)])
    y_points=np.array([(y[i],y[i+1]) for i in range(0,len(y),2)])
    #提取经纬度,将每一行一维向量改成二维矩阵

    return float(np.sum(np.diag(cdist(x_points,y_points,metric='cityblock'))))
    '''
    cdist(x_points,y_points,metric='cityblock') 将得到一个二维矩阵,表示x每一个元素和y每一个元素的曼哈顿距离
    np.diag是取二维矩阵的对角元素,表示x和y对应位置元素的距离
    求和就是两条轨迹的距离
    '''

4 使用NearestNeighbor

注:似乎algorithm只能选择默认的brute,KD_tree和ball_tree都不行

from sklearn.neighbors import *

cellKDtree=NearestNeighbors(metric=disfunc).fit(tst_lst_new)
cellKDtree

相关推荐

  1. sklearn 笔记:neighbors.NearestNeighbors 定义metric

    2023-12-07 19:40:05       36 阅读
  2. 鸿蒙开发笔记(二):定义组件

    2023-12-07 19:40:05       35 阅读
  3. NebulaGraph学习笔记-定义池连接

    2023-12-07 19:40:05       26 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-07 19:40:05       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-07 19:40:05       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-07 19:40:05       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-07 19:40:05       20 阅读

热门阅读

  1. 当内核有内存泄漏的时候

    2023-12-07 19:40:05       27 阅读
  2. 【Spark基础】-- 宽窄依赖

    2023-12-07 19:40:05       34 阅读
  3. 拥有一台服务器可以做些什么

    2023-12-07 19:40:05       41 阅读
  4. Spark SQL百万级数据批量读写入MySQL

    2023-12-07 19:40:05       39 阅读
  5. 什么问题适合使用卡方检验?

    2023-12-07 19:40:05       28 阅读
  6. qt 链表QList,QLinkedList的常见使用

    2023-12-07 19:40:05       38 阅读
  7. 英伟达显卡系列与架构、代表产品

    2023-12-07 19:40:05       33 阅读
  8. Ubuntu 配置打开文件限制

    2023-12-07 19:40:05       43 阅读
  9. Python批量图像处理--图片重命名、图片旋转

    2023-12-07 19:40:05       40 阅读
  10. CG 函数

    2023-12-07 19:40:05       40 阅读
  11. 解决分布式React前端在本地开发环境的跨域问题

    2023-12-07 19:40:05       38 阅读
  12. 关于业界大语言模型(LLM)开源的一些看法

    2023-12-07 19:40:05       33 阅读
  13. 供应链产品经理常用的ChatGPT通用提示词模板

    2023-12-07 19:40:05       37 阅读
  14. MyBatis

    MyBatis

    2023-12-07 19:40:05      41 阅读