本文整理汇总了Python中tensorflow.python.framework.function.define_function函数的典型用法代码示例。如果您正苦于以下问题:Python define_function函数的具体用法?Python define_function怎么用?Python define_function使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了define_function函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testGradientFunc
def testGradientFunc(self):
def XSquarePlusOne(x):
return x * x + 1.0
def XSquarePlusOneGrad(x, dy):
dx = functional_ops._symbolic_gradient(input=[x, dy],
Tout=[tf.float32],
f="XSquarePlusOne",
name="dx")
return dx
g = tf.Graph()
with g.as_default():
f = function.define_function(XSquarePlusOne, {"x": tf.float32})
g = function.define_function(XSquarePlusOneGrad, {"x": tf.float32,
"dy": tf.float32})
epsilon = tf.constant([0.1])
two = tf.constant([2.0])
call_f = function.call_function(f, two)
call_g = function.call_function(g, two, epsilon)
with tf.Session() as sess:
self.assertAllClose([5.0], sess.run(call_f))
self.assertAllClose([0.4], sess.run(call_g))
开发者ID:13331151,项目名称:tensorflow,代码行数:25,代码来源:function_test.py
示例2: testDefineFunctionNoArgs
def testDefineFunctionNoArgs(self):
def AConstant():
return tf.constant([42])
with tf.Graph().as_default():
f_def = function.define_function(AConstant, {})
call = function.call_function(f_def)
self.assertEquals("AConstant", call.op.name)
with tf.Session() as sess:
self.assertAllEqual([42], sess.run(call))
开发者ID:bgyss,项目名称:tensorflow,代码行数:10,代码来源:function_test.py
示例3: testStrippedOpListNestedFunctions
def testStrippedOpListNestedFunctions(self):
with self.test_session():
# Square two levels deep
def f0(x):
return tf.square(x)
f0 = function.define_function(f0, {"x": tf.int32})
def f1(x):
return function.call_function(f0, x)
f1 = function.define_function(f1, {"x": tf.int32})
# At this point we've defined two functions but haven't called them, so
# there should be no used ops.
op_list = tf.contrib.util.stripped_op_list_for_graph(
tf.get_default_graph().as_graph_def())
self.assertEquals(len(op_list.op), 0)
# If we call the function on a constant, there should be two ops
function.call_function(f1, tf.constant(7))
op_list = tf.contrib.util.stripped_op_list_for_graph(
tf.get_default_graph().as_graph_def())
self.assertEquals(["Const", "Square"], [op.name for op in op_list.op])
开发者ID:2er0,项目名称:tensorflow,代码行数:21,代码来源:saver_test.py
示例4: testDefineFunction2Args
def testDefineFunction2Args(self):
def APlus2B(a, b):
return a + b * 2
with tf.Graph().as_default():
f_def = function.define_function(APlus2B, {"a": tf.float32, "b": tf.float32})
one = tf.constant([1.0])
two = tf.constant([2.0])
call = function.call_function(f_def, one, two)
self.assertEquals("APlus2B", call.op.name)
with tf.Session() as sess:
self.assertAllEqual([5.0], sess.run(call))
开发者ID:bgyss,项目名称:tensorflow,代码行数:12,代码来源:function_test.py
示例5: testCallErrors
def testCallErrors(self):
def Const():
return tf.constant(1)
def PlusOne(a):
return a + 1
def PlusMinus(a, b):
return a + b, b - a
with tf.Graph().as_default():
one = tf.constant([1])
two = tf.constant([2])
const = function.define_function(Const, {})
plus_one = function.define_function(PlusOne, {"a": tf.int32})
plus_minus = function.define_function(PlusMinus, {"a": tf.int32,
"b": tf.int32})
function.call_function(const)
with self.assertRaisesRegexp(ValueError, "arguments: 0"):
function.call_function(const, one)
with self.assertRaisesRegexp(ValueError, "arguments: 0"):
function.call_function(const, one, two)
with self.assertRaisesRegexp(ValueError, "arguments: 1"):
function.call_function(plus_one)
function.call_function(plus_one, one)
with self.assertRaisesRegexp(ValueError, "arguments: 1"):
function.call_function(plus_one, one, two)
with self.assertRaisesRegexp(ValueError, "arguments: 2"):
function.call_function(plus_minus)
with self.assertRaisesRegexp(ValueError, "arguments: 2"):
function.call_function(plus_minus, one)
function.call_function(plus_minus, one, two)
function.call_function(plus_one, one, name="p1")
with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"):
function.call_function(plus_one, one, device="/gpu:0")
开发者ID:13331151,项目名称:tensorflow,代码行数:40,代码来源:function_test.py
示例6: testDefineFunctionNames
def testDefineFunctionNames(self):
def Foo(a):
return a + 1
with tf.Graph().as_default():
f_def = function.define_function(Foo, {"a": tf.float32})
one = tf.constant([1.0])
call1 = function.call_function(f_def, one)
self.assertEquals("Foo", call1.op.name)
call2 = function.call_function(f_def, one)
self.assertEquals("Foo_1", call2.op.name)
call3 = function.call_function(f_def, one, name="mine")
self.assertEquals("mine", call3.op.name)
with tf.name_scope("my"):
call4 = function.call_function(f_def, one, name="precious")
self.assertEquals("my/precious", call4.op.name)
开发者ID:bgyss,项目名称:tensorflow,代码行数:16,代码来源:function_test.py
示例7: testDefineErrors
def testDefineErrors(self):
def NoResult():
pass
def VarArgs(*unused_b):
return tf.constant([1])
def DefaultArg(unused_a=12):
return tf.constant([1])
def KwArgs(**unused_kwargs):
return tf.constant([1])
def PlusMinus(a, b):
return a + b, b - a
with tf.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "return at least one tensor"):
function.define_function(NoResult, {})
with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
function.define_function(VarArgs, {})
with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
function.define_function(DefaultArg, {})
with self.assertRaisesRegexp(ValueError, "plain arglists are supported"):
function.define_function(KwArgs, {})
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {})
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {"c": tf.float32})
with self.assertRaisesRegexp(ValueError, "type for argument: b"):
function.define_function(PlusMinus, {"a": tf.float32,
"c": tf.float32})
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {"a": tf.float32,
"b": tf.float32,
"c": tf.float32})
开发者ID:13331151,项目名称:tensorflow,代码行数:37,代码来源:function_test.py
示例8: testDefineErrors
def testDefineErrors(self):
def NoResult():
pass
def DefaultArg(unused_a=12):
return tf.constant([1])
def KwArgs(**unused_kwargs):
return tf.constant([1])
def PlusMinus(a, b):
return a + b, b - a
with tf.Graph().as_default():
# pylint: disable=expression-not-assigned
with self.assertRaisesRegexp(ValueError, "return at least one tensor"):
function.define_function(NoResult, {}).definition
with self.assertRaisesRegexp(ValueError, "are not supported"):
function.define_function(DefaultArg, {}).definition
with self.assertRaisesRegexp(ValueError, "are not supported"):
function.define_function(KwArgs, {}).definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {}).definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {"c": tf.float32}).definition
with self.assertRaisesRegexp(ValueError, "type for argument: b"):
function.define_function(PlusMinus, {"a": tf.float32,
"c": tf.float32}).definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {"a": tf.float32,
"b": tf.float32,
"c": tf.float32}).definition
开发者ID:apollos,项目名称:tensorflow,代码行数:33,代码来源:function_test.py
示例9: XSquarePlusOne
import tensorflow as tf
from tensorflow.python.framework import function
from tensorflow.python.ops import functional_ops
graph = tf.Graph()
with graph.as_default():
tt = tf.constant([4.2])
def XSquarePlusOne(x):
ph = tf.placeholder("float", shape=[1])
return x * x + 1.0
def XSquarePlusOneGrad(x, dy):
dx = functional_ops._symbolic_gradient(input=[x, dy],
Tout=[tf.float32],
f="XSquarePlusOne",
name="dx")
return dx
f = function.define_function(XSquarePlusOne, {"x": tf.float32})
g = function.define_function(XSquarePlusOneGrad, {"x": tf.float32,
"dy": tf.float32})
epsilon = tf.constant([1.0])
two = tf.constant([2.0])
call_f = function.call_function(f, two)
call_g = function.call_function(g, two, epsilon)
tf.train.write_graph(graph.as_graph_def(), '/tmp/tfb', 'simple.pbtxt', as_text=True)
with tf.Session() as sess:
print sess.run(call_f)
print sess.run(call_g)
开发者ID:LaurentMazare,项目名称:tensorflow-ocaml,代码行数:30,代码来源:gradient.py
注:本文中的tensorflow.python.framework.function.define_function函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论