TensoeFlow学习笔记5:线性回归
线性回归
线性回归(Linear Regression)是一种通过属性的线性组合来进行预测的线性模型,其目的是找到一条直线或者一个平面或者更高维的超平面,使得预测值与真实值之间的误差最小化。
简单实现
具体代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
num_points=1000
vectors_set=[]
for i in range(num_points):
# 横坐标,进行随机高斯处理化,以0为均值,以0.55为标准差
x1=np.random.normal(0.0,0.55)
# 纵坐标,数据点在y1=x1*0.1+0.3上小范围浮动
y1=x1*0.1+0.3+np.random.normal(0.0,0.03)
vectors_set.append([x1,y1])
x_data=[v[0] for v in vectors_set]
y_data=[v[1] for v in vectors_set]
# 生成1维的W矩阵,取值是[-1,1]之间的随机数
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
# 生成1维的b矩阵,初始值是0
b = tf.Variable(tf.zeros([1]), name='b')
# 经过计算得出预估值y
y = W * x_data + b
# 以预估值y和实际值y_data之间的均方误差作为损失
loss = tf.reduce_mean(tf.square(y - y_data), name='loss')
# 采用梯度下降法来优化参数 学习率为0.5
optimizer = tf.train.GradientDescentOptimizer(0.5)
# 训练的过程就是最小化这个误差值
train = optimizer.minimize(loss, name='train')
# sess = tf.Session()
init = tf.global_variables_initializer()
# 创建会话
with tf.Session() as sess:
sess.run(init)
# 执行20次训练
for step in range(20):
sess.run(train) # 输出训练好的W和b
print("W =", sess.run(W), "b =", sess.run(b), "loss =", sess.run(loss))
plt.scatter(x_data,y_data,c='r')
plt.plot(x_data,sess.run(y))
plt.show()
输出结果:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20W = [0.41536754] b = [0.29331774] loss = 0.028788242
W = [0.32664812] b = [0.29488927] loss = 0.015261736
W = [0.26295245] b = [0.29606903] loss = 0.008289363
W = [0.2172218] b = [0.29691604] loss = 0.004695383
W = [0.18438922] b = [0.29752415] loss = 0.002842829
W = [0.1608169] b = [0.29796076] loss = 0.0018879117
W = [0.143893] b = [0.29827422] loss = 0.0013956894
W = [0.1317424] b = [0.2984993] loss = 0.0011419685
W = [0.12301882] b = [0.29866084] loss = 0.0010111856
W = [0.11675566] b = [0.29877687] loss = 0.00094377215
W = [0.112259] b = [0.29886016] loss = 0.0009090233
W = [0.1090306] b = [0.29891995] loss = 0.0008911116
W = [0.10671275] b = [0.29896286] loss = 0.00088187883
W = [0.10504864] b = [0.2989937] loss = 0.0008771197
W = [0.10385388] b = [0.29901582] loss = 0.00087466656
W = [0.1029961] b = [0.2990317] loss = 0.00087340205
W = [0.10238025] b = [0.29904312] loss = 0.0008727502
W = [0.10193809] b = [0.2990513] loss = 0.0008724143
W = [0.10162064] b = [0.2990572] loss = 0.0008722411
W = [0.10139273] b = [0.29906142] loss = 0.00087215187