[关闭]
@wuqi0616 2018-03-29T14:08:37.000000Z 字数 11224 阅读 1481

EM算法 - 笔记

机器学习入门资料


1.符号理解:

:表示的概率密度函数,它是一个以为参数的函数。分号左边是随机变量;右边是模型参数。
:前者是Z的以为参数的先验概率密度函数,后者是给定Z下的条件概率。

引 Jensen 不等式

此处输入图片的描述
如图所示:

  1. 如若是凸函数():
    有结论:
  2. 如若是凸函数():
    有结论:
  3. 推论:当且仅当为常数的概率为1时,取到等号
    或者

3.最大似然估计(MLE)


已知时,对求导容易求解

未知时,不能通过对求导获得。

4.EM algorithm

4.1E-Step

选择合适的关于的分布


由Jensen不等式结论可知:

等号才能取到。且的某一分布,有这个性质:

由此可知,应取为给定下,的后验概率分布。

4.2M-Step

寻找参数最大化期望似然

直接求导,依然很麻烦,不过可以用迭代来最大化
(1)由
(2)


只需要最大化联合分布,求出

此处输入图片的描述

5.应用于高斯混合聚类:

概率密度函数为:


其中,是n维均值向量,的协方差矩阵。
高斯混合分布:

分布由个混合成分组成,每个混合部分对应一个高斯分布。其中是第个高斯混合成分的参数,而为相应的“混合系数”,且
E-step
训练集:
随机变量(高斯混合成分):
的先验概率
因此,的后验分布:

M-step
模型参数求解,给定样本集,极大似然估计,最大化(对数)似然


更新:


各混合成分的均值可通过样本加权平均来获得,样本的权重是每个样本属于该成分的后验概率。
更新:

更新混合系数
因要满足,,考虑拉格朗日形式:



每个高斯成分的混合系数由样本属于改成份的平均后验概率确定。

分簇


簇划分是由原型对应后验概率确定。

MOG-EM Algorithm

伪代码:
输入:样本集
高斯混合分布个数.
过程:
1:初始化高斯混合分布的模型参数
2:repeat
3:for do
4:计算由各混合成分生成的后验概率,即

5:end for
6:for do
7:计算新均值向量:
8:计算新协方差矩阵:
9:计算新混合系数:
10:end for
11:将模型参数更新为
12:until满足停止条件
13:
14:for do
15:确定簇标记
16:划入相应的簇
17:end for
输出:簇划分

  1. %% Example of the Mixture-of-Gaussian EM algorithm
  2. % 2018.3.28
  3. % by WuQi
  4. %% Algorithm Start
  5. clear;clc;
  6. tic; % Timing start
  7. %% Read data
  8. Data = load('data.txt');
  9. %% Initialization parameters
  10. k = 3; sample = 4; % cluster's number is 3
  11. [m,n] = size(Data); % m = 80, n =2
  12. alpha = [1/3 1/3 1/3]; % parameter alpha = 1/3
  13. mu = [Data(6,:); Data(22,:); Data(27,:)]; % prior distribution
  14. Sigma(:,:,1)=[0.1,0.0;0.0,0.1]; % parameter sigma = [0.1,0.0;0.0,0.1]
  15. Sigma(:,:,2)=[0.1,0.0;0.0,0.1];
  16. Sigma(:,:,3)=[0.1,0.0;0.0,0.1];
  17. itera = 50; % iteration is 50
  18. count = 0;
  19. Save_mu = zeros(k,n,itera); % parameter mu
  20. Save_index = zeros(m,1,itera); % index of cluster
  21. Sample_itera = [5 10 20 50]; % sampling piont
  22. while count <= itera
  23. %% E-Step :Posterior Distribution
  24. for j = 1:m
  25. sum1=zeros(1,3);
  26. for i = 1:k
  27. sum1(j,i) = alpha(i) * Gf(Data(j,:),mu(i,:),Sigma(:,:,i));
  28. end
  29. sum2=sum(sum1(j,:));
  30. gamma(j,:) = sum1(j,:) / sum2; % gained
  31. end
  32. [max_gamma,index] = max(gamma,[],2); % classification ; large values are retained
  33. %% M-Step :Update the Parameters
  34. sum3 = sum(gamma);
  35. for i = 1:k
  36. sum4 = 0; sum5 = zeros(2,2);
  37. for j = 1:m
  38. sum4 = sum4 + gamma(j,i) * Data(j,:);
  39. end
  40. mu1(i,:) = sum4 / sum3(i); % update the mu
  41. for j = 1:m
  42. Temp = Data(j,:)-mu1(i,:);
  43. Temp = Temp' * Temp;
  44. sum5 = sum5 + gamma(j,i) * Temp;
  45. end
  46. Sigma1(:,:,i) = sum5 / sum3(i); % update the sigma
  47. alpha1(i,:) = sum3(i) / m; % update the alpha
  48. end
  49. mu = mu1;
  50. Sigma = Sigma1;
  51. alpha = alpha1;
  52. count = count + 1;
  53. Save_mu(:,:,count) = mu; % save the parameter mu
  54. Save_index(:,:,count) = index; % save the parameter index
  55. end
  56. %% Plot and Classification
  57. shape = ['o' 's' '^']; % point's shape
  58. color = ['p' 'g' 'k']; % point's color
  59. padding = 'filed';
  60. figure(1);
  61. subplot(221);hold on; % 221
  62. T_index = Save_index(:,:,Sample_itera(1));
  63. for i = 1:k
  64. Temp = find(T_index == i);
  65. scatter(Data(Temp,1),Data(Temp,2),shape(i),color(i),padding);
  66. Temp = [];
  67. end
  68. scatter(Save_mu(:,1,5),Save_mu(:,2,5),'+','r');
  69. xlabel('密度');ylabel('含糖率');
  70. title('(a)5轮迭代后');
  71. subplot(222);hold on; % 222
  72. T_index = Save_index(:,:,Sample_itera(2));
  73. for i = 1:k
  74. Temp = find(T_index == i);
  75. scatter(Data(Temp,1),Data(Temp,2),shape(i),color(i),padding);
  76. Temp = [];
  77. end
  78. scatter(Save_mu(:,1,10),Save_mu(:,2,10),'+','r');
  79. xlabel('密度');ylabel('含糖率');
  80. title('(b)10轮迭代后');
  81. subplot(223);hold on; % 223
  82. T_index = Save_index(:,:,Sample_itera(3));
  83. for i = 1:k
  84. Temp = find(T_index == i);
  85. scatter(Data(Temp,1),Data(Temp,2),shape(i),color(i),padding);
  86. Temp = [];
  87. end
  88. scatter(Save_mu(:,1,20),Save_mu(:,2,20),'+','r');
  89. xlabel('密度');ylabel('含糖率');
  90. title('(c)20轮迭代后');
  91. subplot(224);hold on; % 224
  92. T_index = Save_index(:,:,Sample_itera(4));
  93. for i = 1:k
  94. Temp = find(T_index == i);
  95. scatter(Data(Temp,1),Data(Temp,2),shape(i),color(i),padding);
  96. Temp = [];
  97. end
  98. scatter(Save_mu(:,1,50),Save_mu(:,2,50),'+','r');
  99. xlabel('密度');ylabel('含糖率');
  100. title('(d)50轮迭代后');
  101. toc; % Timing end
  102. ---
  103. %% Likehood Function
  104. function f=Gf(x,u,s)
  105. sum1= (-1/2)*(x-u)*(inv(s))*(x-u)';
  106. sum2= 1/(2*pi*det(s)^(1/2));
  107. f=sum2*exp(sum1);
  108. end

效果图
此处输入图片的描述
高斯混合聚类()在不同轮迭代后的聚类结果。其中样本簇中的样本点分别用“圆形”,“方块”“三角形”表示,各高斯混合成分的均值向量用"+"表示

6.应用于朴素贝叶斯

Example(三硬币模型)
假设有3枚硬币,分别记作A,B,C。这些硬币正面出现的概率分别是.
实验:

  1. 先投掷硬币A,根据其结果选择硬币B或者硬币C。正面选择硬币B,反面选择硬币C。
  2. 然后投掷被选出的硬币,对出现的结果记录。正面记作1;反面计算0

样本:
1,1,0,1,0,0,1,0,1,1
只能观测到硬币投掷后的结果,不能观测其过程,如何估计三硬币正面出现的概率?即三硬币的模型参数。
解:
三硬币模型:
此处输入图片的描述


这里,随机变量是观测变量,表示一次试验观测的结果是1或0;随机变量是隐变量,表示未观测到的投掷硬币A的结果。是模型参数。


即:

求模型参数的极大似然估计,即:

E-Step
选取初值,记作

M-Step
计算模型参数的新估计值

NB-EM Algorithm

伪代码:
输入:样本集
硬币A的状态为.
过程:
1:初始化模型参数
2:repeat
3:for do
4:计算由硬币B,C生成的后验概率,即

5:end for
6:for do
7:计算新的参数
8:计算新的参数
9:end for
10:将模型参数更新为
11:until满足停止条件
12:
13:for do
14:确定簇标记
15:划入相应的簇
16:end for
输出:簇划分

  1. %% Example of the Naive-Bayes EM algorithm
  2. % 2018.3.29
  3. % by WuQi
  4. %% Algorithm Start
  5. clear;clc;
  6. tic; % Timing start
  7. %% Read data
  8. Data = [1 1 0 1 0 0 1 0 1 1]';
  9. %% Initialization parameters
  10. k = 2; % cluster's number is 2
  11. [m,n] = size(Data); % m = 10, n = 1
  12. p_pi = 0.4; % parameter pi
  13. parameter = [0.6 0.7]; % parameter p,q
  14. itera = 50; % iteration is 50
  15. count = 0;
  16. Save_index = zeros(m,1,itera); % index of cluster
  17. Sample_itera = [5 10 20 50]; % sampling piont
  18. while count <= itera
  19. %% E-Step :Posterior Distribution
  20. sum1 = zeros(m,k);
  21. for j = 1:m
  22. for i = 1:k
  23. sum1(j,i) = (p_pi^(2-i)) * ((1-p_pi)^(i-1)) * Bf(Data(j),parameter(i));
  24. end
  25. sum2=sum(sum1(j,:)); % gained
  26. gamma(j,:) = sum1(j,:) / sum2;
  27. end
  28. [max_gamma,index] = max(gamma,[],2); % classification ; large values are retained
  29. %% M-Step :Update the Parameters
  30. sum3 = sum(gamma);
  31. p_pi1 = sum3(1) / m; % update the pi
  32. for i = 1:k
  33. sum4 = 0;
  34. for j = 1:m
  35. sum4 = sum4 + gamma(j,i) * Data(j,:);
  36. end
  37. parameter1(i) = sum4 / sum3(i); % update the parameter
  38. end
  39. p_pi = p_pi1;
  40. parameter = parameter1;
  41. count = count + 1;
  42. Save_index(:,:,count) = index; % save the parameter index
  43. end
  44. %% Plot and Classification
  45. shape = ['o' 's' '^']; % point's shape
  46. color = ['p' 'g' 'k']; % point's color
  47. padding = 'filed';
  48. figure(1);
  49. subplot(221);hold on; % 221
  50. T_index = Save_index(:,:,Sample_itera(1));
  51. for i = 1:k
  52. Temp = find(T_index == i);
  53. [a,b] = size(Temp);
  54. scatter(1:a,Data(Temp),shape(i),color(i),padding);
  55. Temp = [];
  56. end
  57. xlabel('次');ylabel('面');
  58. title('(a)5轮迭代后');
  59. subplot(222);hold on; % 222
  60. T_index = Save_index(:,:,Sample_itera(2));
  61. for i = 1:k
  62. Temp = find(T_index == i);
  63. [a,b] = size(Temp);
  64. scatter(1:a,Data(Temp),shape(i),color(i),padding);
  65. Temp = [];
  66. end
  67. xlabel('次');ylabel('面');
  68. title('(b)10轮迭代后');
  69. subplot(223);hold on; % 223
  70. T_index = Save_index(:,:,Sample_itera(3));
  71. for i = 1:k
  72. Temp = find(T_index == i);
  73. [a,b] = size(Temp);
  74. scatter(1:a,Data(Temp),shape(i),color(i),padding);
  75. Temp = [];
  76. end
  77. xlabel('次');ylabel('面');
  78. title('(c)20轮迭代后');
  79. subplot(224);hold on; % 224
  80. T_index = Save_index(:,:,Sample_itera(4));
  81. for i = 1:k
  82. Temp = find(T_index == i);
  83. [a,b] = size(Temp);
  84. scatter(1:a,Data(Temp),shape(i),color(i),padding);
  85. Temp = [];
  86. end
  87. xlabel('次');ylabel('面');
  88. title('(d)50轮迭代后');
  89. p_pi
  90. parameter
  91. toc; % Timing end
  92. ---
  93. %% Example of the Naive-Bayes EM algorithm
  94. % Likehood Function
  95. % 2018.3.29
  96. % by WuQi
  97. function f=Bf(x,parameter)
  98. sum1 = (parameter ^ x);
  99. sum2 = ((1-parameter) ^ (1-x));
  100. f=sum1 * sum2;
  101. end

一般地:
:表示观测随机变量的数据,
:表示隐随机变量的数据,
: 连在一起称为完全数据(complete - data)
假设:给定观测数据,其概率分布为,其中是需要估计得模型参数,对于不完全数据的似然函数是,对数似然函数
假设:的联合概率分布是,那么完整数据的对数似然函数是

最后补充:
1.EM算法对初始值很敏感。
2.停止迭代条件是:


这里的:


添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注