决策树实例学习python

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn import metrics

# Load the Iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Create a decision tree classifier
clf = DecisionTreeClassifier(random_state=42)

# Train the classifier on the training set
clf.fit(X_train, y_train)

# Predictions on the training set
y_train_pred = clf.predict(X_train)

# Predictions on the testing set
y_test_pred = clf.predict(X_test)

# Calculate accuracy
accuracy_train = metrics.accuracy_score(y_train, y_train_pred)
accuracy_test = metrics.accuracy_score(y_test, y_test_pred)

# Visualize the decision tree (text representation)
tree_rules = export_text(clf, feature_names=iris.feature_names)
print("Decision Tree Rules:\n", tree_rules)

# Plotting the training set
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train_pred, cmap='viridis', edgecolors='k')
plt.title(f"Decision Tree - Training Accuracy: {
     accuracy_train:.2f}")

# Plotting the testing set
plt.subplot(1, 2, 2)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test_pred, cmap='viridis', edgecolors='k')
plt.title(f"Decision Tree - Testing Accuracy: {
     accuracy_test:.2f}")

plt.tight_layout()
plt.show()

在这里插入图片描述

这个输出是训练后决策树的文本表示。下面解释一下这个表示:

|--- petal length (cm) <= 2.45
|   |--- class: 0
|--- petal length (cm) >  2.45
|   |--- petal length (cm) <= 4.75
|   |   |--- petal width (cm) <= 1.65
|   |   |   |--- class: 1
|   |   |--- petal width (cm) >  1.65
|   |   |   |--- class: 2
|   |--- petal length (cm) >  4.75
|   |   |--- petal width (cm) <= 1.75
|   |   |   |--- petal length (cm) <= 4.95
|   |   |   |   |--- class: 1
|   |   |   |--- petal length (cm) >  4.95
|   |   |   |   |--- petal width (cm) <= 1.55
|   |   |   |   |   |--- class: 2
|   |   |   |   |--- petal width (cm) >  1.55
|   |   |   |   |   |--- petal length (cm) <= 5.45
|   |   |   |   |   |   |--- class: 1
|   |   |   |   |   |--- petal length (cm) >  5.45
|   |   |   |   |   |   |--- class: 2
|   |   |--- petal width (cm) >  1.75
|   |   |   |--- petal length (cm) <= 4.85
|   |   |   |   |--- sepal width (cm) <= 3.10
|   |   |   |   |   |--- class: 2
|   |   |   |   |--- sepal width (cm) >  3.10
|   |   |   |   |   |--- class: 1
|   |   |   |--- petal length (cm) >  4.85
|   |   |   |   |--- class: 2

这个表示是决策树的结构,每一行代表一个决策节点,缩进表示层次。例如,第一行表示如果花瓣长度小于等于2.45厘米,则预测类别为0。如果花瓣长度大于2.45厘米,则会根据下一个条件(petal length (cm) <= 4.75)继续分支,以此类推。

最后的类别预测(class: X)表示决策树的叶子节点,其中X是预测的类别。

这个决策树在训练时学习了如何根据输入特征来做出分类决策。

相关推荐

  1. ID3算法 决策学习 Python实现

    2023-12-22 14:40:02       39 阅读
  2. 机器学习:如何在Python实现决策分类?

    2023-12-22 14:40:02       10 阅读
  3. Python机器学习】理论知识:决策

    2023-12-22 14:40:02       39 阅读
  4. Python 机器学习 决策 分类原理

    2023-12-22 14:40:02       32 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-22 14:40:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-22 14:40:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-22 14:40:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-22 14:40:02       20 阅读

热门阅读

  1. Dockerfile巩固:阅读解析nginx的Dockerfile

    2023-12-22 14:40:02       35 阅读
  2. 数据库连接问题 - ChatGPT对自身的定位

    2023-12-22 14:40:02       35 阅读
  3. 第二十一章网络通讯

    2023-12-22 14:40:02       30 阅读
  4. Curl多线程https访问,崩溃问题修复

    2023-12-22 14:40:02       47 阅读
  5. 基于博弈树的开源五子棋AI教程[5 启发式搜索]

    2023-12-22 14:40:02       40 阅读