• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python boosted_trees_ops.training_predict函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.boosted_trees_ops.training_predict函数的典型用法代码示例。如果您正苦于以下问题:Python training_predict函数的具体用法?Python training_predict怎么用?Python training_predict使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了training_predict函数的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testCachedPredictionOnEmptyEnsemble

  def testCachedPredictionOnEmptyEnsemble(self):
    """Tests that prediction on a dummy ensemble does not fail."""
    with self.cached_session() as session:
      # Create a dummy ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto='')
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # No previous cached values.
      cached_tree_ids = [0, 0]
      cached_node_ids = [0, 0]

      # We have two features: 0 and 1. Values don't matter here on a dummy
      # ensemble.
      feature_0_values = [67, 5]
      feature_1_values = [9, 17]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # Nothing changed.
      self.assertAllClose(cached_tree_ids, new_tree_ids)
      self.assertAllClose(cached_node_ids, new_node_ids)
      self.assertAllClose([[0], [0]], logits_updates)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:32,代码来源:prediction_ops_test.py


示例2: testCachedPredictionTheWholeTreeWasPruned

  def testCachedPredictionTheWholeTreeWasPruned(self):
    """Tests that prediction based on previous node in the tree works."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            leaf {
              scalar: 0.00
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 1
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: -6.0
          }
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 5.0
          }
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 1
        }
      """, tree_ensemble_config)

      # Create existing ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      cached_tree_ids = [
          0,
          0,
      ]
      # The predictions were cached in 1 and 2, both were pruned to the root.
      cached_node_ids = [1, 2]

      # We have two features: 0 and 1.These are not going to be used anywhere.
      feature_0_values = [12, 17]
      feature_1_values = [12, 12]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # We are in the last tree.
      self.assertAllClose([0, 0], new_tree_ids)
      self.assertAllClose([0, 0], new_node_ids)

      self.assertAllClose([[-6.0], [5.0]], logits_updates)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:67,代码来源:prediction_ops_test.py


示例3: testCachedPredictionFromThePreviousTreeWithPostPrunedNodes


#.........这里部分代码省略.........
            }
          }
        }
        trees {
          nodes {
            leaf {
              scalar: 0.55
            }
          }
        }
        tree_weights: 1.0
        tree_weights: 1.0
        tree_metadata {
          num_layers_grown: 3
          is_finalized: true
          post_pruned_nodes_meta {
            new_node_id: 0
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 2
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.07
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.083
          }
          post_pruned_nodes_meta {
            new_node_id: 3
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 4
            logit_change: 0.0
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.22
          }
          post_pruned_nodes_meta {
            new_node_id: 1
            logit_change: -0.57
          }
        }
        tree_metadata {
          num_layers_grown: 1
          is_finalized: false
        }
        growing_metadata {
          num_trees_attempted: 2
          num_layers_attempted: 4
        }
      """, tree_ensemble_config)

      # Create existing ensemble.
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      cached_tree_ids = [0, 0, 0, 0, 0, 0]
      # Leaves 3,4, 7 and 8 got deleted during post-pruning, leaves 5 and 6
      # changed the ids to 3 and 4 respectively.
      cached_node_ids = [3, 4, 5, 6, 7, 8]

      # We have two features: 0 and 1.
      feature_0_values = [12, 17, 35, 36, 23, 11]
      feature_1_values = [12, 12, 17, 18, 123, 24]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # We are in the last tree.
      self.assertAllClose([1, 1, 1, 1, 1, 1], new_tree_ids)
      # Examples from leaves 3,4,7,8 should be in leaf 1, examples from leaf 5
      # and 6 in leaf 3 and 4 in tree 0. For tree 1, all of the examples are in
      # the root node.
      self.assertAllClose([0, 0, 0, 0, 0, 0], new_node_ids)

      cached_values = [[0.08], [0.093], [0.0553], [0.0783], [0.15 + 0.08],
                       [0.5 + 0.08]]
      root = 0.55
      self.assertAllClose([[root + 0.01], [root + 0.01], [root + 0.0553],
                           [root + 0.0783], [root + 0.01], [root + 0.01]],
                          logits_updates + cached_values)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:101,代码来源:prediction_ops_test.py


示例4: testNoCachedPredictionButTreeExists

  def testNoCachedPredictionButTreeExists(self):
    """Tests that predictions are updated once trees are added."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 1
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, none were cached before.
      cached_tree_ids = [0, 0]
      cached_node_ids = [0, 0]

      feature_0_values = [67, 5]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # We are in the first tree.
      self.assertAllClose([0, 0], new_tree_ids)
      self.assertAllClose([2, 1], new_node_ids)
      self.assertAllClose([[0.1 * 8.79], [0.1 * 1.14]], logits_updates)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:65,代码来源:prediction_ops_test.py


示例5: testCategoricalSplits

  def testCategoricalSplits(self):
    """Tests the training prediction work for categorical splits."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge(
          """
        trees {
          nodes {
            categorical_split {
              feature_id: 1
              value: 2
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            categorical_split {
              feature_id: 0
              value: 13
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
          nodes {
            leaf {
              scalar: 6.0
            }
          }
        }
        tree_weights: 1.0
        tree_metadata {
          is_finalized: true
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      feature_0_values = [13, 1, 3]
      feature_1_values = [2, 2, 1]

      # No previous cached values.
      cached_tree_ids = [0, 0, 0]
      cached_node_ids = [0, 0, 0]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      self.assertAllClose([0, 0, 0], new_tree_ids)
      self.assertAllClose([3, 4, 2], new_node_ids)
      self.assertAllClose([[5.], [6.], [7.]], logits_updates)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:71,代码来源:prediction_ops_test.py


示例6: testCachedPredictionFromPreviousTree


#.........这里部分代码省略.........
              feature_id: 1
              threshold: 26
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 50
              left_id: 3
              right_id: 4
            }
          }
          nodes {
            leaf {
              scalar: 7
            }
          }
          nodes {
            leaf {
              scalar: 5
            }
          }
          nodes {
            leaf {
              scalar: 6
            }
          }
        }
        trees {
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 34
              left_id: 1
              right_id: 2
            }
          }
          nodes {
            leaf {
              scalar: -7.0
            }
          }
          nodes {
            leaf {
              scalar: 5.0
            }
          }
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: true
        }
        tree_metadata {
          is_finalized: false
        }
        tree_weights: 0.1
        tree_weights: 0.1
        tree_weights: 0.1
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, one was cached in node 1 first, another in node 2.
      cached_tree_ids = [0, 0]
      cached_node_ids = [1, 0]

      # We have two features: 0 and 1.
      feature_0_values = [36, 32]
      feature_1_values = [11, 27]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
      # Example 1 will get to node 3 in tree 1 and node 2 of tree 2
      # Example 2 will get to node 2 in tree 1 and node 1 of tree 2

      # We are in the last tree.
      self.assertAllClose([2, 2], new_tree_ids)
      # When using the full tree, the first example will end up in node 4,
      # the second in node 5.
      self.assertAllClose([2, 1], new_node_ids)
      # Example 1: tree 0: 8.79, tree 1: 5.0, tree 2: 5.0 = >
      #            change = 0.1*(5.0+5.0)
      # Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
      #            change= 0.1(1.14+7.0-7.0)
      self.assertAllClose([[1], [0.114]], logits_updates)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:101,代码来源:prediction_ops_test.py


示例7: testCachedPredictionFromTheSameTree

  def testCachedPredictionFromTheSameTree(self):
    """Tests that prediction based on previous node in the tree works."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 7
              left_id: 3
              right_id: 4
            }
            metadata {
              gain: 1.4
              original_leaf {
                scalar: 7.14
              }
            }
          }
          nodes {
            bucketized_split {
              feature_id: 0
              threshold: 7
              left_id: 5
              right_id: 6
            }
            metadata {
              gain: 2.7
              original_leaf {
                scalar: -4.375
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
          nodes {
            leaf {
              scalar: -5.875
            }
          }
          nodes {
            leaf {
              scalar: -2.075
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, one was cached in node 1 first, another in node 0.
      cached_tree_ids = [0, 0]
      cached_node_ids = [1, 0]

      # We have two features: 0 and 1.
      feature_0_values = [67, 5]
      feature_1_values = [9, 17]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
#.........这里部分代码省略.........
开发者ID:kylin9872,项目名称:tensorflow,代码行数:101,代码来源:prediction_ops_test.py


示例8: testCachedPredictionIsCurrent

  def testCachedPredictionIsCurrent(self):
    """Tests that prediction based on previous node in the tree works."""
    with self.cached_session() as session:
      tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
      text_format.Merge("""
        trees {
          nodes {
            bucketized_split {
              feature_id: 1
              threshold: 15
              left_id: 1
              right_id: 2
            }
            metadata {
              gain: 7.62
              original_leaf {
                scalar: -2
              }
            }
          }
          nodes {
            leaf {
              scalar: 1.14
            }
          }
          nodes {
            leaf {
              scalar: 8.79
            }
          }
        }
        tree_weights: 0.1
        tree_metadata {
          is_finalized: true
          num_layers_grown: 2
        }
        growing_metadata {
          num_trees_attempted: 1
          num_layers_attempted: 2
        }
      """, tree_ensemble_config)

      # Create existing ensemble with one root split
      tree_ensemble = boosted_trees_ops.TreeEnsemble(
          'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
      tree_ensemble_handle = tree_ensemble.resource_handle
      resources.initialize_resources(resources.shared_resources()).run()

      # Two examples, one was cached in node 1 first, another in node 0.
      cached_tree_ids = [0, 0]
      cached_node_ids = [1, 2]

      # We have two features: 0 and 1. Values don't matter because trees didn't
      # change.
      feature_0_values = [67, 5]
      feature_1_values = [9, 17]

      # Grow tree ensemble.
      predict_op = boosted_trees_ops.training_predict(
          tree_ensemble_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=[feature_0_values, feature_1_values],
          logits_dimension=1)

      logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)

      # Nothing changed.
      self.assertAllClose(cached_tree_ids, new_tree_ids)
      self.assertAllClose(cached_node_ids, new_node_ids)
      self.assertAllClose([[0], [0]], logits_updates)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:71,代码来源:prediction_ops_test.py


示例9: _bt_model_fn


#.........这里部分代码省略.........
          tree_ensemble_handle=tree_ensemble.resource_handle,
          bucketized_features=input_feature_list,
          logits_dimension=head.logits_dimension)
    else:
      if is_single_machine:
        local_tree_ensemble = tree_ensemble
        ensemble_reload = control_flow_ops.no_op()
      else:
        # Have a local copy of ensemble for the distributed setting.
        with ops.device(worker_device):
          local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
              name=name + '_local', is_local=True)
        # TODO(soroush): Do partial updates if this becomes a bottleneck.
        ensemble_reload = local_tree_ensemble.deserialize(
            *tree_ensemble.serialize())
      if training_state_cache:
        cached_tree_ids, cached_node_ids, cached_logits = (
            training_state_cache.lookup())
      else:
        # Always start from the beginning when no cache is set up.
        batch_size = array_ops.shape(labels)[0]
        cached_tree_ids, cached_node_ids, cached_logits = (
            array_ops.zeros([batch_size], dtype=dtypes.int32),
            array_ops.zeros([batch_size], dtype=dtypes.int32),
            array_ops.zeros(
                [batch_size, head.logits_dimension], dtype=dtypes.float32))
      with ops.control_dependencies([ensemble_reload]):
        (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
         last_layer_nodes_range) = local_tree_ensemble.get_states()
        summary.scalar('ensemble/num_trees', num_trees)
        summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
        summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)

        partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
            tree_ensemble_handle=local_tree_ensemble.resource_handle,
            cached_tree_ids=cached_tree_ids,
            cached_node_ids=cached_node_ids,
            bucketized_features=input_feature_list,
            logits_dimension=head.logits_dimension)
      logits = cached_logits + partial_logits

    # Create training graph.
    def _train_op_fn(loss):
      """Run one training iteration."""
      if training_state_cache:
        train_op.append(training_state_cache.insert(tree_ids, node_ids, logits))
      if closed_form_grad_and_hess_fn:
        gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
      else:
        gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
        hessians = gradients_impl.gradients(
            gradients, logits, name='Hessians')[0]

      stats_summaries_list = []
      for i, feature_ids in enumerate(feature_ids_list):
        num_buckets = bucket_size_list[i]
        summaries = [
            array_ops.squeeze(
                boosted_trees_ops.make_stats_summary(
                    node_ids=node_ids,
                    gradients=gradients,
                    hessians=hessians,
                    bucketized_features_list=[input_feature_list[f]],
                    max_splits=max_splits,
                    num_buckets=num_buckets),
                axis=0) for f in feature_ids
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:67,代码来源:boosted_trees.py


示例10: _bt_model_fn


#.........这里部分代码省略.........
    center_bias_var = variable_scope.variable(
        initial_value=center_bias, name='center_bias_needed', trainable=False)
    if is_single_machine:
      local_tree_ensemble = tree_ensemble
      ensemble_reload = control_flow_ops.no_op()
    else:
      # Have a local copy of ensemble for the distributed setting.
      with ops.device(worker_device):
        local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
            name=name + '_local', is_local=True)
      # TODO(soroush): Do partial updates if this becomes a bottleneck.
      ensemble_reload = local_tree_ensemble.deserialize(
          *tree_ensemble.serialize())

    if training_state_cache:
      cached_tree_ids, cached_node_ids, cached_logits = (
          training_state_cache.lookup())
    else:
      # Always start from the beginning when no cache is set up.
      batch_size = array_ops.shape(labels)[0]
      cached_tree_ids, cached_node_ids, cached_logits = (
          array_ops.zeros([batch_size], dtype=dtypes.int32),
          _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
          array_ops.zeros(
              [batch_size, head.logits_dimension], dtype=dtypes.float32))

    with ops.control_dependencies([ensemble_reload]):
      (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
       last_layer_nodes_range) = local_tree_ensemble.get_states()
      summary.scalar('ensemble/num_trees', num_trees)
      summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
      summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)

      partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
          tree_ensemble_handle=local_tree_ensemble.resource_handle,
          cached_tree_ids=cached_tree_ids,
          cached_node_ids=cached_node_ids,
          bucketized_features=input_feature_list,
          logits_dimension=head.logits_dimension)
      logits = cached_logits + partial_logits

    # Create training graph.
    def _train_op_fn(loss):
      """Run one training iteration."""
      if training_state_cache:
        # Cache logits only after center_bias is complete, if it's in progress.
        train_op.append(
            control_flow_ops.cond(
                center_bias_var, control_flow_ops.no_op,
                lambda: training_state_cache.insert(tree_ids, node_ids, logits))
        )

      if closed_form_grad_and_hess_fn:
        gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
      else:
        gradients = gradients_impl.gradients(loss, logits, name='Gradients')[0]
        hessians = gradients_impl.gradients(
            gradients, logits, name='Hessians')[0]

      # TODO(youngheek): perhaps storage could be optimized by storing stats
      # with the dimension max_splits_per_layer, instead of max_splits (for the
      # entire tree).
      max_splits = _get_max_splits(tree_hparams)

      stats_summaries_list = []
      for i, feature_ids in enumerate(feature_ids_list):
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:67,代码来源:boosted_trees.py



注:本文中的tensorflow.python.ops.boosted_trees_ops.training_predict函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python check_ops.assert_equal函数代码示例发布时间:2022-05-27
下一篇:
Python array_ops.zeros_like函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap