机器学习——编程从零实现决策树【二】

第一节的内容:机器学习——编程实现从零构造训练集的决策树-CSDN博客

目录

v2:划分验证集,优化预测

1)划分训练集和验证集

2)完善预测过程

3)训练模型并验证

v3:k折交叉验证

1)理论

2)实践


v2:划分验证集,优化预测

1)划分训练集和验证集

1. 从好瓜里面选出3个

2. 从坏瓜里面选出3个

import numpy as np
def show(D):
  for i in D:
    print(i)

good_index = np.arange(0, 8, 1, dtype=np.int16)
np.random.shuffle(good_index)
print(f"好瓜打乱后的索引={good_index}")

bad_index = np.arange(8,17,1,dtype=np.int16)
np.random.shuffle(bad_index)
print(f"坏瓜打乱后的索引{bad_index}")

train_data = [ D[i] for i in good_index[3:]]+[D[i] for i in bad_index[3:]]
val_data = [D[i] for i in good_index[:3]] + [D[i] for i in bad_index[:3]]
print("训练集的数据为")
show(train_data)
print("验证集的数据为")
show(val_data)


结果如下:

2)完善预测过程

考虑到预测样本中可能存在出现新的属性值的情况:

def predict_v2(data,root):
  cur = root
  while cur.label != 1:
    attr = cur.bestattr
    key = data[attr]
    # 如果样本出现新的属性值,则这个样本被标记为当前结点数量最多的类别
    if key not in cur.subDs:
      return cur.max

    cur = cur.subDs[key]
  return cur.Class

分类精度的计算:

def calAccuracy(pred,data):
  n = len(data)
  re = 0
  for i in range(n):
    if pred[i] == data[i]['Class']:
      re+=1
  return re/n

3)训练模型并验证

# 建树
root_v2 = TreeGenerate(train_data,Attr)
# 画图
drawTree(root_v2)


re = []
for i in range(len(val_data)):
  re.append(predict_v2(val_data[i],root_v2))

print(f"精度是:{calAccuracy(re,val_data)}")

结果截图(部分)

最后一行显示分类精度:约为0.8333

v3:k折交叉验证

变更划分方式,选择性能更好的模型

1)理论

2)实践

① 打乱原始数据集的数据

# 打乱顺序
def shuffle(D):
  index = np.arange(0,len(D),1,dtype=np.int16)
  cpD = copy.deepcopy(D)
  np.random.shuffle(cpD)
  return cpD

D_v3 = shuffle(D)

② 将数据集划分若干份

def CreateByK(D,k):
  # 根据k折将数据集D划分出若干大小为k的子集
  n = int(len(D)/k)
  re = []
  for i in range(n):
    re.append(D[k*i:i*k+k])
  re.append(D[n*k:])
  return re


Ds_K = CreateByK(D_v3,3)
show(Ds_K)

③ 模型评估函数,返回指标精度

# 打包验证过程
def evaluate(val_data,root):
  re = []
  for i in range(len(val_data)):
    re.append(predict_v2(val_data[i],root))
  ans = calAccuracy(re,val_data)
  print(f"精度是:{ans}")
  return ans

④ 组合不同的训练集和验证集,对模型进行训练、评估,记录

def trainByK(Ds_K,Attr):
  n = len(Ds_K)
  models = []
  for i in range(n):
    train_data = CreateTrain_data(Ds_K,i)

    val_data = Ds_K[i]
    print(val_data)
    # 生成模型
    root = TreeGenerate(train_data,Attr)

    # 计算训练集上的精度
    acc_train = evaluate(train_data,root)

    # 用验证集预测并计算精度
    acc_val = evaluate(val_data,root)
    models.append({'model':root,'acc':acc_train,'acc_val':acc_val})
  return models

⑤ 找出训练集精度和验证集精度之和最高的模型

models = trainByK(Ds_K,Attr)
target = 0
for i in range(len(models)):
  if models[i]['acc']+models[i]['acc_val'] > models[target]['acc']+models[target]['acc_val']:
    target = i

drawTree(models[target]['model'])
print(f"模型在训练集上的损失:{models[target]['acc']},在验证集上的损失:{models[target]['acc_val']}")

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-03-20 12:56:02       19 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-20 12:56:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-20 12:56:02       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-20 12:56:02       20 阅读

热门阅读

  1. Rust 中的 Vec<u8> 类型

    2024-03-20 12:56:02       16 阅读
  2. golang踩坑记录

    2024-03-20 12:56:02       22 阅读
  3. Flutter插件开发与发布指南

    2024-03-20 12:56:02       20 阅读
  4. Flutter项目组件模块化开发的实践与搭建

    2024-03-20 12:56:02       19 阅读
  5. flutter-elinux的基本介绍及安装调试

    2024-03-20 12:56:02       17 阅读
  6. mysql建表&索引语句

    2024-03-20 12:56:02       20 阅读
  7. Flutter中自定义Dialog

    2024-03-20 12:56:02       19 阅读
  8. jenkins 连接harbor 推送镜像

    2024-03-20 12:56:02       19 阅读
  9. 安卓面试题多线程 91-95

    2024-03-20 12:56:02       18 阅读
  10. leetcode-hot100-图论

    2024-03-20 12:56:02       18 阅读
  11. Spring Data访问Elasticsearch----实体回调Entity Callbacks

    2024-03-20 12:56:02       19 阅读
  12. 医学预测变量筛选的几种方法(R语言版)

    2024-03-20 12:56:02       17 阅读
  13. immer的使用

    2024-03-20 12:56:02       17 阅读