小言_互联网的博客

Fisher判别的推导概念和过程+python代码实现(三分类)

445人阅读  评论(0)

python代码完成Fisher判别的推导

数据集Iris.csv:链接下载:
提取码:eah8

一、Fisher算法的主要思想

  • 线性判别分析(Linear Discriminant Analysis
    简称LDA)是一种经典的线性学习方法,在二分类问题上因为最早由【Fisher,1936年】提出,所以也称为“Fisher 判别分析!”
    Fisher(费歇)判别思想是投影,使多维问题简化为一维问题来处理。选择一个适当的投影轴,使所有的样本点都投影到这个轴上得到一个投影值。对这个投影轴的方向的要求是:使每一类内的投影值所形成的类内离差尽可能小,而不同类间的投影值所形成的类间离差尽可能大。

二、Fisher数学算法步骤

  • 为了找到最佳投影方向,需要计算出 各类样本均值、样本类内离散度矩阵 Si\boldsymbol S_{i}S i和样本总类内离散度矩阵 Sw\boldsymbolS_{w}Sw、样本类间离散度矩阵 Sb\boldsymbol S_{b}Sb ,根据Fisher准则,找到最佳投影向量,将训练集内的所有样本进行投影,投影到一维Y空间,由于Y空间是一维的,则需要求出Y空间的划分边界点,找到边界点后,就可以对待测样本进行一维Y空间投影,判断它的投影点与分界点的关系,将其归类。具体方法如下(以两类问题为例子):

①计算各类样本均值向量 m i m_i , m i m_i 是各个类的均值, N i N_i w i w_i 类的样本个数。

②计算样本类内离散度矩阵 S i S_i 和总类内离散度矩阵 S w S_w

③计算样本类间离散度矩阵 S b S_b

④求投影方向向量 W W (维度和样本的维度相同)。我们希望投影后,在一维 Y Y 空间里各类样本尽可能分开,就是我们希望的两类样本均值之差 ( m 1 m 2 ) (\overline{m_1}-\overline{m_2}) 越大越好,同时希望各类样本内部尽量密集,即是:希望类内离散度越小越好。因此,我们可以定义Fisher准则函数为:

2使得 J F ( w ) J_F(w) 取得最大值 w w 为:

⑤将训练集内所有样本进行投影。

⑥. 计算在投影空间上的分割阈值 y 0 y_0 ,在一维Y空间,各类样本均值 m i \overline{m_i} 为:


样本类内离散度 S i 2 \overline{S_i}^2 和总类内离散度 S w \overline{S_w}

而此时类间离散度就成为两类均值差的平方。

计算阈值 y 0 y_0

⑦对于给定的测试样本 x x ,计算出它在 w w 上的投影点 y y

⑧根据决策规则分类!

三、python实现代码

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
path=r'D:/iris-data/iris.csv'
df = pd.read_csv(path, header=0)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]

#类均值向量
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)

#各类内离散度矩阵
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):
    a=Iris1[i,:]-m1
    a=np.array([a])
    b=a.T
    s1=s1+np.dot(b,a)    
for i in range(0,30,1):
    c=Iris2[i,:]-m2
    c=np.array([c])
    d=c.T
    s2=s2+np.dot(d,c) 
for i in range(0,30,1):
    a=Iris3[i,:]-m3
    a=np.array([a])
    b=a.T
    s3=s3+np.dot(b,a) 

#总类内离散矩阵
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
#投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#判别函数以及T
#需要先将m1-m2转化成矩阵才能进行求其转置矩阵
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
#print(m1+m2) #1x4维度  invsw12 4x4维度  m1-m2 4x1维度
#判别函数以及阈值T(即w0)
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,49):
    x=Iris1[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
        kind1=kind1+1
    elif g12<0 and g23>0:
        newiris2.extend(x)
    elif g13<0 and g23<0 :
        newiris3.extend(x)
#print(newiris1)
for i in range(30,49):
    x=Iris2[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
    elif g12<0 and g23>0:
 
        newiris2.extend(x)
        kind2=kind2+1
    elif g13<0 and g23<0 :
        newiris3.extend(x)
for i in range(30,50):
    x=Iris3[i,:]
    x=np.array([x])
    g12=np.dot(w12,x.T)+T12
    g13=np.dot(w13,x.T)+T13
    g23=np.dot(w23,x.T)+T23
    if g12>0 and g13>0:
        newiris1.extend(x)
    elif g12<0 and g23>0:     
        newiris2.extend(x)
    elif g13<0 and g23<0 :
        newiris3.extend(x)
        kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("样本类内离散度矩阵S1:",s1,'\n')
print("样本类内离散度矩阵S2:",s2,'\n')
print("样本类内离散度矩阵S3:",s3,'\n')
print('-----------------------------------------------------------------------------------------------')
print("总体类内离散度矩阵Sw12:",sw12,'\n')
print("总体类内离散度矩阵Sw13:",sw13,'\n')
print("总体类内离散度矩阵Sw23:",sw23,'\n')
print('-----------------------------------------------------------------------------------------------')
print('判断出来的综合正确率:',correct*100,'%')

运行结果:

以上便是此次实验的所有结果。

 参考博客:http://bob0118.club/?p=266 

转载:https://blog.csdn.net/qq_42585108/article/details/105928164
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场