本文整理汇总了Python中tensorflow.python.data.experimental.ops.map_defun.map_defun函数的典型用法代码示例。如果您正苦于以下问题:Python map_defun函数的具体用法?Python map_defun怎么用?Python map_defun使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了map_defun函数的19个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testMapDefunWithInvalidInput
def testMapDefunWithInvalidInput(self):
@function.Defun(dtypes.int32)
def simple_fn(x):
return x * 2
c = constant_op.constant(2)
with self.assertRaises(ValueError):
# Fails at graph construction time for inputs with known shapes.
r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
p = array_ops.placeholder(dtypes.int32)
r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
with session.Session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(r, feed_dict={p: 0})
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:15,代码来源:map_defun_op_test.py
示例2: testMapDefunPartialShapeInference
def testMapDefunPartialShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return x
elems = array_ops.placeholder(dtypes.int64, (None, 2))
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
self.assertEqual(result[0].get_shape().as_list(), [None, 2])
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:map_defun_op_test.py
示例3: testMapDefunShapeInference
def testMapDefunShapeInference(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return x
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
self.assertEqual(result.get_shape(), (3, 2))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:map_defun_op_test.py
示例4: testMapDefunWithWrongOutputShape
def testMapDefunWithWrongOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:11,代码来源:map_defun_op_test.py
示例5: testMapDefunSimple
def testMapDefunSimple(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
return x * 2 + 3
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:11,代码来源:map_defun_op_test.py
示例6: testMapDefunMismatchedTypes
def testMapDefunMismatchedTypes(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
return math_ops.cast(x, dtypes.float64)
nums = [1, 2, 3, 4, 5, 6]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(r)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:11,代码来源:map_defun_op_test.py
示例7: testMapDefunRaisesDefunError
def testMapDefunRaisesDefunError(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
return array_ops.identity(x)
elems = constant_op.constant([0, 0, 0, 37, 0])
result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:11,代码来源:map_defun_op_test.py
示例8: testMapDefunWithCapturedInputs
def testMapDefunWithCapturedInputs(self):
c = constant_op.constant(2)
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
return x + c
x = constant_op.constant([1, 2, 3, 4])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
expected = x + c
self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:11,代码来源:map_defun_op_test.py
示例9: testMapDefunReduceDim
def testMapDefunReduceDim(self):
# Tests where the output has a different rank from the input
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return array_ops.gather(x, 0)
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
expected = constant_op.constant([1, 3, 5])
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:12,代码来源:map_defun_op_test.py
示例10: testMapDefunWithDifferentOutputShapeEachRun
def testMapDefunWithDifferentOutputShapeEachRun(self):
@function.Defun(dtypes.int32)
def simple_fn(x):
return x * 2 + 3
elems = array_ops.placeholder(dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
with session.Session() as sess:
self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
self.assertAllEqual(
sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:12,代码来源:map_defun_op_test.py
示例11: testMapDefunMultipleOutputs
def testMapDefunMultipleOutputs(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def fn(x):
return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
(2,)])
expected = [elems, elems * 2 + 3]
self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:12,代码来源:map_defun_op_test.py
示例12: testMapDefunRaisesErrorOnRuntimeShapeMismatch
def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
@function.Defun(dtypes.int32, dtypes.int32)
def fn(x, y):
return x, y
elems1 = array_ops.placeholder(dtypes.int32)
elems2 = array_ops.placeholder(dtypes.int32)
result = map_defun.map_defun(fn, [elems1, elems2],
[dtypes.int32, dtypes.int32], [(), ()])
with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"All inputs must have the same dimension 0."):
sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:15,代码来源:map_defun_op_test.py
示例13: testMapDefunCancelledCorrectly
def testMapDefunCancelledCorrectly(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([5], dtypes.int64)])
def defun(x):
# x has leading dimension 5, this will raise an error
return array_ops.gather(x, 10)
c = array_ops.tile(
array_ops.expand_dims(
constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
[100, 1])
map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r"indices = 10 is not in \[0, 5\)"):
self.evaluate(map_defun_op)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:15,代码来源:map_defun_op_test.py
示例14: testMapDefunWithUnspecifiedOutputShape
def testMapDefunWithUnspecifiedOutputShape(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([2], dtypes.int32)])
def simple_fn(x):
res = x * 2 + 3
return (res, res + 1, res + 2)
nums = [[1, 2], [3, 4], [5, 6]]
elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
r = map_defun.map_defun(simple_fn, [elems],
[dtypes.int32, dtypes.int32, dtypes.int32],
[None, (None,), (2,)])
expected = elems * 2 + 3
self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:16,代码来源:map_defun_op_test.py
示例15: testMapDefunWithVariantTensorAsCaptured
def testMapDefunWithVariantTensorAsCaptured(self):
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.variant)
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def fn(x):
del x
return serialized
x = constant_op.constant([0, 0])
map_defun_op = map_defun.map_defun(fn, [x], [dtypes.variant], [None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:20,代码来源:map_defun_op_test.py
示例16: benchmarkDefunVsMapFn
def benchmarkDefunVsMapFn(self):
"""Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def defun(x):
return array_ops.identity(x)
def fn(x):
return array_ops.identity(x)
base = math_ops.range(100)
for input_size in [10, 100, 1000, 10000]:
num_iters = 100000 // input_size
map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
map_fn_op = map_fn.map_fn(fn, base)
self._run(
map_defun_op, "with_defun_size_%d" % input_size, num_iters=num_iters)
self._run(
map_fn_op, "without_defun_size_%d" % input_size, num_iters=num_iters)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:20,代码来源:map_defun_benchmark.py
示例17: testMapDefunWithStrTensor
def testMapDefunWithStrTensor(self):
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
def fn(x):
return x
st = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
serialized = sparse_ops.serialize_sparse_v2(st, out_type=dtypes.string)
serialized = array_ops.stack([serialized, serialized])
map_defun_op = map_defun.map_defun(fn, [serialized], [dtypes.string],
[None])[0]
deserialized = sparse_ops.deserialize_sparse(map_defun_op, dtypes.int32)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0, 0], [0, 1, 2], [1, 0, 0], [1, 1, 2]],
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:20,代码来源:map_defun_op_test.py
示例18: testMapDefunWithParentCancellation
def testMapDefunWithParentCancellation(self):
# Checks that a cancellation of the parent graph is threaded through to
# MapDefunOp correctly.
@function.defun(input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
def simple_fn(x):
del x
queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
# Blocking
return queue.dequeue_many(5)
c = constant_op.constant([1, 2, 3, 4, 5])
map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
with self.cached_session() as sess:
thread = self.checkedThread(
self._assert_op_cancelled, args=(sess, map_defun_op))
thread.start()
time.sleep(0.2)
sess.close()
thread.join()
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:20,代码来源:map_defun_op_test.py
示例19: benchmarkDefunVsMapFn
def benchmarkDefunVsMapFn(self):
"""Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
@function.Defun(dtypes.int32)
def defun(x):
return array_ops.identity(x)
def map_fn(x):
return array_ops.identity(x)
base = math_ops.range(100)
for input_size in [10, 100, 1000, 10000]:
num_iters = 100000 // input_size
map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
map_fn_op = functional_ops.map_fn(map_fn, base)
self._run(
map_defun_op,
"benchmarkMapDefun_size_%d" % input_size,
num_iters=num_iters)
self._run(
map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:22,代码来源:map_defun_op_test.py
注:本文中的tensorflow.python.data.experimental.ops.map_defun.map_defun函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论