ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pycharm2019+Tensorflow2.0 学习文档(九):3-1,低阶API示范1

2022-01-07 17:04:07  阅读:231  来源: 互联网

标签:loss plt features Pycharm2019 labels Tensorflow2.0 API tf model


参考:参考1参考2

总体来说,常见的带监督的机器学习问题主要分为两类:分类和回归,我们使用 Tensorflow 来解决这些问题的时候需要自己搭建网络,但是Tensorflow不同级别的API也就产生了不同的模型搭建方式。越底层的API灵活性越大,可以更加自由的添加自己想加入的内容,但是编码难度有所提高;反之,越高阶的API具有更好的封装性,但是灵活度会有所下降,本次内容先从低阶API示意说明。

1、回归问题

1.1 数据生成

1、首先我们得自己设计一个回归问题,也就是建一个方程,然后训练网络去拟合它。

我们熟知的线性方程:Y = W ∗ X + b Y=W*X+bY=W∗X+b

我们这里生成400个数据,X是在(-10,10)之间的均匀分布,W为(2,-2),b=3,另外添加噪音

# 样本数量
n = 400

# 生成测试用数据集
X = tf.random.uniform([n, 2], minval=-10, maxval=10)  # 从均匀分布中随机获得 n 个数据(shape:400*2)
w0 = tf.constant([[2.0], [-3.0]])  # shape: 2*1
b0 = tf.constant([[3.0]])  # shape: 1*1
Y = X @ w0 + b0 + tf.random.normal([n, 1], mean=0.0, stddev=2.0)  # @表示矩阵乘法,增加正态扰动 shape: 400*1

然后显示我们自己生成的数据

plt.figure(figsize=(12, 5))
ax1 = plt.subplot(121)
ax1.scatter(X[:, 0], Y[:, 0], c="b")
plt.xlabel("x1")
plt.ylabel("y", rotation=0)

ax2 = plt.subplot(122)
ax2.scatter(X[:, 1], Y[:, 0], c="g")
plt.xlabel("x2")
plt.ylabel("y", rotation=0)
plt.show()

2、简单说明下:

(1)scatter 函数的参数详解可以参考这个位置看下 scatter()

(2)x[m,n]是通过numpy库引用数组或矩阵中的某一段数据集的一种写法,m代表第m维,n代表m维中取第几段特征数据。

典型用法:x[:,n]或者x[n,:]

x[:,n]表示在全部数组(维)中取第n个数据,直观来说,x[:,n]就是取所有集合的第n个数据,

例如:x[:,0],即x中所有的第0列数据

(3)数据显示结果

 3、构建数据生成器

整体思路:(1)随机打乱数据下标

(2)遍历数据,每一个batch-size作为一个分割,得到打乱后的下标切片(大小为batch_size)

(3)使用tf.gather() 函数将 X、Y分别和上一步得到的随机下标组合,yield 返回生成器。

# 构建数据管道迭代器
def data_iter(features, labels, batch_size=8):
    num_examples = len(features)  # 计算容器中一共有多少条信息(400)
    indices = list(range(num_examples))  # 记录下标
    np.random.shuffle(indices)  # 样本的读取顺序是随机的
    for i in range(0, num_examples, batch_size):
        indexs = indices[i: min(i + batch_size, num_examples)]  # 确认选择的 index
        # 使用tf.gather()函数将X,Y分别和上一步得到的随机下标组合,yield返回生成器
        # tf.gather(params,indices,axis=0)函数是根据indices下标从params中返回对应元素的切片
        yield tf.gather(features, indexs), tf.gather(labels, indexs)


# 测试数据管道效果
batch_size = 8
# next 返回迭代器的下一个条目
(features, labels) = next(data_iter(X, Y, batch_size))  # 结果: features shape 8*2; labels shape 8*1
print(features)
print(labels)

 

4、定义模型

w = tf.Variable(tf.random.normal(w0.shape))
b = tf.Variable(tf.zeros_like(b0, dtype=tf.float32))


# 定义模型
class LinearRegression:
    # 正向传播
    def __call__(self, x):
        return x @ w + b

    # 损失函数
    def loss_func(self, y_true, y_pred):
        return tf.reduce_mean((y_true - y_pred) ** 2 / 2)


model = LinearRegression()

说明: _call_ 是一个特殊方法,一旦定义之后,类的实例可以变成一个可调用对象。

5、训练模型

(1)定义train_step 完成每一步的梯度求取和参数更新

(2)执行训练过程,使用 autograph 进行加速的时候还是很明显的,可以对比观察下。

@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features)
        loss = model.loss_func(labels, predictions)
    # 反向传播求梯度
    dloss_lw, dloss_db = tape.gradient(loss, [w, b])
    # 梯度下降算法更新参数
    w.assign(w - 0.001 * dloss_lw)
    b.assign(b - 0.001 * dloss_db)
    return loss
def train_model(model,epochs):
    for epoch in tf.range(1,epochs+1):
        for features, labels in data_iter(X,Y,10):
            loss = train_step(model,features,labels)

        if epoch%50==0:
            printbar()
            tf.print("epoch =",epoch,"loss = ",loss)
            tf.print("w =",w)
            tf.print("b =",b)

train_model(model,epochs = 200)

(3)结果可视化

# 结果可视化
plt.figure(figsize=(12, 5))
ax1 = plt.subplot(121)
ax1.scatter(X[:, 0], Y[:, 0], c="b", label="samples")
ax1.plot(X[:, 0], w[0] * X[:, 0] + b[0], "-r", linewidth=5.0, label="model")
ax1.legend()
plt.xlabel("x1")
plt.ylabel("y", rotation=0)

ax2 = plt.subplot(122)
ax2.scatter(X[:, 1], Y[:, 0], c="g", label="samples")
ax2.plot(X[:, 1], w[1] * X[:, 1] + b[0], "-r", linewidth=5.0, label="model")
plt.xlabel("x2")
plt.ylabel("y", rotation=0)

plt.show()

标签:loss,plt,features,Pycharm2019,labels,Tensorflow2.0,API,tf,model
来源: https://blog.csdn.net/sinat_34520704/article/details/122350306

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有