实现EM算法的主循环

任务描述

本关任务:用 python 实现 EM 算法的迭代过程。请不要修改 Begin-End 段之外的代码。

相关知识

为了完成本关任务,你需要掌握 EM 算法的迭代流程。

EM 算法的迭代流程

通过上一关的学习,相信你已经对 EM 算法的单次迭代过程以及 EM 算法的核心思想和流程已经有一定的了解了。在这一关中我们会把上一关中没做的事情做完,就是迭代

EM 算法的迭代流程很简单,就是循环调用单次迭代过程而已。不过在循环时需要加上终止条件。一般来说终止条件有两个:

  • 最大迭代次数 : EM 算法的最大循环次数;

  • 参数变化的容忍度 :当 EM 算法估计出来的参数 θ 不怎么变化时,可以提前挑出循环。例如容忍度为 1e-3,则表示若这次迭代的估计结果与上一次迭代的估计结果之间的差异小于 1e-3 则跳出循环。

编程要求

根据提示,在右侧编辑器补充 Begin-End 段中的代码,完成 em(observations, thetas, tol=1e-4, iterations=100)函数。该函数需要完成的功能是模拟抛掷硬币实验并迭代估计硬币 A 与硬币 B 正面朝上的概率。其中:

  • observations :抛掷硬币的实验结果记录,类型为 list 。 list 的行数代表做了几轮实验,列数代表每轮实验用某个硬币抛掷了几次。 list 中的值代表正反面, 0 代表反面朝上, 1 代表正面朝上。如 [[1, 0, 1], [0, 1, 1]] 表示进行了两轮实验,每轮实验用某硬币抛掷三次。第一轮的结果是正反正,第二轮的结果是反正正。

  • thetas :硬币 A 与硬币 B 正面朝上的概率的初始值,类型为 list ,如 [0.2, 0.7] 代表硬币 A 正面朝上的概率为 0.2 ,硬币 B 正面朝上的概率为 0.7 。

  • tol :差异容忍度,即当 EM 算法估计出来的参数 theta 不怎么变化时,可以提前挑出循环。例如容忍度为 1e-4 ,则表示若这次迭代的估计结果与上一次迭代的估计结果之间的 L1 距离小于 1e-4 则跳出循环。为了正确的评测,请不要修改该值。

  • iterations : EM 算法的最大迭代次数。为了正确的评测,请不要修改该值。

  • 返回值:将估计出来的硬币 A 和硬币 B 正面朝上的概率组成 list 或者 ndarray 返回。如 [0.4, 0.6] 表示你认为硬币 A 正面朝上的概率为 0.4 ,硬币 B 正面朝上的概率为 0.6 。
测试说明

平台会对你编写的代码进行测试,你只需完成 em 函数即可。

测试输入: {'init_values':[0.2, 0.7], 'observations':[[1, 1, 0, 1, 0], [0, 0, 1, 1, 0], [1, 0, 0, 0, 0], [1, 0, 0, 1, 1], [0, 1, 1, 0, 0]]} 预期输出: [0.439928, 0.440072]

import numpy as np
from scipy import stats
def em_single(init_values, observations):
    """
    模拟抛掷硬币实验并估计在一次迭代中,硬币A与硬币B正面朝上的概率。请不要修改!!
    :param init_values:硬币A与硬币B正面朝上的概率的初始值,类型为list,如[0.2, 0.7]代表硬币A正面朝上的概率为0.2,硬币B正面朝上的概率为0.7。
    :param observations:抛掷硬币的实验结果记录,类型为list。
    :return:将估计出来的硬币A和硬币B正面朝上的概率组成list返回。如[0.4, 0.6]表示你认为硬币A正面朝上的概率为0.4,硬币B正面朝上的概率为0.6。
    """
    observations = np.array(observations)
    counts = {'A': {'H': 0, 'T': 0}, 'B': {'H': 0, 'T': 0}}
    theta_A = init_values[0]
    theta_B = init_values[1]
    # E step
    for observation in observations:
        len_observation = len(observation)
        num_heads = observation.sum()
        num_tails = len_observation - num_heads
        # 两个二项分布
        contribution_A = stats.binom.pmf(num_heads, len_observation, theta_A)
        contribution_B = stats.binom.pmf(num_heads, len_observation, theta_B)
        weight_A = contribution_A / (contribution_A + contribution_B)
        weight_B = contribution_B / (contribution_A + contribution_B)
        # 更新在当前参数下A、B硬币产生的正反面次数
        counts['A']['H'] += weight_A * num_heads
        counts['A']['T'] += weight_A * num_tails
        counts['B']['H'] += weight_B * num_heads
        counts['B']['T'] += weight_B * num_tails
    # M step
    new_theta_A = counts['A']['H'] / (counts['A']['H'] + counts['A']['T'])
    new_theta_B = counts['B']['H'] / (counts['B']['H'] + counts['B']['T'])
    return [new_theta_A, new_theta_B]
def em(observations, thetas, tol=1e-4, iterations=100):
    """
    模拟抛掷硬币实验并使用EM算法估计硬币A与硬币B正面朝上的概率。
    :param observations: 抛掷硬币的实验结果记录,类型为list。
    :param thetas: 硬币A与硬币B正面朝上的概率的初始值,类型为list,如[0.2, 0.7]代表硬币A正面朝上的概率为0.2,硬币B正面朝上的概率为0.7。
    :param tol: 差异容忍度,即当EM算法估计出来的参数theta不怎么变化时,可以提前挑出循环。例如容忍度为1e-4,则表示若这次迭代的估计结果与上一次迭代的估计结果之间的L1距离小于1e-4则跳出循环。为了正确的评测,请不要修改该值。
    :param iterations: EM算法的最大迭代次数。为了正确的评测,请不要修改该值。
    :return: 将估计出来的硬币A和硬币B正面朝上的概率组成list或者ndarray返回。如[0.4, 0.6]表示你认为硬币A正面朝上的概率为0.4,硬币B正面朝上的概率为0.6。
    """
    #********* Begin *********#
    iteration = 0
    thetas = np.array(thetas)
    while iteration < iterations:
        new_thetas = np.array(em_single(thetas, observations))
        delta_change = np.sum(np.abs(thetas - new_thetas))
        if delta_change < tol:
            break
        else:
            thetas = new_thetas
        iteration += 1
    return thetas
    #********* End *********#

 

最近更新

  1. TCP协议是安全的吗?

    2024-06-12 10:52:01       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-12 10:52:01       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-12 10:52:01       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-12 10:52:01       20 阅读

热门阅读

  1. go语言接口之http.Handler接口

    2024-06-12 10:52:01       8 阅读
  2. 富格林:活用经验可信提高出金

    2024-06-12 10:52:01       7 阅读
  3. 力扣1146.快照数组

    2024-06-12 10:52:01       11 阅读
  4. C++中的享元模式

    2024-06-12 10:52:01       9 阅读
  5. Ubuntu系统介绍

    2024-06-12 10:52:01       7 阅读
  6. $(this) 和 this 关键字在 jQuery 中有何不同?

    2024-06-12 10:52:01       7 阅读
  7. 他很意外,我竟然是女程序员?

    2024-06-12 10:52:01       7 阅读
  8. 掉电安全文件系统littlefs移植

    2024-06-12 10:52:01       4 阅读
  9. 等保测评和安全运维

    2024-06-12 10:52:01       9 阅读
  10. 安全等保评测-什么是“等保“?

    2024-06-12 10:52:01       7 阅读
  11. @vue/cli source and destination must not be the same

    2024-06-12 10:52:01       6 阅读