如何保存和恢复TensorFlow训练的模型

作者: Mihajlo Pavloski 2017-11-01 15:13:49

如果深层神经网络模型的复杂度非常高的话,那么训练它可能需要相当长的一段时间,当然这也取决于你拥有的数据量,运行模型的硬件等等。在大多数情况下,你需要通过保存文件来保障你试验的稳定性,防止如果中断(或一个错误),你能够继续从没有错误的地方开始。

更重要的是,对于任何深度学习的框架,像TensorFlow,在成功的训练之后,你需要重新使用模型的学习参数来完成对新数据的预测。

如何保存和恢复TensorFlow训练的模型

在这篇文章中,我们来看一下如何保存和恢复TensorFlow模型,我们在此介绍一些最有用的方法,并提供一些例子。

1. 首先我们将快速介绍TensorFlow模型

TensorFlow的主要功能是通过张量来传递其基本数据结构类似于NumPy中的多维数组,而图表则表示数据计算。它是一个符号库,这意味着定义图形和张量将仅创建一个模型,而获取张量的具体值和操作将在会话(session)中执行,会话(session)一种在图中执行建模操作的机制。会话关闭时,张量的任何具体值都会丢失,这也是运行会话后将模型保存到文件的另一个原因。

通过示例可以帮助我们更容易理解,所以让我们为二维数据的线性回归创建一个简单的TensorFlow模型。

首先,我们将导入我们的库:

  1. import tensorflow as tf   
  2. import numpy as np   
  3. import matplotlib.pyplot as plt   
  4. %matplotlib inline 

下一步是创建模型。我们将生成一个模型,它将以以下的形式估算二次函数的水平和垂直位移:

  1. y = (x - h) ^ 2 + v 

其中h是水平和v是垂直的变化。

以下是如何生成模型的过程(有关详细信息,请参阅代码中的注释):

  1. # Clear the current graph in each run, to avoid variable duplication 
  2. tf.reset_default_graph() 
  3. # Create placeholders for the x and y points 
  4. X = tf.placeholder("float")   
  5. Y = tf.placeholder("float") 
  6. # Initialize the two parameters that need to be learned 
  7. h_est = tf.Variable(0.0, name='hor_estimate')   
  8. v_est = tf.Variable(0.0, name='ver_estimate'
  9. # y_est holds the estimated values on y-axis 
  10. y_est = tf.square(X - h_est) + v_est 
  11. # Define a cost function as the squared distance between Y and y_est 
  12. cost = (tf.pow(Y - y_est, 2)) 
  13. # The training operation for minimizing the cost function. The 
  14. # learning rate is 0.001 
  15. trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost) 

在创建模型的过程中,我们需要有一个在会话中运行的模型,并且传递一些真实的数据。我们生成一些二次数据(Quadratic data),并给他们添加噪声。

  1. # Use some values for the horizontal and vertical shift 
  2. h = 1   
  3. v = -2 
  4. # Generate training data with noise 
  5. x_train = np.linspace(-2,4,201)   
  6. noise = np.random.randn(*x_train.shape) * 0.4   
  7. y_train = (x_train - h) ** 2 + v + noise 
  8. # Visualize the data  
  9. plt.rcParams['figure.figsize'] = (10, 6)   
  10. plt.scatter(x_train, y_train)   
  11. plt.xlabel('x_train')   
  12. plt.ylabel('y_train') 

2. The Saver class

Saver类是TensorFlow库提供的类,它是保存图形结构和变量的***方法。

(1) 保存模型

在以下几行代码中,我们定义一个Saver对象,并在train_graph()函数中,经过100次迭代的方法最小化成本函数。然后,在每次迭代中以及优化完成后,将模型保存到磁盘。每个保存在磁盘上创建二进制文件被称为“检查点”。

  1. # Create a Saver object 
  2. saver = tf.train.Saver() 
  3.  
  4. init = tf.global_variables_initializer() 
  5.  
  6. # Run a session. Go through 100 iterations to minimize the cost 
  7. def train_graph():   
  8.     with tf.Session() as sess: 
  9.         sess.run(init) 
  10.         for i in range(100): 
  11.             for (x, y) in zip(x_train, y_train): 
  12.  
  13.                 # Feed actual data to the train operation 
  14.                 sess.run(trainop, feed_dict={X: x, Y: y}) 
  15.  
  16.             # Create a checkpoint in every iteration 
  17.             saver.save(sess, 'model_iter', global_step=i
  18.  
  19.         # Save the final model 
  20.         saver.save(sess, 'model_final') 
  21.         h_ = sess.run(h_est) 
  22.         v_ = sess.run(v_est) 
  23.     return h_, v_ 

现在让我们用上述功能训练模型,并打印出训练的参数。

  1. result = train_graph()   
  2. print("h_est = %.2f, v_est = %.2f" % result)   
  3.  
  4. $ python tf_save.py 
  5. h_est = 1.01, v_est = -1.96 

Okay,参数是非常准确的。如果我们检查我们的文件系统,***4次迭代中保存有文件以及最终的模型。

保存模型时,你会注意到需要4种类型的文件才能保存:

  • “.meta”文件:包含图形结构。
  • “.data”文件:包含变量的值。
  • “.index”文件:标识检查点。
  • “checkpoint”文件:具有最近检查点列表的协议缓冲区。

检查点文件保存到磁盘

图1:检查点文件保存到磁盘

调用tf.train.Saver()方法,如上所示,将所有变量保存到一个文件。通过将它们作为参数,表情通过列表或dict传递来保存变量的子集,例如:tf.train.Saver({‘hor_estimate’: h_est})。

Saver构造函数的一些其他有用的参数,也可以控制整个过程,它们是:

  • max_to_keep:最多保留的检查点数。
  • keep_checkpoint_every_n_hours:保存检查点的时间间隔。如果你想要了解更多信息,请查看官方文档的Saver类,它提供了其它有用的信息,你可以探索查看。
  • Restoring Models

恢复TensorFlow模型时要做的***件事就是将图形结构从“.meta”文件加载到当前图形中。

  1. tf.reset_default_graph()   
  2. imported_meta = tf.train.import_meta_graph("model_final.meta") 

也可以使用以下命令探索当前图形tf.get_default_graph()。接着第二步是加载变量的值。提醒:值仅存在于会话(session)中。

  1. with tf.Session() as sess:   
  2.     imported_meta.restore(sess, tf.train.latest_checkpoint('./')) 
  3.     h_est2 = sess.run('hor_estimate:0') 
  4.     v_est2 = sess.run('ver_estimate:0') 
  5.     print("h_est: %.2f, v_est: %.2f" % (h_est2, v_est2)) 
  1. $ python tf_restore.py 
  2. INFO:tensorflow:Restoring parameters from ./model_final   
  3. h_est: 1.01, v_est: -1.96 

如前面所提到的,这种方法只保存图形结构和变量,这意味着通过占位符“X”和“Y”输入的训练数据不会被保存。

无论如何,在这个例子中,我们将使用我们定义的训练数据tf,并且可视化模型拟合。

  1. plt.scatter(x_train, y_train, label='train data')   
  2. plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2, color='red'label='model')   
  3. plt.xlabel('x_train')   
  4. plt.ylabel('y_train')   
  5. plt.legend()  

Saver这个类允许使用一个简单的方法来保存和恢复你的TensorFlow模型(图形和变量)到/从文件,并保留你工作中的多个检查点,这可能是有用的,它可以帮助你的模型在训练过程中进行微调。

4. SavedModel格式(Format)

在TensorFlow中保存和恢复模型的一种新方法是使用SavedModel,Builder和loader功能。这个方法实际上是Saver提供的更高级别的序列化,它更适合于商业目的。

虽然这种SavedModel方法似乎不被开发人员完全接受,但它的创作者指出:它显然是未来。与Saver主要关注变量的类相比,SavedModel尝试将一些有用的功能包含在一个包中,例如Signatures:允许保存具有一组输入和输出的图形,Assets:包含初始化中使用的外部文件。

(1) 使用SavedModel Builder保存模型

接下来我们尝试使用SavedModelBuilder类完成模型的保存。在我们的示例中,我们不使用任何符号,但也足以说明该过程。

  1. tf.reset_default_graph() 
  2. # Re-initialize our two variables 
  3. h_est = tf.Variable(h_est2, name='hor_estimate2')   
  4. v_est = tf.Variable(v_est2, name='ver_estimate2'
  5.  
  6. # Create a builder 
  7. builder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/') 
  8.  
  9. # Add graph and variables to builder and save 
  10. with tf.Session() as sess:   
  11.     sess.run(h_est.initializer) 
  12.     sess.run(v_est.initializer) 
  13.     builder.add_meta_graph_and_variables(sess, 
  14.                                        [tf.saved_model.tag_constants.TRAINING], 
  15.                                        signature_def_map=None
  16.                                        assets_collection=None
  17. builder.save() 
  1. $ python tf_saved_model_builder.py 
  2. INFO:tensorflow:No assets to save.   
  3. INFO:tensorflow:No assets to write.   
  4. INFO:tensorflow:SavedModel written to: b'./SavedModel/saved_model.pb' 

运行此代码时,你会注意到我们的模型已保存到位于“./SavedModel/saved_model.pb”的文件中。

(2) 使用SavedModel Loader程序恢复模型

模型恢复使用tf.saved_model.loader,并且可以恢复会话范围中保存的变量,符号。

在下面的例子中,我们将加载模型,并打印出我们的两个系数(h_est和v_est)的数值。数值如预期的那样,我们的模型已经被成功地恢复了。

  1. with tf.Session() as sess:   
  2.     tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], './SavedModel/') 
  3.     h_est = sess.run('hor_estimate2:0') 
  4.     v_est = sess.run('ver_estimate2:0') 
  5.     print("h_est: %.2f, v_est: %.2f" % (h_est, v_est)) 
  1. $ python tf_saved_model_loader.py 
  2. INFO:tensorflow:Restoring parameters from b'./SavedModel/variables/variables'   
  3. h_est: 1.01, v_est: -1.96 

5. 结论

如果你知道你的深度学习网络的训练可能会花费很长时间,保存和恢复TensorFlow模型是非常有用的功能。该主题太广泛,无法在一篇博客文章中详细介绍。不管怎样,在这篇文章中我们介绍了两个工具:Saver和SavedModel builder/loader,并创建一个文件结构,使用简单的线性回归来说明实例。希望这些能够帮助到你训练出更好的神经网络模型。

TensorFlow 神经网络 深度学习
上一篇:51CTO首届开发者大赛部分作品曝光,等你来补充! 下一篇:从算法实现到MiniFlow实现,打造机器学习的基础架构平台
评论
取消
暂无评论,快去成为第一个评论的人吧

更多资讯推荐

2020年深度学习优秀GPU一览,看看哪一款最适合你!

如果你准备进入深度学习,什么样的GPU才是最合适的呢?下面列出了一些适合进行深度学习模型训练的GPU,并将它们进行了横向比较,一起来看看吧!

大数据文摘 ·  1天前
从零开始构建简单人工神经网络:1个隐藏层

我们在本文中将构建一个有1个输入层、1个隐藏层和1个输出层的神经网络。我们会看到,我们构建的神经网络能够找到非线性边界。

布加迪 ·  2020-03-26 09:00:00
AI芯片之卷积神经网络原理

卷积神经网络(Convolutional Neural Network,CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。 它包括卷积层(convolutional layer)和池化层(pooling layer)。

人人都是极客 ·  2020-03-25 09:48:10
从零开始构建简单人工神经网络:1个输入层和1个输出层

本上下篇将介绍仅使用numpy Python库从零开始构建人工神经网络(ANN)。上篇将介绍构建一个很简单的ANN,只有1个输入层和1个输出层,没有隐藏层。下篇将介绍构建一个有1个输入层、1个隐藏层和1个输出层的ANN。

布加迪 ·  2020-03-25 09:00:00
清华开源Jittor:首个国内高校自研深度学习框架,一键转换PyTorch

深度学习框架越来越多,主导的团队也从高校研究机构渐渐转向了科技巨头。但是,学界在这一领域的力量不容忽视。今日。清华大学开发了一个名为计图(Jittor)的深度学习框架。

佚名 ·  2020-03-20 14:33:29
华为开源只用加法的神经网络:实习生领衔打造,效果不输传统CNN

没有乘法的神经网络,你敢想象吗?无论是单个神经元的运算还是卷积运算,都不可避免地要使用乘法。

佚名 ·  2020-03-17 10:01:22
一行代码让性能提升2倍 精选

如果现在向你推荐一款神器,可以实现训练速度翻倍,访存效率翻倍,你心动吗?心动不如行动,来和我一起看看这款神器——基于PaddlePaddle核心框架的自动混合精度技术,简称飞桨 AMP 技术。

佚名 ·  2020-03-13 13:23:42
Facebook研究开放三个新框架,让深度学习更容易

你知道吗?微软、谷歌、Facebook、亚马逊、Uber等科技巨头的研究部门已经成为人工智能(AI)领域开源框架最活跃的贡献者之一。

读芯术 ·  2020-03-10 13:27:28
Copyright©2005-2020 51CTO.COM 版权所有 未经许可 请勿转载