ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

numpy学习线性回归, 并用matplotlib画动态图

2021-04-14 21:03:32  阅读:290  来源: 互联网

标签:plt matplotlib spread 参数 动态图 np Theta 100 numpy


线性回归

线性回归

准备

假设函数为一元一次函数
h = Θ 0 + Θ 1 x h = \Theta _{0} + \Theta _{1}x h=Θ0​+Θ1​x

代价函数
J ( Θ 0 , Θ 1 ) = 1 2 m ∑ i = 1 m ( h ( x i ) ) − y i ) 2 J(\Theta _{0}, \Theta _{1}) = \frac{1}{2m}\sum_{i=1}^{m}(h(x_{i})) - y_i)^2 J(Θ0​,Θ1​)=2m1​i=1∑m​(h(xi​))−yi​)2
我们的目的就是找到参数 Θ 0 \Theta_0 Θ0​和 Θ 1 \Theta_1 Θ1​使得代价函数值最小

梯度下降方式求

梯度下降算法, 参数 = 参数 - 学习率 * 代价函数对参数求偏导
Θ j = Θ j − α ∂ ∂ Θ j J ( Θ 0 , Θ 1 ) \Theta_j = \Theta_j - \alpha \frac{\partial }{\partial \Theta_j }J(\Theta_0, \Theta_1) Θj​=Θj​−α∂Θj​∂​J(Θ0​,Θ1​)

import numpy as np  # 引入numpy

# 生成数据
x = np.arange(100).reshape(100, 1)
# 按照y = 10x + 100 生成数据, 并且数据会上下100随机误差
y = 10 * x + 100  + np.random.randint(-100, 100, size=(100, 1))

# 假设函数 h = ax + b
# 初始化参数, 这里都设置为0, 也能设置两个随机数
a = 0
b = 0
learning_rate = 0.0001  # 学习率

# 死循环梯度下降
while True:
	# 这里不要把求出的偏导值直接修改参数, 避免影响其他参数的计算, 所以先保留原参数值
    a_spread = np.mean((a*x+b-y)*x)  # 记录a参数求偏导值
    b_spread = np.mean(a*x+b-y)  # 记录b参数求偏导值
    
    a = a - a_spread * learning_rate  # 修改参数
    b = b - b_spread * learning_rate
	
	if abs(a_spread) <= 0.1 and abs(b_spread) <= 0.1:
		# 这里设置一个阈值, 本身最好观察梯度下降函数图像, 发现基本不下降了才停止
		break
print(a, b)  # 打印梯度下降求出的两个参数值

在matplotlib画出图像

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display

# 生成数据
x = np.arange(100).reshape(100, 1)
# 按照y = 10x + 100 生成数据, 并且数据会上下100随机误差
y = 10 * x + 100  + np.random.randint(-100, 100, size=(100, 1))
# 初始化 h = ax + b, a=0, b=0
# 假设函数 h = ax + b
# 初始化参数, 这里都设置为0, 也能设置两个随机数
a = 0
b = 0
learning_rate = 0.0001  # 学习率
while True:
    a_spread = np.mean((a*x+b-y)*x)
    b_spread = np.mean(a*x+b-y)
    a = a - a_spread * learning_rate
    b = b - b_spread * learning_rate
    # 画出函数图像
    _x = np.linspace(x[0], x[-1], 2*len(x))
    _y = _x * a + b
    # 这里可以每隔一定次数在画出图像, 因为b参数会变化很慢
    clear_output(wait=True)  # 清除打印信息
    plt.ylim(0, 1300)  # 控制y轴显示范围
    plt.xlim(0, 200)  # 控制x轴显示范围
    plt.scatter(x, y, s=5)  # 画出散点图
    plt.plot(_x, _y, color='red', linewidth=1.0, linestyle='--')  # 画出假设函数h图像
    plt.annotate(  # 显示一个文本框指向最后一个数据
        s=f"{a=:.2f} {b=:.2f}",  # 文本内容
        xy=(_x[-1], _y[-1]),  # 箭头点所在坐标
        xytext=(_x[-1]+10, _y[-1]-100),  # 文本内容所在坐标
        weight='bold',  # 字体线型
        color='aqua',  # 字体颜色
        arrowprops=dict(arrowstyle='-|>', connectionstyle='arc3', color='red'),
        bbox=dict(boxstyle='round,pad=0.5', fc='yellow', ec='k',lw=1 ,alpha=0.4)
    )
    plt.pause(0.1)
    if abs(a_spread) <= 0.1 and abs(b_spread) <= 0.1:
        plt.savefig('result.png', bbox_inches='tight', pad_inches=0)  # 保存结果
        plt.close()
        break

正规矩阵方式求

公式: Θ = ( X T X ) − 1 X T y \Theta = (X^TX)^-1X^Ty Θ=(XTX)−1XTy

import numpy as np
# 生成数据
x = np.arange(100).reshape(100, 1)
# 按照y = 10x + 100 生成数据, 并且数据会上下100随机误差
y = 10 * x + 100 + np.random.randint(-100, 100, size=(100, 1))

# 构建X矩阵
X = np.matrix(np.c_[np.ones((100, 1)), x])  # 添加一列全为1的一列, 作为x0
Theta = (X.T * X).I * X.T * y  # .T是转置, .I是逆矩阵, 逆矩阵也可以用np.linalg.pinv求解伪逆, 避免不存在逆矩阵情况
print(Theta)  # 打印求解后的参数

标签:plt,matplotlib,spread,参数,动态图,np,Theta,100,numpy
来源: https://blog.csdn.net/flyinghu123/article/details/115709211

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有