本文整理汇总了Python中tensorflow.python.ops.state_ops.scatter_nd_update函数的典型用法代码示例。如果您正苦于以下问题:Python scatter_nd_update函数的具体用法?Python scatter_nd_update怎么用?Python scatter_nd_update使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了scatter_nd_update函数的6个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testRank3InvalidShape2
def testRank3InvalidShape2(self):
indices = array_ops.zeros([2, 2, 1], dtypes.int32)
updates = array_ops.zeros([2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
with self.assertRaisesWithPredicateMatch(
ValueError, "The inner \\d+ dimensions of input\\.shape="):
state_ops.scatter_nd_update(ref, indices, updates)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py
示例2: testResVarInvalidOutputShape
def testResVarInvalidOutputShape(self):
res = variables.Variable(
initial_value=lambda: array_ops.zeros(shape=[], dtype=dtypes.float32),
dtype=dtypes.float32)
with self.cached_session():
res.initializer.run()
with self.assertRaisesOpError("Output must be at least 1-D"):
state_ops.scatter_nd_update(res, [[0]], [0.22]).eval()
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py
示例3: testRank3ValidShape
def testRank3ValidShape(self):
indices = array_ops.zeros([2, 2, 2], dtypes.int32)
updates = array_ops.zeros([2, 2, 2], dtypes.int32)
shape = np.array([2, 2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
self.assertAllEqual(
state_ops.scatter_nd_update(ref, indices,
updates).get_shape().as_list(), shape)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py
示例4: testExtraIndicesDimensions
def testExtraIndicesDimensions(self):
indices = array_ops.zeros([1, 1, 2], dtypes.int32)
updates = array_ops.zeros([1, 1], dtypes.int32)
shape = np.array([2, 2])
ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
with self.cached_session():
ref.initializer.run()
self.assertAllEqual(expected_result, scatter_update.eval())
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:12,代码来源:scatter_nd_ops_test.py
示例5: testSimple
def testSimple(self):
indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
scatter = state_ops.scatter_nd_update(ref, indices, updates)
init = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
sess.run(init)
result = sess.run(scatter)
self.assertAllClose(result, expected)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:12,代码来源:scatter_nd_ops_test.py
示例6: testSimple3
def testSimple3(self):
indices = constant_op.constant([[1]], dtype=dtypes.int32)
updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32)
ref = variables.Variable(
[[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
scatter = state_ops.scatter_nd_update(ref, indices, updates)
init = variables.global_variables_initializer()
with self.test_session(use_gpu=True) as sess:
sess.run(init)
result = sess.run(scatter)
self.assertAllClose(result, expected)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:13,代码来源:scatter_nd_ops_test.py
注:本文中的tensorflow.python.ops.state_ops.scatter_nd_update函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论