机器学习实战之预测数值型数据:回归_常用的回归预测-程序员宅基地

技术标签: 因吉  机器学习  Smale  # 机器学习  Python  

引入

  分类的目标变量是标称型数据,而回归则是对连续型的数据做出预测。回归能做什么呢?Peter Harrington的观点是可以做任何事情,包括他本人提到的一个比较有新意的应用:预测名人的离婚率。

1 线性回归

线性回归

  优点:结果易于理解,计算并不复杂;
  缺点:对非线性数据拟合效果不好;
  使用数据类型:数值型和标称型数据。

  回归的目的是预测数值型的目标值。最接近的的办法是依据输入写一个目标值的计算公式。假设你想对评估一个自己今天的颓废程度,可能会这样计算:

颓废程度 = 起床时间 × 1.2 + 发呆时间 × 1.1 - 学习时间 × 1.3

  以上是个示例,便是所谓的回归方程(regression equation),其中的数字称作回归系数(regression weights),求回归系数的过程就是回归。一旦求得回归系数,再给定输入,便能轻松地获得预测值。
  说到回归,一般指线性回归(linear regression),之后所述的回归都是这个意思。需要说明的是,存在另一种称为非线性回归的回归模型,该模型不认同上面的做法,比如认为输出可能是输入的乘积。例如:

颓废程度 =1.2 × 起床时间 × 发呆时间 × 1.1 / 学习时间

1.1 基本概念

“回归”一词的来历

  今天所知道的回归是由达尔文(Charles Darwin)的表兄弟Francis Galton发明的。Galton于1877年完成了第一次回归预测,目的是根据上一代豌豆种子(双亲)的尺寸来预测下一代豌豆种子(孩子)的尺寸。Galton在大量对象上应用了回归分析,甚至包括人的身高。他注意到,如果双亲的高度比平均高度高,他们的子女也倾向于比平均高度高,但尚不及双亲。孩子的高度向着平均高度回退(回归)。Galton在多项研究上都注意到这个现象,所有尽管这个英文单词根数值预测没有任何关系,但这种研究方法仍被称为回归。

  如何求回归系数呢?假设输入数据都存放在矩阵 X X X中,而回归系数存放在向量 W W W中。那么对于给定的输入 X i X_i Xi,预测结果将会通过 Y i = X i t W Y_i=X^t_iW Yi=XitW给出。现在的问题是,已有一部分 X X X和对应的 Y Y Y,如何找到 W W W呢?一个常用的方法就是找出使误差最小的 W W W。这里的误差是指预测值 Y Y Y和真实值 Y Y Y之间的差值,使用该误差的简单累加将使得正差值和负差值相互抵消,,故可采用平方误差。平方误差可以写做:
∑ i = 1 m ( y i − x i T w ) 2 (1-1) \sum^m_{i=1}(y_i-x^T_iw)^2\tag{1-1} i=1m(yixiTw)2(1-1)  用矩阵表示还可以写做 ( Y − X W ) T ( Y − X W ) (Y-XW)^T(Y-XW) (YXW)T(YXW)。如果对 W W W求导,得到 X T ( Y − X W ) X^T(Y-XW) XT(YXW),令其等于0,解出 W W W如下:
W ^ = ( X T X ) − 1 X T Y \hat W=(X^TX)^{-1}X^TY W^=(XTX)1XTY   W W W上的帽子表示:这是当前可以估计出的最优解。值得注意的是,上述公式中包含 ( X T X ) − 1 (X^TX)^{-1} (XTX)1,也就是需要对矩阵求逆,因此这个方程只在矩阵存在的时候使用。然而,矩阵的逆不一定存在,故需在代码中进行判断。

1.2 标准线性回归

  上述通过 W ^ = ( X T X ) − 1 X T Y \hat W=(X^TX)^{-1}X^TY W^=(XTX)1XTY的方式求解最佳回归系数,该方法被称为OLS,即普通最小二乘法(ordinary least squares)。示例数据集如下,存于文件ex0.txt中:

1.000000	0.067732	3.176513
1.000000	0.427810	3.816464
1.000000	0.995731	4.550095
1.000000	0.738336	4.256571
1.000000	0.981083	4.560815
1.000000	0.526171	3.929515
1.000000	0.378887	3.526170
1.000000	0.033859	3.156393
1.000000	0.132791	3.110301
1.000000	0.138306	3.149813
1.000000	0.247809	3.476346
1.000000	0.648270	4.119688
1.000000	0.731209	4.282233
1.000000	0.236833	3.486582
1.000000	0.969788	4.655492
1.000000	0.607492	3.965162
1.000000	0.358622	3.514900
1.000000	0.147846	3.125947
1.000000	0.637820	4.094115
1.000000	0.230372	3.476039
1.000000	0.070237	3.210610
1.000000	0.067154	3.190612
1.000000	0.925577	4.631504
1.000000	0.717733	4.295890
1.000000	0.015371	3.085028
1.000000	0.335070	3.448080
1.000000	0.040486	3.167440
1.000000	0.212575	3.364266
1.000000	0.617218	3.993482
1.000000	0.541196	3.891471
1.000000	0.045353	3.143259
1.000000	0.126762	3.114204
1.000000	0.556486	3.851484
1.000000	0.901144	4.621899
1.000000	0.958476	4.580768
1.000000	0.274561	3.620992
1.000000	0.394396	3.580501
1.000000	0.872480	4.618706
1.000000	0.409932	3.676867
1.000000	0.908969	4.641845
1.000000	0.166819	3.175939
1.000000	0.665016	4.264980
1.000000	0.263727	3.558448
1.000000	0.231214	3.436632
1.000000	0.552928	3.831052
1.000000	0.047744	3.182853
1.000000	0.365746	3.498906
1.000000	0.495002	3.946833
1.000000	0.493466	3.900583
1.000000	0.792101	4.238522
1.000000	0.769660	4.233080
1.000000	0.251821	3.521557
1.000000	0.181951	3.203344
1.000000	0.808177	4.278105
1.000000	0.334116	3.555705
1.000000	0.338630	3.502661
1.000000	0.452584	3.859776
1.000000	0.694770	4.275956
1.000000	0.590902	3.916191
1.000000	0.307928	3.587961
1.000000	0.148364	3.183004
1.000000	0.702180	4.225236
1.000000	0.721544	4.231083
1.000000	0.666886	4.240544
1.000000	0.124931	3.222372
1.000000	0.618286	4.021445
1.000000	0.381086	3.567479
1.000000	0.385643	3.562580
1.000000	0.777175	4.262059
1.000000	0.116089	3.208813
1.000000	0.115487	3.169825
1.000000	0.663510	4.193949
1.000000	0.254884	3.491678
1.000000	0.993888	4.533306
1.000000	0.295434	3.550108
1.000000	0.952523	4.636427
1.000000	0.307047	3.557078
1.000000	0.277261	3.552874
1.000000	0.279101	3.494159
1.000000	0.175724	3.206828
1.000000	0.156383	3.195266
1.000000	0.733165	4.221292
1.000000	0.848142	4.413372
1.000000	0.771184	4.184347
1.000000	0.429492	3.742878
1.000000	0.162176	3.201878
1.000000	0.917064	4.648964
1.000000	0.315044	3.510117
1.000000	0.201473	3.274434
1.000000	0.297038	3.579622
1.000000	0.336647	3.489244
1.000000	0.666109	4.237386
1.000000	0.583888	3.913749
1.000000	0.085031	3.228990
1.000000	0.687006	4.286286
1.000000	0.949655	4.628614
1.000000	0.189912	3.239536
1.000000	0.844027	4.457997
1.000000	0.333288	3.513384
1.000000	0.427035	3.729674
1.000000	0.466369	3.834274
1.000000	0.550659	3.811155
1.000000	0.278213	3.598316
1.000000	0.918769	4.692514
1.000000	0.886555	4.604859
1.000000	0.569488	3.864912
1.000000	0.066379	3.184236
1.000000	0.335751	3.500796
1.000000	0.426863	3.743365
1.000000	0.395746	3.622905
1.000000	0.694221	4.310796
1.000000	0.272760	3.583357
1.000000	0.503495	3.901852
1.000000	0.067119	3.233521
1.000000	0.038326	3.105266
1.000000	0.599122	3.865544
1.000000	0.947054	4.628625
1.000000	0.671279	4.231213
1.000000	0.434811	3.791149
1.000000	0.509381	3.968271
1.000000	0.749442	4.253910
1.000000	0.058014	3.194710
1.000000	0.482978	3.996503
1.000000	0.466776	3.904358
1.000000	0.357767	3.503976
1.000000	0.949123	4.557545
1.000000	0.417320	3.699876
1.000000	0.920461	4.613614
1.000000	0.156433	3.140401
1.000000	0.656662	4.206717
1.000000	0.616418	3.969524
1.000000	0.853428	4.476096
1.000000	0.133295	3.136528
1.000000	0.693007	4.279071
1.000000	0.178449	3.200603
1.000000	0.199526	3.299012
1.000000	0.073224	3.209873
1.000000	0.286515	3.632942
1.000000	0.182026	3.248361
1.000000	0.621523	3.995783
1.000000	0.344584	3.563262
1.000000	0.398556	3.649712
1.000000	0.480369	3.951845
1.000000	0.153350	3.145031
1.000000	0.171846	3.181577
1.000000	0.867082	4.637087
1.000000	0.223855	3.404964
1.000000	0.528301	3.873188
1.000000	0.890192	4.633648
1.000000	0.106352	3.154768
1.000000	0.917886	4.623637
1.000000	0.014855	3.078132
1.000000	0.567682	3.913596
1.000000	0.068854	3.221817
1.000000	0.603535	3.938071
1.000000	0.532050	3.880822
1.000000	0.651362	4.176436
1.000000	0.901225	4.648161
1.000000	0.204337	3.332312
1.000000	0.696081	4.240614
1.000000	0.963924	4.532224
1.000000	0.981390	4.557105
1.000000	0.987911	4.610072
1.000000	0.990947	4.636569
1.000000	0.736021	4.229813
1.000000	0.253574	3.500860
1.000000	0.674722	4.245514
1.000000	0.939368	4.605182
1.000000	0.235419	3.454340
1.000000	0.110521	3.180775
1.000000	0.218023	3.380820
1.000000	0.869778	4.565020
1.000000	0.196830	3.279973
1.000000	0.958178	4.554241
1.000000	0.972673	4.633520
1.000000	0.745797	4.281037
1.000000	0.445674	3.844426
1.000000	0.470557	3.891601
1.000000	0.549236	3.849728
1.000000	0.335691	3.492215
1.000000	0.884739	4.592374
1.000000	0.918916	4.632025
1.000000	0.441815	3.756750
1.000000	0.116598	3.133555
1.000000	0.359274	3.567919
1.000000	0.814811	4.363382
1.000000	0.387125	3.560165
1.000000	0.982243	4.564305
1.000000	0.780880	4.215055
1.000000	0.652565	4.174999
1.000000	0.870030	4.586640
1.000000	0.604755	3.960008
1.000000	0.255212	3.529963
1.000000	0.730546	4.213412
1.000000	0.493829	3.908685
1.000000	0.257017	3.585821
1.000000	0.833735	4.374394
1.000000	0.070095	3.213817
1.000000	0.527070	3.952681
1.000000	0.116163	3.129283

  选取第二、三列绘制如下:
在这里插入图片描述

图1-1 训练数据集


  创建regression.py文件并添加以下代码:

程序清单1-1: 标准回归函数、数据导入函数及测试函数

import matplotlib.pyplot as plt
from numpy import *

def load_data_set(file_name):
    with open(file_name) as fd:
        fd_data = fd.readlines()
    x_set = []; y_set = []
    for data in fd_data:
        data = data.strip().split('\t')
        data = [float(value) for value in data]
        x_set.append(data[:-1])
        y_set.append(data[-1])
        plt.scatter(data[1], data[2], c='red')
    return x_set, y_set

def stand_regres(x_set, y_set):
    x_mat = mat(x_set); y_mat = mat(y_set).T
    x_tx = x_mat.T * x_mat
    if linalg.det(x_tx) == 0.0:    #linalg.det()用于计算行列式,若行列式为0,则矩阵不可进行求逆运算
        print("This matrix is singular,cannot do inverse")
        return
    w = x_tx.I * (x_mat.T * y_mat)    #对应普通二乘法公式
    return w

def test1():
    x_set, y_set = load_data_set('E:/Machine Learing/myMachineLearning/data/ex0.txt')
    w = stand_regres(x_set, y_set)
    print("W:", w.T)
    x_copy = mat(x_set).copy()    #拷贝
    x_copy .sort(0)    #排序点
    y_hat = x_copy * w    #获得预测值
    plt.plot(x_copy [:,1], y_hat)
    plt.show()

if __name__ == '__main__':
    test1()

  运行结果:

W: [[3.00774324 1.69532264]]

在这里插入图片描述

图1-2 训练数据集和它的最佳拟合直线


  需要注意的是给定数据集中的第一列总是等于1.0,即X0。这是因为我们假定偏移量是一个常数,第二、三列才是数据真实的属性。在绘图时,预测结果保存于y_hat中,但是原本数据集中实例是杂乱的,故在绘制前进行排序。
  
  这一模型简单,几乎所有数据集都可以用上述方式建立模型,那么,如何判断这些模型的优劣呢?有种方式是计算预测值y_hat与真实值y的匹配程度,即相关系数。Python中计算相关系数的命令是corrcoef(y_hat.T,y_mat)。最终结果如下:

Correlation coefficient:
 [[1.         0.98647356]
 [0.98647356 1.        ]]

  对角线上的数据是1.0,这是因为自己与自己匹配的结果;而预测结果与真实结果的相关性达到了0.98。

2 局部加权线性回归

  线性回归很可能出现欠拟合现象,因为它求的是具有最小均方误差的无偏估计。显然,如果模型欠拟合则不能取得最好的预测效果。所以有的方法允许在估计中引入一些偏差,从而降低预测的均方误差。

2.1 基本概念

此处介绍一个方法:局部加权线性回归(Locally Weighted Linear Regression, LWLR)。在该方法中,给待预测点附近的每个点赋予一定的权重,其他则与标准回归一致。
  与kNN一样,该算法每次预测均需事先选取出对应的数据子集,其所对应的 W W W如下:
W ^ = ( X T W ′ X ) − 1 X T W ′ Y (2-1) \hat W=(X^TW'X)^{-1}X^TW'Y\tag{2-1} W^=(XTWX)1XTWY(2-1)其中 W ′ W' W特指权重。
  LWLR使用“核”来对附近的点赋予更高的权重。核的类型可以自由选择,最常用的则是高斯核,高斯核对应的权重如下:
W ′ ( i , i ) = e x p ( ∣ x ( i ) − x ∣ − 2 k 2 ) (2-2) W'(i,i)=exp(\frac{|x^{(i)}-x|}{-2k^2})\tag{2-2} W(i,i)=exp(2k2x(i)x)(2-2)  由此就构建了一个只包含对角元素的权重矩阵 W ′ W' W,并且点 x x x x ( i ) x(i) x(i)越近, W ′ ( i , i ) W'(i,i) W(i,i)便越大。式 2 − 2 2-2 22中包含了一个用户指定的参数 k k k,它决定了对附近的点赋予多大的权重。当然,上述公式还可以写作以下形式:
W ′ ( i , i ) = e x p ( − ( x − x ( i ) ) 2 2 k 2 ) (2-3) W'(i,i)=exp(-\frac{(x-x(i))^2}{2k^2})\tag{2-3} W(i,i)=exp(2k2(xx(i))2)(2-3)  取k分布等于0.5、0.1、0.01且 x ( i ) = 0.5 x(i)=0.5 x(i)=0.5作为示例绘制权重变化图如下:
在这里插入图片描述

图2-1 k相关权重图


  以下为具体实现过程。于regression.py文件并添加以下代码:

程序清单2-1: 局部加权线性回归函数

def lwlr(test_point, x_set, y_set, k=1.0):    #输入参数:单个实例、x、y、k
    x_mat = mat(x_set); y_mat = mat(y_set).T
    m = shape(x_mat)[0]
    weights = mat(eye((m)))    #创建对角矩阵
    for j in range(m):    #计算权重
        diff_mat = test_point - x_mat[j,:]
        weights[j, j] = exp(diff_mat * diff_mat.T / (-2.0 * k**2))
    x_tx = x_mat.T * (weights * x_mat)
    if linalg.det(x_tx) == 0.0:    #判断行列式是否为零
        print("This matrix is singular, cannot do inverse")
        return
    w = x_tx.I * (x_mat.T * (weights * y_mat))
    return test_point * w

def lwlr_test(test_set, x_set, y_set, k=1.0):
    test_mat = mat(test_set)
    m = shape(test_mat)[0]
    y_hat = zeros(m)
    for i in range(m):    #对每一个实例进行预测
        y_hat[i] = lwlr(test_mat[i], x_set, y_set, k)
    return y_hat

def test2():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data\ex0.txt')
    y_hat = lwlr_test(x_set, x_set, y_set)
    print("The predicted is:\n", y_hat[:6])
    print("The real is:\n", y_set[:6])
    x_mat = mat(x_set); y_mat = mat(y_set)
    sorted_index = x_mat[:,1].argsort(0)    #获取矩阵排序后的索引,但不改变矩阵
    x_sorted = x_mat[sorted_index][:,0,:]    #排序后的sorted_index
    plt.axis([0, 1, 3, 5])
    plt.subplot(311)
    plt.plot(x_sorted[:,1], y_hat[sorted_index])
    plt.scatter(x_mat[:,1].flatten().A[0], y_mat.T.flatten().A[0], s=2, c='red')
    plt.subplot(312)
    y_hat = lwlr_test(x_set, x_set, y_set, 0.1)
    plt.plot(x_sorted[:, 1], y_hat[sorted_index])
    plt.scatter(x_mat[:, 1].flatten().A[0], y_mat.T.flatten().A[0], s=2, c='red')
    plt.subplot(313)
    y_hat = lwlr_test(x_set, x_set, y_set, 0.003)
    plt.plot(x_sorted[:, 1], y_hat[sorted_index])
    plt.scatter(x_mat[:, 1].flatten().A[0], y_mat.T.flatten().A[0], s=2, c='red')
    plt.show()

if __name__ == '__main__':
    test2()
    # test1()

  运行结果:

The predicted is:
 [3.12204471 3.73284336 4.69692033 4.25997574 4.67205815 3.89979584]
The real is:
 [3.176513, 3.816464, 4.550095, 4.256571, 4.560815, 3.929515]

在这里插入图片描述

图2-2 不同k值下局部加权线性回归的结果


  如图2-2:
  1)k=1.0时,权重很大,如同将所有数据视为等权重,得出的最佳拟合直线与标准回归一致,出现欠拟合;
  2)k=0.01时,权重适中,抓住了数据的潜在模式;
  3)k=0.003时,权重较小,纳入了太多噪声点,拟合的直线与数据点过于贴切,出现过拟合。

2.2 示例:预测鲍鱼的年龄

  鲍鱼数据集‘abalone.txt’来源于UCI数据集合,记录了鲍鱼的年龄。鲍鱼的年龄可以从鲍鱼壳的层数推算得到,具体内容如下,其中每个实例的最后一个元素代表鲍鱼真实年龄(由于数据集较大,只列出前一百行):

1	0.455	0.365	0.095	0.514	0.2245	0.101	0.15	15
1	0.35	0.265	0.09	0.2255	0.0995	0.0485	0.07	7
-1	0.53	0.42	0.135	0.677	0.2565	0.1415	0.21	9
1	0.44	0.365	0.125	0.516	0.2155	0.114	0.155	10
0	0.33	0.255	0.08	0.205	0.0895	0.0395	0.055	7
0	0.425	0.3	0.095	0.3515	0.141	0.0775	0.12	8
-1	0.53	0.415	0.15	0.7775	0.237	0.1415	0.33	20
-1	0.545	0.425	0.125	0.768	0.294	0.1495	0.26	16
1	0.475	0.37	0.125	0.5095	0.2165	0.1125	0.165	9
-1	0.55	0.44	0.15	0.8945	0.3145	0.151	0.32	19
-1	0.525	0.38	0.14	0.6065	0.194	0.1475	0.21	14
1	0.43	0.35	0.11	0.406	0.1675	0.081	0.135	10
1	0.49	0.38	0.135	0.5415	0.2175	0.095	0.19	11
-1	0.535	0.405	0.145	0.6845	0.2725	0.171	0.205	10
-1	0.47	0.355	0.1	0.4755	0.1675	0.0805	0.185	10
1	0.5	0.4	0.13	0.6645	0.258	0.133	0.24	12
0	0.355	0.28	0.085	0.2905	0.095	0.0395	0.115	7
-1	0.44	0.34	0.1	0.451	0.188	0.087	0.13	10
1	0.365	0.295	0.08	0.2555	0.097	0.043	0.1	7
1	0.45	0.32	0.1	0.381	0.1705	0.075	0.115	9
1	0.355	0.28	0.095	0.2455	0.0955	0.062	0.075	11
0	0.38	0.275	0.1	0.2255	0.08	0.049	0.085	10
-1	0.565	0.44	0.155	0.9395	0.4275	0.214	0.27	12
-1	0.55	0.415	0.135	0.7635	0.318	0.21	0.2	9
-1	0.615	0.48	0.165	1.1615	0.513	0.301	0.305	10
-1	0.56	0.44	0.14	0.9285	0.3825	0.188	0.3	11
-1	0.58	0.45	0.185	0.9955	0.3945	0.272	0.285	11
1	0.59	0.445	0.14	0.931	0.356	0.234	0.28	12
1	0.605	0.475	0.18	0.9365	0.394	0.219	0.295	15
1	0.575	0.425	0.14	0.8635	0.393	0.227	0.2	11
1	0.58	0.47	0.165	0.9975	0.3935	0.242	0.33	10
-1	0.68	0.56	0.165	1.639	0.6055	0.2805	0.46	15
1	0.665	0.525	0.165	1.338	0.5515	0.3575	0.35	18
-1	0.68	0.55	0.175	1.798	0.815	0.3925	0.455	19
-1	0.705	0.55	0.2	1.7095	0.633	0.4115	0.49	13
1	0.465	0.355	0.105	0.4795	0.227	0.124	0.125	8
-1	0.54	0.475	0.155	1.217	0.5305	0.3075	0.34	16
-1	0.45	0.355	0.105	0.5225	0.237	0.1165	0.145	8
-1	0.575	0.445	0.135	0.883	0.381	0.2035	0.26	11
1	0.355	0.29	0.09	0.3275	0.134	0.086	0.09	9
-1	0.45	0.335	0.105	0.425	0.1865	0.091	0.115	9
-1	0.55	0.425	0.135	0.8515	0.362	0.196	0.27	14
0	0.24	0.175	0.045	0.07	0.0315	0.0235	0.02	5
0	0.205	0.15	0.055	0.042	0.0255	0.015	0.012	5
0	0.21	0.15	0.05	0.042	0.0175	0.0125	0.015	4
0	0.39	0.295	0.095	0.203	0.0875	0.045	0.075	7
1	0.47	0.37	0.12	0.5795	0.293	0.227	0.14	9
-1	0.46	0.375	0.12	0.4605	0.1775	0.11	0.15	7
0	0.325	0.245	0.07	0.161	0.0755	0.0255	0.045	6
-1	0.525	0.425	0.16	0.8355	0.3545	0.2135	0.245	9
0	0.52	0.41	0.12	0.595	0.2385	0.111	0.19	8
1	0.4	0.32	0.095	0.303	0.1335	0.06	0.1	7
1	0.485	0.36	0.13	0.5415	0.2595	0.096	0.16	10
-1	0.47	0.36	0.12	0.4775	0.2105	0.1055	0.15	10
1	0.405	0.31	0.1	0.385	0.173	0.0915	0.11	7
-1	0.5	0.4	0.14	0.6615	0.2565	0.1755	0.22	8
1	0.445	0.35	0.12	0.4425	0.192	0.0955	0.135	8
1	0.47	0.385	0.135	0.5895	0.2765	0.12	0.17	8
0	0.245	0.19	0.06	0.086	0.042	0.014	0.025	4
-1	0.505	0.4	0.125	0.583	0.246	0.13	0.175	7
1	0.45	0.345	0.105	0.4115	0.18	0.1125	0.135	7
1	0.505	0.405	0.11	0.625	0.305	0.16	0.175	9
-1	0.53	0.41	0.13	0.6965	0.302	0.1935	0.2	10
1	0.425	0.325	0.095	0.3785	0.1705	0.08	0.1	7
1	0.52	0.4	0.12	0.58	0.234	0.1315	0.185	8
1	0.475	0.355	0.12	0.48	0.234	0.1015	0.135	8
-1	0.565	0.44	0.16	0.915	0.354	0.1935	0.32	12
-1	0.595	0.495	0.185	1.285	0.416	0.224	0.485	13
-1	0.475	0.39	0.12	0.5305	0.2135	0.1155	0.17	10
0	0.31	0.235	0.07	0.151	0.063	0.0405	0.045	6
1	0.555	0.425	0.13	0.7665	0.264	0.168	0.275	13
-1	0.4	0.32	0.11	0.353	0.1405	0.0985	0.1	8
-1	0.595	0.475	0.17	1.247	0.48	0.225	0.425	20
1	0.57	0.48	0.175	1.185	0.474	0.261	0.38	11
-1	0.605	0.45	0.195	1.098	0.481	0.2895	0.315	13
-1	0.6	0.475	0.15	1.0075	0.4425	0.221	0.28	15
1	0.595	0.475	0.14	0.944	0.3625	0.189	0.315	9
-1	0.6	0.47	0.15	0.922	0.363	0.194	0.305	10
-1	0.555	0.425	0.14	0.788	0.282	0.1595	0.285	11
-1	0.615	0.475	0.17	1.1025	0.4695	0.2355	0.345	14
-1	0.575	0.445	0.14	0.941	0.3845	0.252	0.285	9
1	0.62	0.51	0.175	1.615	0.5105	0.192	0.675	12
-1	0.52	0.425	0.165	0.9885	0.396	0.225	0.32	16
1	0.595	0.475	0.16	1.3175	0.408	0.234	0.58	21
1	0.58	0.45	0.14	1.013	0.38	0.216	0.36	14
-1	0.57	0.465	0.18	1.295	0.339	0.2225	0.44	12
1	0.625	0.465	0.14	1.195	0.4825	0.205	0.4	13
1	0.56	0.44	0.16	0.8645	0.3305	0.2075	0.26	10
-1	0.46	0.355	0.13	0.517	0.2205	0.114	0.165	9
-1	0.575	0.45	0.16	0.9775	0.3135	0.231	0.33	12
1	0.565	0.425	0.135	0.8115	0.341	0.1675	0.255	15
1	0.555	0.44	0.15	0.755	0.307	0.1525	0.26	12
1	0.595	0.465	0.175	1.115	0.4015	0.254	0.39	13
-1	0.625	0.495	0.165	1.262	0.507	0.318	0.39	10
1	0.695	0.56	0.19	1.494	0.588	0.3425	0.485	15
1	0.665	0.535	0.195	1.606	0.5755	0.388	0.48	14
1	0.535	0.435	0.15	0.725	0.269	0.1385	0.25	9
1	0.47	0.375	0.13	0.523	0.214	0.132	0.145	8
1	0.47	0.37	0.13	0.5225	0.201	0.133	0.165	7
-1	0.475	0.375	0.125	0.5785	0.2775	0.085	0.155	10

于regression.py文件并添加以下代码:

程序清单2-2: 预测鲍鱼年龄

def test3():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data/abalone.txt')
    x_label = list(arange(0.1, 10, 0.1))
    y_hat01 = lwlr_test(x_set[0:99], x_set[0:99], y_set[0:99], 0.1)
    rss_error01 = rss_error(y_set[0:99], y_hat01.T)
    y_hat1 = lwlr_test(x_set[0:99], x_set[0:99], y_set[0:99], 1)
    rss_error1 = rss_error(y_set[0:99], y_hat1.T)
    y_hat10 = lwlr_test(x_set[0:99], x_set[0:99], y_set[0:99], 10)
    rss_error10 = rss_error(y_set[0:99], y_hat10.T)
    print("The training error:\nThe prediction error when k=0.1:", rss_error01)
    print("The prediction error when k=1:", rss_error1)
    print("The prediction error when k=10:", rss_error10)
    y_hat01 = lwlr_test(x_set[100:199], x_set[0:99], y_set[0:99], 0.1)
    rss_error01 = rss_error(y_set[100:199], y_hat01.T)
    y_hat1 = lwlr_test(x_set[100:199], x_set[0:99], y_set[0:99], 1)
    rss_error1 = rss_error(y_set[100:199], y_hat1.T)
    y_hat10 = lwlr_test(x_set[100:199], x_set[0:99], y_set[0:99], 10)
    rss_error10 = rss_error(y_set[100:199], y_hat10.T)
    print("The test error:\nThe prediction error when k=0.1:", rss_error01)
    print("The prediction error when k=1:", rss_error1)
    print("The prediction error when k=10:", rss_error10)

if __name__ == '__main__':
    test3()

  运行结果:

The training error:
The prediction error when k=0.1: 56.78868743050092
The prediction error when k=1: 429.89056187038
The prediction error when k=10: 549.1181708827924
The test error:
The prediction error when k=0.1: 57913.51550155911
The prediction error when k=1: 573.5261441895982
The prediction error when k=10: 517.5711905381903

  如前所述,k过小时尽管训练预测误差较小,但测试预测误差较大,即欠拟合;反之k过大,则出现训练预测误差大于测试误差的情况。具体的k取值多少,则需要依多次实验而定。

3 缩减系数来“理解”数据

  当数据的特征比样本点多时,线性回归和局部加权线性回归便不再适用,这是因为在计算 ( X T X ) − 1 (X^TX)^{-1} (XTX)1时会出错。即此时输入数据的矩阵 X X X不是满秩矩阵。
  为了解决该问题,统计学家引入了岭回归(ridge regression)的概念。

3.1 岭回归

  简单来说,岭回归就是在矩阵 X T X X^TX XTX上加一个 λ I \lambda I λI从而使得矩阵满秩,进而能对 X T X + λ I X^TX+\lambda I XTX+λI求逆。其中矩阵 I I I是一个 n × n n×n n×n的单位矩阵,对角线的元素全为1,其他元素全为0.而 λ \lambda λ是一个用户输入的数值。由此回归系数的公式变为:
W ^ = ( X T X + λ I ) − 1 X T Y (3-1) \hat{W}=(X^TX+\lambda I)^{-1}X^TY\tag{3-1} W^=(XTX+λI)1XTY(3-1)

岭回归中的岭是什么?
  岭回归使用了单位矩阵乘以常量 λ \lambda λ,观察其中的单位矩阵 i i i,可以发现值 I I I贯穿整个对角线,其余元素全是0。形象的,在0构成的平面上有一条1组成的“岭”,这便是“岭”的由来。

  岭回归最先用来处理特征数多于样本数的情况,现在也用于在估计中加入偏差,从而得到更好的估计。这里通过引入 λ \lambda λ来限制所有 W W W的和,通过引入该惩罚项,能够减少不重要的参数,这个技术在统计学也叫做 缩减(shrinkage)。
缩减方法可以去掉不中的参数,因此能够更好地理解数据,自然能比简单线性回归取得更好的结果。与之前类似,通过预测误差最小化得到 λ \lambda λ,再求得 W W W。于regression.py文件并添加以下代码:

程序清单3-1: 岭回归

"""岭回归"""
def ridge_regres(x_mat, y_mat, lam=0.2):
    x_tx = x_mat.T * x_mat
    denom = x_tx + eye(shape(x_mat)[1]) * lam    #对应岭回归公式
    if linalg.det(denom) == 0.0:    #lam=0时依然会出现错误
        print("This matrix is singular,cannot do inverse")
        return
    w = denom.I * (x_mat.T * y_mat)    #.I为求逆
    return w

def ridge_test(x_set, y_set, N=30):
    x_mat = mat(x_set); y_mat = mat(y_set).T
    y_mean = mean(y_mat, 0)
    y_mat = y_mat - y_mean    #特征标准化处理,使得每维特征具有同等重要性,不考虑特征代表什么
    x_means = mean(x_mat, 0)
    x_var = var(x_mat, 0)    #每列方差
    x_mat = (x_mat - x_means) / x_var
    num_test = N
    w_mat = zeros((num_test, shape(x_mat)[1]))
    for i in range(num_test):
        w = ridge_regres(x_mat, y_mat, exp(i - 10))    #lam呈指数级变化,这样可以看出lam为较小值与较大值时对结果的影响
        w_mat[i,:] = w.T
    return w_mat

def test4():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data/abalone.txt')
    w_mat = ridge_test(x_set, y_set)
    print("W matrix:", w_mat)
    plt.plot(w_mat)
    plt.xlabel('log(lambda)')
    plt.axis([-1, 30, -1, 2.5])
    plt.show()

  运行结果:

...
 [-3.13618246e-06  3.43488557e-04  4.29265642e-04  9.86279863e-04
   8.16188652e-05  1.39858822e-04  3.40121256e-04  3.34847052e-04]]

在这里插入图片描述

图3-1 岭回归回归系数变化图


   λ \lambda λ非常小,系数与普通回归系数一样; λ \lambda λ非常大时,所有回归系数缩减为0。可以在中间某处找到使得预测结果最好的 λ \lambda λ。如何找到?则可以交叉验证。
  还有其他一些缩减方法,如lasso、LAR、PCA回归以及子集选择等。与岭回归一样,这些方法不仅可以提高预测精度,也可以解释回归系数。

3.2 lasso

  在增加约束的时,普通的最小二乘法回归会得到与岭回归一样的公式:
∑ k = 1 n W k 2 ≤ λ (3-2) \sum^n_{k=1}W^2_k\le\lambda\tag{3-2} k=1nWk2λ(3-2)  上式限定了所有回归系数的平方和不能大于 λ \lambda λ,从而避免当两个或更多的特征相关时,出现一个很大正系数或者负系数的情况。与此类似,lasso也对回归系数做了限定:
∑ k = 1 n ∣ W k ∣ ≤ λ (3-3) \sum^n_{k=1}|W_k|\le\lambda\tag{3-3} k=1nWkλ(3-3)  这里的约束条件用绝对值代替平方和,使得在 λ \lambda λ足够小时,一些系数被迫缩减为0。这个特征可以更好地理解数据,但是也大大增加了计算复杂度,如果需要在此条件下解出回归系数,需要使用二次规划算法。

3.3 前向逐步回归

  前向逐步回归在更加简单的情况下,也能达到和lasso差不多的效果。它属于一种贪心算法,即每一步都尽可能减小误差。一开始,所有权重都设为1,然后每一步所做的决策是对某个权重增加或者减小一个很小的值。其伪代码如下:
  数据标准化,使其分布满足0均值和单位方差
  在每轮迭代中:
    设置当前最小误差lowest_error为正无穷
    对每个特征:
      增大或减小:
        改变一个系数得到一个新的 W W W
        计算新 W W W下的误差error
        如果当前误差error小于最小误差lowest_error:
          设置W_best等于当前 W W W
      将 W W W设置为新的W_best

于regression.py文件并添加以下代码:

程序清单3-2: 前向逐步线性回归

def stage_wise(x_set, y_set, eps=0.1, max_iter=100):    #eps为每次迭代需要调整的步长
    x_mat = mat(x_set); y_mat = mat(y_set).T
    y_mean = mean(y_mat, 0)    #平均值
    y_mat = y_mat - y_mean
    x_mat = regularize(x_mat)
    m ,n =shape(x_mat)
    return_mat = zeros((max_iter, n))
    w = zeros((n, 1))
    w_max = w.copy()
    for i in range(max_iter):
        # print(w.T)
        lowest_error = inf;
        for j in range(n):
            for sign in [-1, 1]:
                w_test = w.copy()
                w_test[j] += eps * sign
                y_test = x_mat * w_test
                rss_e = rss_error(y_mat.A, y_test.A)
                if rss_e < lowest_error:
                    lowest_error = rss_e
                    w_max = w_test
        w = w_max.copy()
        return_mat[i] = w.T
    return return_mat

def test5():
    x_set, y_set = load_data_set('E:\Machine Learing\myMachineLearning\data/abalone.txt')
    return_mat = stage_wise(x_set, y_set, 0.01, 200)
    print("Return mat:\n", return_mat)
    plt.plot(return_mat)
    plt.show()

if __name__ == '__main__':
    test5()
    # test4()
    # test3()
    # test2()
    # test1()

  运行结果:

Return mat:
 [[ 0.    0.    0.   ...  0.    0.    0.  ]
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 [ 0.    0.    0.   ...  0.    0.    0.  ]
 ...
 [ 0.05  0.    0.09 ... -0.64  0.    0.36]
 [ 0.04  0.    0.09 ... -0.64  0.    0.36]
 [ 0.05  0.    0.09 ... -0.64  0.    0.36]]

在这里插入图片描述

图3-2 前向逐步回归系数变化图(eps=0.01,max_iter=200)


  上述结果中,W1和W6都是0,这表明它们不对目标值造成任何影响,很有可能这些特征是不需要的。另外在eps设置为0.01时,一段时间后系数就已经饱和并在特定值之间来回震荡,这是因为步长太长的原因。接下来尝试更小步长并绘制:
在这里插入图片描述

图3-3 前向逐步回归系数变化图(eps=0.005,max_iter=1000)


  前向逐步线性回归主要的优点在于:
  可以帮助理解现在模型并改进。当构建一个模型之后,可以运行该算法找出重要的特征,这样就有可能及时停止收集那些不重要特征的收集。

4 权衡偏差与方差

  一旦发现模型与测量值之间存在差异,就说明出现了误差。当考虑模型中“噪声”或者说误差时,必须考虑其来源:
  1)对复杂的过程进行简化时,将导致模型和测量值之间出现“噪声”或误差;
  2)若无法理解数据的真实生成过程,会导致差异的发生;
  3)测量过程本身也可能产生“噪声”或问题。
  例如之前用到的‘ex0.txt’数据集,是认为制造的,其生成公式如下:
y = 3.0 + 1.7 x + 0.1 s i n ( 30 x ) + 0.06 N ( 0 , 1 ) y=3.0+1.7x+0.1sin(30x)+0.06N(0,1) y=3.0+1.7x+0.1sin(30x)+0.06N(0,1)其中 N ( 0 , 1 ) N(0,1) N(0,1)是一个均值为0、方差为1的正太分布。若用一条进行进行拟合,那么最佳拟合应该是 3.0 + 1.7 x 3.0 + 1.7x 3.0+1.7x这部分,这样一来,误差部分则是 0.1 s i n ( 30 x ) + 0.06 N ( 0 , 1 ) 0.1sin(30x)+0.06N(0,1) 0.1sin(30x)+0.06N(0,1)
  下图为训练误差和测试误差的曲线图(来源):
在这里插入图片描述

图4-1 偏差方差折中与测试误差及训练误差的关系(eps=0.005,max_iter=1000)


  根据局部加权线性回归中的实验知道:
  1)降低核的大小,那么训练误差将变小,即图中红色线;
  2)降低核的大小,那么测试误差将有一个先变小后变大的过程,即黑色线。
  一般认为,上述两种误差由三部分组成:
  偏差、测量误差、随机噪声。
  
  如果从鲍鱼数据集中取一个随机样本集,例如取其中100个数据,并用线性模型拟合,将会得到一组回归系数。同理,再取另一组随机样本集并拟合,将会得到另一组回归系数。这些系数间的差异就是模型方差1大小的反映。


  1. 方差指模型之间的差异;偏差指模型预测值和数据之间的差异。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44575152/article/details/100640051

智能推荐

while循环&CPU占用率高问题深入分析与解决方案_main函数使用while(1)循环cpu占用99-程序员宅基地

文章浏览阅读3.8k次,点赞9次,收藏28次。直接上一个工作中碰到的问题,另外一个系统开启多线程调用我这边的接口,然后我这边会开启多线程批量查询第三方接口并且返回给调用方。使用的是两三年前别人遗留下来的方法,放到线上后发现确实是可以正常取到结果,但是一旦调用,CPU占用就直接100%(部署环境是win server服务器)。因此查看了下相关的老代码并使用JProfiler查看发现是在某个while循环的时候有问题。具体项目代码就不贴了,类似于下面这段代码。​​​​​​while(flag) {//your code;}这里的flag._main函数使用while(1)循环cpu占用99

【无标题】jetbrains idea shift f6不生效_idea shift +f6快捷键不生效-程序员宅基地

文章浏览阅读347次。idea shift f6 快捷键无效_idea shift +f6快捷键不生效

node.js学习笔记之Node中的核心模块_node模块中有很多核心模块,以下不属于核心模块,使用时需下载的是-程序员宅基地

文章浏览阅读135次。Ecmacript 中没有DOM 和 BOM核心模块Node为JavaScript提供了很多服务器级别,这些API绝大多数都被包装到了一个具名和核心模块中了,例如文件操作的 fs 核心模块 ,http服务构建的http 模块 path 路径操作模块 os 操作系统信息模块// 用来获取机器信息的var os = require('os')// 用来操作路径的var path = require('path')// 获取当前机器的 CPU 信息console.log(os.cpus._node模块中有很多核心模块,以下不属于核心模块,使用时需下载的是

数学建模【SPSS 下载-安装、方差分析与回归分析的SPSS实现(软件概述、方差分析、回归分析)】_化工数学模型数据回归软件-程序员宅基地

文章浏览阅读10w+次,点赞435次,收藏3.4k次。SPSS 22 下载安装过程7.6 方差分析与回归分析的SPSS实现7.6.1 SPSS软件概述1 SPSS版本与安装2 SPSS界面3 SPSS特点4 SPSS数据7.6.2 SPSS与方差分析1 单因素方差分析2 双因素方差分析7.6.3 SPSS与回归分析SPSS回归分析过程牙膏价格问题的回归分析_化工数学模型数据回归软件

利用hutool实现邮件发送功能_hutool发送邮件-程序员宅基地

文章浏览阅读7.5k次。如何利用hutool工具包实现邮件发送功能呢?1、首先引入hutool依赖<dependency> <groupId>cn.hutool</groupId> <artifactId>hutool-all</artifactId> <version>5.7.19</version></dependency>2、编写邮件发送工具类package com.pc.c..._hutool发送邮件

docker安装elasticsearch,elasticsearch-head,kibana,ik分词器_docker安装kibana连接elasticsearch并且elasticsearch有密码-程序员宅基地

文章浏览阅读867次,点赞2次,收藏2次。docker安装elasticsearch,elasticsearch-head,kibana,ik分词器安装方式基本有两种,一种是pull的方式,一种是Dockerfile的方式,由于pull的方式pull下来后还需配置许多东西且不便于复用,个人比较喜欢使用Dockerfile的方式所有docker支持的镜像基本都在https://hub.docker.com/docker的官网上能找到合..._docker安装kibana连接elasticsearch并且elasticsearch有密码

随便推点

Python 攻克移动开发失败!_beeware-程序员宅基地

文章浏览阅读1.3w次,点赞57次,收藏92次。整理 | 郑丽媛出品 | CSDN(ID:CSDNnews)近年来,随着机器学习的兴起,有一门编程语言逐渐变得火热——Python。得益于其针对机器学习提供了大量开源框架和第三方模块,内置..._beeware

Swift4.0_Timer 的基本使用_swift timer 暂停-程序员宅基地

文章浏览阅读7.9k次。//// ViewController.swift// Day_10_Timer//// Created by dongqiangfei on 2018/10/15.// Copyright 2018年 飞飞. All rights reserved.//import UIKitclass ViewController: UIViewController { ..._swift timer 暂停

元素三大等待-程序员宅基地

文章浏览阅读986次,点赞2次,收藏2次。1.硬性等待让当前线程暂停执行,应用场景:代码执行速度太快了,但是UI元素没有立马加载出来,造成两者不同步,这时候就可以让代码等待一下,再去执行找元素的动作线程休眠,强制等待 Thread.sleep(long mills)package com.example.demo;import org.junit.jupiter.api.Test;import org.openqa.selenium.By;import org.openqa.selenium.firefox.Firefox.._元素三大等待

Java软件工程师职位分析_java岗位分析-程序员宅基地

文章浏览阅读3k次,点赞4次,收藏14次。Java软件工程师职位分析_java岗位分析

Java:Unreachable code的解决方法_java unreachable code-程序员宅基地

文章浏览阅读2k次。Java:Unreachable code的解决方法_java unreachable code

标签data-*自定义属性值和根据data属性值查找对应标签_如何根据data-*属性获取对应的标签对象-程序员宅基地

文章浏览阅读1w次。1、html中设置标签data-*的值 标题 11111 222222、点击获取当前标签的data-url的值$('dd').on('click', function() { var urlVal = $(this).data('ur_如何根据data-*属性获取对应的标签对象

推荐文章

热门文章

相关标签