KMean 聚类
KMean 聚类
1 解决什么问题
假设二维坐标轴上有一些点,现在让你把这些点分个类。于是对我们来说,这个
分类
似乎就是把距离相近的点
画到一类中去。
- 假设要划分
N
类,坐标点M
个 - 从
M
个坐标点随机选取N
个点,作为每个分类的中心点
,这N
个点的列表记录为centerPointList
- 遍历
M
个坐标点中的每个点- 计算当前点和
N
个中心点的距离,dis1、dis2 ... disN
- 从
dis1、dis2 ... disN
找到最小的距离的下标。下标记录为cluster
,那么这个cluster
就是这次遍历时候当前点归属的分类。
- 计算当前点和
- 步骤
3
结束后,每个点都会归属到某个分类。计算每个分类中点集合的均值,把这个均值作为新的中心点
,替换掉centerPointList
。 - 重复
3、4
直到重复次数大于约定次数,或者中心点
变化较小。此时就可以知道每个点归属的分类。
2 java实现计算二维点的聚类案例
package com.forezp.kmean;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* @author yuegang
*/
public class KMeanCluster {
/**
* 表示二维空间中的点
*/
public static class Point {
Integer x = 0;
Integer y = 0;
public Point() {
}
public Point(Integer x, Integer y) {
this.x = x;
this.y = y;
}
public void incX(Integer x) {
this.x += x;
}
public void incY(int y) {
this.y += y;
}
public Integer getX() {
return x;
}
public void setX(Integer x) {
this.x = x;
}
public Integer getY() {
return y;
}
public void setY(Integer y) {
this.y = y;
}
@Override
public String toString() {
return "(" + x + ", " + y + ")";
}
}
/**
* 表示二维空间中的点
* 下标是点的顺序
*/
private final List<Point> pointIndexDataMap;
private final List<List<Point>> centerPointList = Lists.newArrayList(); // 记录每一个分类的中心点
private final List<Integer> pointClusterMap = Lists.newArrayList(); // 点所属的分类
private int index = 0; // 计算次数
private int clusterCount = 0; // 分类个数
public KMeanCluster(List<Point> pointIndexDataMap, int clusterCount) {
this.pointIndexDataMap = pointIndexDataMap;
this.clusterCount = clusterCount;
index = 0;
initCenterPoint();
initCluster(pointIndexDataMap);
}
private void initCluster(List<Point> pointIndexDataMap) {
// 初始化每个点的分类,设置一个没有意义的值
for (int j = 0; j < pointIndexDataMap.size(); ++j) {
pointClusterMap.add(-1);
}
}
private void initCenterPoint() {
List<Point> objects = Lists.newArrayListWithExpectedSize(clusterCount);
List<Integer> yList = Lists.newArrayListWithExpectedSize(clusterCount);
Random random = new Random();
for (int i = 0; i < clusterCount; ++i) {
// 注意这个不能相同
int i1 = random.nextInt(pointIndexDataMap.size());
while (yList.contains(i1)) {
i1 = random.nextInt(pointIndexDataMap.size());
}
yList.add(i1);
}
for (int i = 0; i < clusterCount; ++i) {
objects.add(pointIndexDataMap.get(yList.get(i)));
}
centerPointList.add(objects);
}
public void calc() {
List<Point> pointIndices = centerPointList.get(index);
for (int i = 0; i < pointIndexDataMap.size(); ++i) {
Point point = pointIndexDataMap.get(i);
// 计算该点和那个簇最近,把把归属到这个簇中。
int cluster = 0;
double min = Double.MAX_VALUE;
for (int inc = 0; inc < pointIndices.size(); ++inc) {
Point point1 = pointIndices.get(inc);
Integer x = point.getX();
Integer y = point.getY();
Integer x1 = point1.getX();
Integer y1 = point1.getY();
int i1 = x - x1;
int i2 = y - y1;
int total = i1 * i1 + i2 * i2;
double sqrt = Math.sqrt(total);
if (sqrt < min) {
min = sqrt;
cluster = inc;
}
}
pointClusterMap.set(i, cluster);
}
// 计算每个族的中心点;
int size = centerPointList.get(0).size();
Map<Integer, Point> map = Maps.newTreeMap();
Map<Integer, Integer> cluterCount = Maps.newHashMapWithExpectedSize(size);
for (int i = 0; i < pointClusterMap.size(); ++i) {
int cluster = pointClusterMap.get(i);
Point point = map.computeIfAbsent(cluster, sss -> new Point());
cluterCount.put(cluster, cluterCount.getOrDefault(cluster, 0) + 1);
Point point1 = pointIndexDataMap.get(i);
point.incX(point1.getX());
point.incY(point1.getY());
}
for (Map.Entry<Integer, Point> integerPointEntry : map.entrySet()) {
Integer key = integerPointEntry.getKey();
Point point = integerPointEntry.getValue();
Integer integer = cluterCount.get(key);
point.setX(point.getX() / integer);
point.setY(point.getY() / integer);
}
++index;
Map<Integer, List<Point>> curClassfiyMap = Maps.newTreeMap();
for (int i = 0; i < pointClusterMap.size(); ++i) {
Point point = pointIndexDataMap.get(i);
Integer classfly = pointClusterMap.get(i);
List<Point> points = curClassfiyMap.computeIfAbsent(classfly, k -> Lists.newArrayList());
points.add(point);
}
List<Point> curCenterPointList = new ArrayList<>(map.values());
centerPointList.add(curCenterPointList);
show(curClassfiyMap, curCenterPointList);
}
private void show(Map<Integer, List<Point>> curClassfiyMap, List<Point> curCenterPointList) {
System.out.println("计算次数:" + index);
System.out.println("当前分类:" + curClassfiyMap);
System.out.println("当前中心点:" + curCenterPointList);
}
public static void main(String[] args) {
Point point = new Point(100, 100);
Point point1 = new Point(1, 1);
Point point2 = new Point(110, 120);
Point point3 = new Point(10, 20);
Point point4 = new Point(130, 160);
List<Point> pointIndexDataMap = Lists.newArrayList(point, point1, point2, point3, point4);
KMeanCluster oneCalc = new KMeanCluster(pointIndexDataMap, 2);
for (int i = 0; i < 2; ++i) {
oneCalc.calc();
}
}
}
输出
计算次数:1
当前分类:{
0=[(110, 120), (130, 160)], 1=[(100, 100), (1, 1), (10, 20)]}
当前中心点:[(120, 140), (37, 40)]
计算次数:2
当前分类:{
0=[(100, 100), (110, 120), (130, 160)], 1=[(1, 1), (10, 20)]}
当前中心点:[(113, 126), (5, 10)]