It took me some time to pin down the problem. Here is is:
class ComplicatedStuff:
def __init__(self):
self.result = None
def fun(self, val):
self.result = val
@tf.function
def no_fun(x, blabla):
s = ComplicatedStuff()
# s.do_this(blabla)
# s.do_that(blabla)
if x > .5:
s.fun(2*x)
else:
s.fun(x)
return s.result
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
I would really expect to get 2.0
back instead of 1.0
. I figured out the reason is that the conditional is traced in both branches, and because I return a value using a side-effect in s
, only the result of the second branch survives. The question is, how do I code around this limitation? Using return values would solve it, but it will definitely uglify the code because ComplicatedStuff wraps a bunch of intermediate results that I don't want to expose like that. Is there some better option?
The thing I came up with that more-or-less preserved the structure, is this hackery:
class ComplicatedStuff(dict):
def __init__(self):
super().__init__()
self.result = None
def fun(self, val):
self.result = val
def __setattr__(self, item, value):
self[item] = value
def __getattribute__(self, item):
if item.startswith("__") or item not in self:
return super().__getattribute__(item)
else:
return self[item]
@tf.function
def no_fun(x, blabla):
s = ComplicatedStuff()
# s.do_this(blabla)
# s.do_that(blabla)
if x > .5:
s.fun(2*x)
s = s
else:
s.fun(x)
s = s
return s.result
no_fun(tf.constant(1.), ...)
>>> <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
There must be a better option, right?
question from:
https://stackoverflow.com/questions/66062501/tensorflow-tf-function-conditionals 与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…