• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python data_feeder.setup_train_data_feeder函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.contrib.learn.python.learn.io.data_feeder.setup_train_data_feeder函数的典型用法代码示例。如果您正苦于以下问题:Python setup_train_data_feeder函数的具体用法?Python setup_train_data_feeder怎么用?Python setup_train_data_feeder使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了setup_train_data_feeder函数的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: fit

  def fit(self, x, y, steps=None, monitors=None, logdir=None):
    """Neural network model from provided `model_fn` and training data.

    Note: called first time constructs the graph and initializers
    variables. Consecutives times it will continue training the same model.
    This logic follows partial_fit() interface in scikit-learn.
    To restart learning, create new estimator.

    Args:
      x: matrix or tensor of shape [n_samples, n_features...]. Can be
      iterator that returns arrays of features. The training input
      samples for fitting the model.
      y: vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
      iterator that returns array of targets. The training target values
      (class labels in classification, real numbers in regression).
      steps: int, number of steps to train.
             If None or 0, train for `self.steps`.
      monitors: List of `BaseMonitor` objects to print training progress and
        invoke early stopping.
      logdir: the directory to save the log file that can be used for
      optional visualization.

    Returns:
      Returns self.
    """
    if logdir is not None:
      self._model_dir = logdir
    self._data_feeder = setup_train_data_feeder(
        x, y, n_classes=self.n_classes, batch_size=self.batch_size)
    self._train_model(input_fn=self._data_feeder.input_builder,
                      feed_fn=self._data_feeder.get_feed_dict_fn(),
                      steps=steps or self.steps,
                      monitors=monitors)
    return self
开发者ID:285219011,项目名称:hello-world,代码行数:34,代码来源:base.py


示例2: __init__

 def __init__(self, val_X, val_y, n_classes=0, print_steps=100,
              early_stopping_rounds=None):
     super(ValidationMonitor, self).__init__(print_steps=print_steps,
                                             early_stopping_rounds=early_stopping_rounds)
     self.val_feeder = setup_train_data_feeder(val_X, val_y, n_classes, -1)
     self.print_val_loss_buffer = []
     self.all_val_loss_buffer = []
开发者ID:2er0,项目名称:tensorflow,代码行数:7,代码来源:monitors.py


示例3: _get_input_fn

def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
  """Make inputs into input and feed functions."""
  if input_fn is None:
    if x is None:
      raise ValueError('Either x or input_fn must be provided.')

    if contrib_framework.is_tensor(x) or (y is not None and
                                          contrib_framework.is_tensor(y)):
      raise ValueError('Inputs cannot be tensors. Please provide input_fn.')

    if feed_fn is not None:
      raise ValueError('Can not provide both feed_fn and x or y.')

    df = data_feeder.setup_train_data_feeder(x, y, n_classes=None,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             epochs=epochs)
    return df.input_builder, df.get_feed_dict_fn()

  if (x is not None) or (y is not None):
    raise ValueError('Can not provide both input_fn and x or y.')
  if batch_size is not None:
    raise ValueError('Can not provide both input_fn and batch_size.')

  return input_fn, feed_fn
开发者ID:AntHar,项目名称:tensorflow,代码行数:25,代码来源:estimator.py


示例4: evaluate

 def evaluate(self, x=None, y=None, input_fn=None, steps=None):
     """See base class."""
     feed_fn = None
     if x is not None:
         eval_data_feeder = setup_train_data_feeder(
             x, y, n_classes=self.n_classes, batch_size=self.batch_size, epochs=1
         )
         input_fn, feed_fn = (eval_data_feeder.input_builder, eval_data_feeder.get_feed_dict_fn())
     return self._evaluate_model(input_fn=input_fn, feed_fn=feed_fn, steps=steps or self.steps)
开发者ID:scott89,项目名称:tensorflow,代码行数:9,代码来源:base.py


示例5: _get_predict_input_fn

def _get_predict_input_fn(x, batch_size):
  # TODO(ipoloshukin): Remove this when refactor of data_feeder is done
  if hasattr(x, 'create_graph'):
    def input_fn():
      return x.create_graph()
    return input_fn, None

  df = data_feeder.setup_train_data_feeder(x, None,
                                           n_classes=None,
                                           batch_size=batch_size, epochs=1)
  return df.input_builder, df.get_feed_dict_fn()
开发者ID:01-,项目名称:tensorflow,代码行数:11,代码来源:estimator.py


示例6: predict

 def predict(self, x=None, input_fn=None, batch_size=None, outputs=None,
             axis=1):
   if x is not None:
     predict_data_feeder = setup_train_data_feeder(
         x, None, n_classes=None,
         batch_size=batch_size or self.batch_size,
         shuffle=False, epochs=1)
     result = super(DeprecatedMixin, self)._infer_model(
       input_fn=predict_data_feeder.input_builder,
       feed_fn=predict_data_feeder.get_feed_dict_fn(),
       outputs=outputs)
   else:
     result = super(DeprecatedMixin, self)._infer_model(
     input_fn=input_fn, outputs=outputs)
   if self.__deprecated_n_classes > 1 and axis is not None:
     return np.argmax(result, axis)
   return result
开发者ID:AngleFork,项目名称:tensorflow,代码行数:17,代码来源:base.py


示例7: _predict

    def _predict(self, x, axis=-1, batch_size=None):
        if self._graph is None:
            raise NotFittedError()
        # Use the batch size for fitting if the user did not specify one.
        if batch_size is None:
            batch_size = self.batch_size

        predict_data_feeder = setup_train_data_feeder(
            x, None, n_classes=None, batch_size=batch_size, shuffle=False, epochs=1
        )

        preds = self._infer_model(
            input_fn=predict_data_feeder.input_builder, feed_fn=predict_data_feeder.get_feed_dict_fn()
        )
        if self.n_classes > 1 and axis != -1:
            preds = preds["predictions"].argmax(axis=axis)
        else:
            preds = preds["predictions"]
        return preds
开发者ID:scott89,项目名称:tensorflow,代码行数:19,代码来源:base.py


示例8: fit

    def fit(self, X, y, monitor=None, logdir=None):
        """Builds a neural network model given provided `model_fn` and training
        data X and y.

        Note: called first time constructs the graph and initializers
        variables. Consecutives times it will continue training the same model.
        This logic follows partial_fit() interface in scikit-learn.

        To restart learning, create new estimator.

        Args:
            X: matrix or tensor of shape [n_samples, n_features...]. Can be
            iterator that returns arrays of features. The training input
            samples for fitting the model.
            y: vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
            iterator that returns array of targets. The training target values
            (class labels in classification, real numbers in regression).
            monitor: Monitor object to print training progress and invoke early stopping
            logdir: the directory to save the log file that can be used for
            optional visualization.

        Returns:
            Returns self.
        """
        # Sets up data feeder.
        self._data_feeder = setup_train_data_feeder(X, y,
                                                    self.n_classes,
                                                    self.batch_size)

        if monitor is None:
            self._monitor = monitors.default_monitor(verbose=self.verbose)
        else:
            self._monitor = monitor

        if not self.continue_training or not self._initialized:
            # Sets up model and trainer.
            self._setup_training()
            self._initialized = True
        else:
            self._data_feeder.set_placeholders(self._inp, self._out)

        # Sets up summary writer for later optional visualization.
        # Due to not able to setup _summary_writer in __init__ as it's not a
        # parameter of the model, here we need to check if such variable exists
        # and if it's None or not (in case it was setup in a previous run).
        # It is initialized only in the case where it wasn't before and log dir
        # is provided.
        if logdir:
            if (not hasattr(self, "_summary_writer") or
                    (hasattr(self, "_summary_writer") and self._summary_writer is None)):
                self._setup_summary_writer(logdir)
        else:
            self._summary_writer = None

        # Attach monitor to this estimator.
        self._monitor.set_estimator(self)

        # Train model for given number of steps.
        trainer.train(
            self._session, self._train, 
            self._model_loss, self._global_step,
            self._data_feeder.get_feed_dict_fn(),
            steps=self.steps,
            monitor=self._monitor,
            summary_writer=self._summary_writer,
            summaries=self._summaries,
            feed_params_fn=self._data_feeder.get_feed_params)
        return self
开发者ID:01bui,项目名称:tensorflow,代码行数:68,代码来源:base.py


示例9: _get_predict_input_fn

def _get_predict_input_fn(x, y, batch_size):
  df = data_feeder.setup_train_data_feeder(
      x, y, n_classes=None, batch_size=batch_size,
      shuffle=False, epochs=1)
  return df.input_builder, df.get_feed_dict_fn()
开发者ID:Baaaaam,项目名称:tensorflow,代码行数:5,代码来源:estimator.py


示例10: _get_input_fn

def _get_input_fn(x, y, batch_size):
  df = data_feeder.setup_train_data_feeder(
      x, y, n_classes=None, batch_size=batch_size)
  return df.input_builder, df.get_feed_dict_fn()
开发者ID:Baaaaam,项目名称:tensorflow,代码行数:4,代码来源:estimator.py



注:本文中的tensorflow.contrib.learn.python.learn.io.data_feeder.setup_train_data_feeder函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap