机器学习:P-R曲线(Python)

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


def data_preproce():
    # 加载数据,数据预处理
    digits = datasets.load_digits()
    X, y = digits.data, digits.target
    random_state = np.random.RandomState(0)
    n_example, n_features = X.shape
    X = np.c_[X, random_state.randn(n_example, 10 * n_features)]  # 添加噪声特征
    X = StandardScaler().fit_transform(X)  # 标准化
    y = label_binarize(y, classes=np.unique(y))  # one-hot
    # 划分数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0,
                                                        shuffle=True, stratify=y)
    return X_train, X_test, y_train, y_test


def model_trian(model):
    # 模型训练
    classifier = OneVsRestClassifier(model)
    classifier.fit(X_train, y_train)
    y_score = classifier.decision_function(X_test)
    return y_score


def micor_PR(y_test, y_score):
    # For each class
    precision = dict()
    recall = dict()
    average_precision = dict()
    n_classes = y_score.shape[1]  # 类别数
    for i in range(n_classes):
        # Compute precision-recall pairs for different probability thresholds.
        precision[i], recall[i],_ = precision_recall_curve(y_test[:, i], y_score[:, i])
        # Compute average precision (AP) from prediction scores.
        # \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
        average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])
    # A "micro-average": quantifying score on all classes jointly
    precision["micro"], recall["micro"],_ = precision_recall_curve(y_test.ravel(), y_score.ravel())
    average_precision["micro"] = average_precision_score(y_test, y_score, average="micro")
    return precision, recall, average_precision

def plt_PR_curve(precision, recall, average_precision, label):
    # 绘制P-R曲线
    label = label + ': AP={0:0.2f}'.format(average_precision["micro"])
    plt.step(recall['micro'], precision['micro'], where='post', lw=2, label=label)  # 绘制接替图

X_train, X_test, y_train, y_test = data_preproce()
y_score = model_trian(LogisticRegression())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt.figure(figsize=(8, 6))
plt_PR_curve(precision, recall, average_precision, "LogisticRegression")

y_score = model_trian(SVC())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt_PR_curve(precision, recall, average_precision, "svm.SVC")

y_score = model_trian(LinearDiscriminantAnalysis())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt_PR_curve(precision, recall, average_precision, "LinearDiscriminantAnalysis")

plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.grid()
plt.title('Average precision score, micro_averaged over all classes', fontsize=14)
plt.legend(fontsize=12)
plt.show()


from itertools import cycle
# setup plot details

y_score = model_trian(LogisticRegression())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt.figure(figsize=(9, 8))
f_scores = np.linspace(0.2, 0.8, num=4)
lines, labels = [], []  # 图例线条和标签值

# F-score 等高线绘制
for f_score in f_scores:
    x = np.linspace(0.01, 1)
    y = f_score * x / (2 * x - f_score)
    line, = plt.plot(x[y >= 0], y[y >= 0], color='gray', ls='--', alpha=0.5)
    # alpha参数用于设置透明度,0~1之间,alpha越小,线条越透明
    plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))  # annotate注释

lines.append(line)
labels.append('iso-f1 curves')
line, = plt.plot(recall["micro"], precision["micro"], ls="-.", lw=2)
lines.append(line)
labels.append('micro-average(area = {0:0.2f})'.format(average_precision["micro"]))

for i in range(y_score.shape[1]):
    line, = plt.plot(recall[i], precision[i], lw=1.5)
    lines.append(line)
    labels.append('class {0} (area = {1:0.2f})'.format(i, average_precision[i]))

fig = plt.gcf()
fig.subplots_adjust(bottom=0.25)
plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Extension of Precision-Recall curve to multi-class', fontsize=14)
plt.legend(lines, labels, loc=(1.02, 0), prop=dict(size=12))

plt.tight_layout()
plt.show()


 

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-01-24 07:08:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-24 07:08:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-24 07:08:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-24 07:08:02       20 阅读

热门阅读

  1. 【计算机二级考试C语言】C文件读写

    2024-01-24 07:08:02       32 阅读
  2. golang常用库之-golang-jwt/jwt包

    2024-01-24 07:08:02       29 阅读
  3. 美易官方:黑石股价盘前走低受评级下调影响

    2024-01-24 07:08:02       37 阅读
  4. 基于开源库imu_gps_localization进行组合导航

    2024-01-24 07:08:02       35 阅读
  5. Qt容器QHash类

    2024-01-24 07:08:02       37 阅读
  6. C++类的入门

    2024-01-24 07:08:02       20 阅读
  7. windows主题与壁纸所在文件夹路径

    2024-01-24 07:08:02       37 阅读