TF Learn : 基于Scikit-learn和TensorFlow的深度学习利器

作者: 汪昊 2018-09-06 08:00:00

【51CTO.com原创稿件】了解国外数据科学市场的人都知道,2017年海外数据科学最常用的三项技术是 Spark ,Python 和 MongoDB 。说到 Python ,做大数据的人都不会对 Scikit-learn 和 Pandas 感到陌生。

TF Learn : 基于Scikit-learn和TensorFlow的深度学习利器

Scikit-learn 是最常用的 Python 机器学习框架,在各大互联网公司做算法的工程师在实现单机版本的算法的时候或多或少都会用到 Scikit-learn 。TensorFlow 就更是大名鼎鼎,做深度学习的人都不可能不知道 TensorFlow。

下面我们先来看一段样例,这段样例是传统的机器学习算法逻辑回归的实现:

TF Learn : 基于Scikit-learn和TensorFlow的深度学习利器

可以看到,样例中仅仅使用了 3 行代码就完成了逻辑回归的主要功能。下面我们来看一下如果用 TensorFlow 来实现同样的代码,需要多少行?下面的代码来自 GitHub :

  1. '' 
  2. A logistic regression learning algorithm example using TensorFlow library.  
  3. This example is using the MNIST database of handwritten digits  
  4. (https://yann.lecun.com/exdb/mnist/)  
  5. Author: Aymeric Damien  
  6. Project: https://github.com/aymericdamien/TensorFlow-Examples/  
  7. '' 
  8. from __future__ import print_function  
  9. import tensorflow as tf  
  10. # Import MNIST data  
  11. from tensorflow.examples.tutorials.mnist import input_data  
  12. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True 
  13.  
  14. # Parameters  
  15. learning_rate = 0.01  
  16. training_epochs = 25  
  17. batch_size = 100  
  18. display_step = 1  
  19.  
  20. # tf Graph Input  
  21. x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784  
  22. y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes  
  23.  
  24. Set model weights  
  25. W = tf.Variable(tf.zeros([784, 10]))  
  26. b = tf.Variable(tf.zeros([10]))  
  27.  
  28. # Construct model  
  29. pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax  
  30.  
  31. # Minimize error using cross entropy  
  32. cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))  
  33. # Gradient Descent  
  34. optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)  
  35.  
  36. # Initialize the variables (i.e. assign their default value)  
  37. init = tf.global_variables_initializer()  
  38.  
  39. # Start training  
  40. with tf.Session() as sess:  
  41.     # Run the initializer  
  42.     sess.run(init)  
  43.     # Training cycle  
  44.     for epoch in range(training_epochs):  
  45.         avg_cost = 0.  
  46.         total_batch = int(mnist.train.num_examples/batch_size)  
  47.         # Loop over all batches  
  48.         for i in range(total_batch):  
  49.             batch_xs, batch_ys = mnist.train.next_batch(batch_size)  
  50.             # Run optimization op (backprop) and cost op (to get loss value)  
  51.             _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,  
  52.                                                           y: batch_ys})  
  53.             # Compute average loss  
  54.             avg_cost += c / total_batch  
  55.         # Display logs per epoch step  
  56.         if (epoch+1) % display_step == 0:  
  57.             print("Epoch:"'%04d' % (epoch+1), "cost=""{:.9f}".format(avg_cost))  
  58.  
  59.     print("Optimization Finished!" 
  60.  
  61.     # Test model  
  62.     correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))  
  63.     # Calculate accuracy  
  64.     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))  
  65. print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) 

一个相对来说比较简单的机器学习算法,用 Tensorflow 来实现却花费了大量的篇幅。然而 Scikit-learn 本身没有 Tensorflow 那样丰富的深度学习的功能。有没有什么办法,能够在保证 Scikit-learn 的简单易用性的前提下,能够让 Scikit-learn 像 Tensorflow 那样支持深度学习呢?答案是有的,那就是 Scikit-Flow 开源项目。该项目后来被集成到了 Tensorflow 项目里,变成了现在的 TF Learn 模块。

我们来看一个 TF Learn 实现线性回归的样例:

  1. """ Linear Regression Example """ 
  2. from __future__ import absolute_import, division, print_function  
  3. import tflearn  
  4. # Regression data  
  5. X = [3.3,4.4,5.5,6.71,6.93,4.168,9.779,6.182,7.59,2.167,7.042,10.791,5.313,7.997,5.654,9.27,3.1]  
  6. Y = [1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,2.827,3.465,1.65,2.904,2.42,2.94,1.3]  
  7. # Linear Regression graph  
  8. input_ = tflearn.input_data(shape=[None])  
  9. linear = tflearn.single_unit(input_)  
  10. regression = tflearn.regression(linear, optimizer='sgd', loss='mean_square' 
  11.                                 metric='R2', learning_rate=0.01)  
  12. m = tflearn.DNN(regression)  
  13. m.fit(X, Y, n_epoch=1000, show_metric=True, snapshot_epoch=False 
  14. print("\nRegression result:" 
  15. print("Y = " + str(m.get_weights(linear.W)) +  
  16.       "*X + " + str(m.get_weights(linear.b)))  
  17.  
  18. print("\nTest prediction for x = 3.2, 3.3, 3.4:" 
  19. print(m.predict([3.2, 3.3, 3.4])) 

我们可以看到,TF Learn 继承了 Scikit-Learn 的简洁编程风格,在处理传统的机器学习方法的时候非常的方便。下面我们看一段 TF Learn 实现 CNN (MNIST数据集)的样例:

  1. """ Convolutional Neural Network for MNIST dataset classification task.  
  2. References 
  3.     Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based  
  4.     learning applied to document recognition." Proceedings of the IEEE,  
  5.     86(11):2278-2324, November 1998.  
  6. Links:  
  7.     [MNIST Dataset] https://yann.lecun.com/exdb/mnist/  
  8. "" 
  9.  
  10. from __future__ import division, print_function, absolute_import   
  11. import tflearn  
  12. from tflearn.layers.core import input_data, dropout, fully_connected  
  13. from tflearn.layers.conv import conv_2d, max_pool_2d  
  14. from tflearn.layers.normalization import local_response_normalization  
  15. from tflearn.layers.estimator import regression  
  16.  
  17. # Data loading and preprocessing  
  18. import tflearn.datasets.mnist as mnist  
  19. X, Y, testX, testY = mnist.load_data(one_hot=True 
  20. X = X.reshape([-1, 28, 28, 1])  
  21. testX = testX.reshape([-1, 28, 28, 1])  
  22. # Building convolutional network  
  23. network = input_data(shape=[None, 28, 28, 1], name='input' 
  24. network = conv_2d(network, 32, 3, activation='relu', regularizer="L2" 
  25. network = max_pool_2d(network, 2)  
  26. network = local_response_normalization(network)  
  27. network = conv_2d(network, 64, 3, activation='relu', regularizer="L2" 
  28. network = max_pool_2d(network, 2)  
  29. network = local_response_normalization(network)  
  30. network = fully_connected(network, 128, activation='tanh' 
  31. network = dropout(network, 0.8)  
  32. network = fully_connected(network, 256, activation='tanh' 
  33. network = dropout(network, 0.8)  
  34. network = fully_connected(network, 10, activation='softmax'
  35. network = regression(network, optimizer='adam', learning_rate=0.01,  
  36.                      loss='categorical_crossentropy'name='target' 
  37.  
  38. # Training  
  39. model = tflearn.DNN(network, tensorboard_verbose=0)  
  40. model.fit({'input': X}, {'target': Y}, n_epoch=20,  
  41.            validation_set=({'input': testX}, {'target': testY}),  
  42. snapshot_step=100, show_metric=True, run_id='convnet_mnist'

可以看到,基于 TF Learn 的深度学习代码也是非常的简洁。

TF Learn 是 TensorFlow 的高层次类 Scikit-Learn 封装,提供了原生版 TensorFlow 和 Scikit-Learn 之外的又一种选择。对于熟悉了 Scikit-Learn 和厌倦了 TensorFlow 冗长代码的用户来说,不啻为一种福音,也值得机器学习和数据挖掘的从业者认真学习和掌握。

汪昊,恒昌利通大数据部负责人/资深架构师,美国犹他大学本科/硕士,对外经贸大学在职MBA。曾在百度,新浪,网易,豆瓣等公司有多年的研发和技术管理经验,擅长机器学习,大数据,推荐系统,社交网络分析等技术。在 TVCG 和 ASONAM 等国际会议和期刊发表论文 8 篇。本科毕业论文获国际会议 IEEE SMI 2008 ***论文奖。

【51CTO原创稿件,合作站点转载请注明原文作者和出处为51CTO.com】

深度学习 TensorFlow Python
上一篇:科普 | 从TensorFlow.js入手了解机器学习 下一篇:警惕数字化转型三大陷阱,英特尔赋能平安医疗科技“端到端”的AI能力
评论
取消
暂无评论,快去成为第一个评论的人吧

更多资讯推荐

2020年搞深度学习需要什么样的GPU:请上48G显存

在 lambda 最新的一篇显卡横向测评文章中,开发者们探讨了哪些 GPU 可以再不出现内存错误的情况下训练模型。当然,还有这些 GPU 的 AI 性能。

机器之心 ·  13h前
为什么用Go编写机器学习的基础架构,而不是Python?

虽然Python是使用广泛的语言,并用于每个主要的机器学习框架中。然而,你能想象?在Cortex(将机器学习模型部署为API的开放源代码平台之一)代码库中,87.5%的代码都是使用GO编写。

读芯术 ·  2020-02-14 13:13:04
一个案例掌握深度学习

近期我们将连载一个深度学习专题,由百度深度学习技术平台部主任架构师毕然分享,让你快速入门深度学习,参与到人工智能浪潮中。

佚名 ·  2020-02-12 17:10:54
20条理由告诉你,为什么当前的深度学习成了人工智能的死胡同?

在深度学习刚刚进入视线时,大多数AI研究人员嗤之以鼻,但短短几年后,它的触角已经横跨医疗、教育、汽车等众多领域。

AI科技评论 ·  2020-02-10 13:36:30
2020,人工智能和深度学习未来的五大趋势

虽然近年来人工智能经常成为热门议题,但它还远未实现真正的成就。人工智能技术发展的主要障碍在于投资成本,投资成本影响短期内的回报。而当时机成熟时,投资AI的公司却可以获得巨大的回报。在最近的一份报告中,麦肯锡预测人工智能领头企业未来将会实现现金流翻倍。

CSDN App ·  2020-02-09 17:27:29
Golang还是Python?哪种语言更适合AI?

Google所开发的Golang正在成为机器学习和AI的主流编程语言,全球已有数百万用户。那么很多IT技术人员肯定会有这样的疑问,前些年火起来的Python在科学计算及AI领域也有广泛应用,那么这两者今后谁才是AI的菜呢?

橘子键盘手 ·  2020-02-07 12:55:29
深度学习火了那么多年,到底怎么搞?使用Numpy快速入门

近期我们将连载一个深度学习专题,由百度深度学习技术平台部主任架构师毕然分享,让你快速入门深度学习,参与到人工智能浪潮中。

佚名 ·  2020-02-05 11:25:29
请别再把深度学习与机器学习混为一谈了!

虽说机器学习和深度学习都能发现数据中的模式与特征,但是它们所涉及到的技术和具体的应用场景却截然不同。

陈峻 ·  2020-02-05 09:00:00
Copyright©2005-2020 51CTO.COM 版权所有 未经许可 请勿转载