• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

TensorFlow编程指南: Estimators

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文件介绍Estimator – TensorFlow高级API,Estimators大大简化了机器学习编程,它包含以下功能:

  • 训练
  • 评测
  • 预测
  • 导出提供服务

您可以使用我们提供的预开发(Tensorflow自带)的Estimators,也可以编写自定义Estimators。所有Estimators – 无论是Tensorflow自带还是用户自定义 – 都是基于tf.estimator.Estimator类的类。

注意:TensorFlow也包括已弃用的Estimatortf.contrib.learn.Estimator,这个后续不建议使用。

Estimators的优点

Estimators有以下优点:

  • 您可以在本地主机上或在分布式multi-server环境中运行基于Estimators模型,而无需更改模型。此外,您可以在CPU,GPU或TPU上运行基于Estimators的模型,而无需重新编码模型。
  • Estimators简化了模型开发人员之间的实现共享。
  • 您可以使用high-level直观的代码开发最先进的模型。简而言之,使用Estimators创建模型通常比使用low-level TensorFlow API更容易。
  • Estimators本身是建立在tf.layers上,这简化了自定义操作。
  • Estimators为你建立计算图。换句话说,你不必自己建立计算图。
  • Estimators提供安全的分布式训练循环,控制如何和何时:
    • 建立计算图
    • 初始化变量
    • 启动队列
    • 处理异常
    • 创建检查点文件并从失败中恢复
    • 保存TensorBoard的摘要

在使用Estimators编写应用程序时,必须将数据输入管道与模型分开。这种分离简化了使用不同数据集时的实验开销。

TensorFlow自带的Estimators

自带的Estimators使您能够在比基本TensorFlow API更高的概念层面上工作。由于Estimators为您处理所有麻烦的工作,您不必再担心创建计算图或会话的问题。也就是说,自带的Estimators直接为你创建和管理GraphSession对象。此外,自带的Estimators让您只需进行最少的代码更改就可以尝试不同的模型架构。例如,DNNClassifier,这个自带的Estimator类通过密集的前馈神经网络训练分类模型。

自带的Estimators程序的结构

依赖自带Estimators的TensorFlow程序通常由以下四个步骤组成:

  1. 编写一个或多个数据集导入功能。例如,您可以创建一个函数来导入训练集,另一个函数导入测试集。每个数据集导入函数都必须返回两个对象:

    • 一个字典,其中的键是特征名称和值是张量(Tensors或SparseTensors)包含相应特征数据
    • 包含一个或多个标签的张量

    例如,下面的代码演示了输入函数的基本框架:

    def input_fn(dataset):
       ...  # manipulate dataset, extracting feature names and the label
       return feature_dict, label
    

    (见导入数据了解有关详细信息。)

  2. 定义特征列。tf.feature_column标识特征名称、类型以及任何输入的预处理。例如,以下片段创建三个保存整数或浮点数据的特征列。前两个特征列只是标识特征的名称和类型。第三个特征列还指定了一个lambda程序来缩放原始数据:

    # Define three numeric feature columns.
    population = tf.feature_column.numeric_column('population')
    crime_rate = tf.feature_column.numeric_column('crime_rate')
    median_education = tf.feature_column.numeric_column('median_education',
                        normalizer_fn='lambda x: x - global_education_mean')
    
  3. 实例化相关的自带Estimators。例如,下面是一个名为LinearClassifier自带Estimators的示例:

    # Instantiate an estimator, passing the feature columns.
    estimator = tf.estimator.Estimator.LinearClassifier(
        feature_columns=[population, crime_rate, median_education],
        )
    
  4. 调用训练,评估或预测方法。例如,所有Estimators都提供了一个train方法,它训练一个模型。

    # my_training_set is the function created in Step 1
    estimator.train(input_fn=my_training_set, steps=2000)
    

TensorFlow自带的Estimators的好处

TensorFlow自带的Estimators包含最佳实践,提供以下好处:

  • 确定计算图的不同部分应该在哪里运行的最佳实践,在单个机器或集群上实施策略。
  • 事件(摘要)写入和普遍有用的摘要的最佳实践。

如果您不使用自带的Estimators,则必须自己实现上述功能。

自定义Estimators

每个Estimators的核心 – 无论是系统自带还是自定义 – 其核心都是模型函数,它是建立训练,评估和预测图的方法。当您使用自带的Estimators时,其他人已经实现了这些模型功能。当依靠自定义的Estimators时,您必须自己编写模型函数。Creating Estimators in tf.estimator解释如何编写模型函数。

我们推荐以下工作流程:

  1. 假设存在一个合适的自带的Estimators,请使用它来构建您的第一个模型并使用其结果来建立基线。
  2. 使用此自带的Estimators构建和测试您的整体流水线,包括数据的完整性和可靠性。
  3. 如果有合适的可替代的自带的Estimators可用,运行实验以确定哪个自带的Estimators能够产生最佳结果。
  4. 可能通过构建您自己的自定义Estimators来进一步改进您的模型。

从Keras模型创建Estimators

您可以将现有的Keras模型转换为Estimators。这样做可以使您的Keras模型获得Estimator的优势,例如分布式训练。调用tf.keras.estimator.model_to_estimator,按下例所示的方式:

# Instantiate a Keras inception v3 model.
keras_inception_v3 = tf.keras.applications.inception_v3.InceptionV3(weights=None)
# Compile model with the optimizer, loss, and metrics you'd like to train with.
keras_inception_v3.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001, momentum=0.9),
                          loss='categorical_crossentropy',
                          metric='accuracy')
# Create an Estimator from the compiled Keras model.
est_inception_v3 = tf.keras.estimator.model_to_estimator(keras_model=keras_inception_v3)
# Treat the derived Estimator as you would any other Estimator. For example,
# the following derived Estimator calls the train method:
est_inception_v3.train(input_fn=my_training_set, steps=2000)

更多详细信息,请参阅文档tf.keras.estimator.model_to_estimator

参考资料

  • Estimators | TensorFlow

鲜花

握手

雷人

路过

鸡蛋
专题导读
上一篇:
TensorBoard直方图仪表板发布时间:2022-05-14
下一篇:
TensorFlow编程指南: Embedding发布时间:2022-05-14
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap