I have a TensorArray (a) to store the values computed within the tf.while_loop. However, I cannot convert the TensorArray to a Numpy array. For some reason, there seems to be a mismatch between int32 and float32.
import time
import tensorflow as tf
import numpy as np
#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
x_batch = tf.convert_to_tensor(x_train)
s_pred_im = tf.convert_to_tensor(x_batch)
iters = tf.constant(10)
a = tf.TensorArray(tf.float32, size=10)
def cond(value, a, s_pred_im, x_batch, i, iters):
return tf.less(i, iters)
def body(value, a, s_pred_im, x_batch, i, iters):
value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
a = a.write(i,value)
return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]
res = tf.while_loop(cond, body, [0, a, s_pred_im, x_batch, 0, iters])
b = res[1].stack()
with tf.Session() as sess:
b.eval()
Doing this gives the following error -
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
1364 try:
-> 1365 return fn(*args)
1366 except errors.OpError as e:
~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1347 # Ensure any changes to the graph are reflected in the runtime.
-> 1348 self._extend_graph()
1349 return self._call_tf_sessionrun(options, feed_dict, fetch_list,
~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _extend_graph(self)
1387 with self._graph._session_run_lock(): # pylint: disable=protected-access
-> 1388 tf_session.ExtendSession(self._session)
1389
InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-72-5642d29d3bf6> in <module>
1 with tf.Session() as sess:
----> 2 b.eval()
~anaconda3envsestlibsite-packagesensorflow_corepythonframeworkops.py in eval(self, feed_dict, session)
796
797 """
--> 798 return _eval_using_default_session(self, feed_dict, self.graph, session)
799
800 def experimental_ref(self):
~anaconda3envsestlibsite-packagesensorflow_corepythonframeworkops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
5405 "the tensor's graph is different from the session's "
5406 "graph.")
-> 5407 return session.run(tensors, feed_dict)
5408
5409
~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in run(self, fetches, feed_dict, options, run_metadata)
954 try:
955 result = self._run(None, fetches, feed_dict, options_ptr,
--> 956 run_metadata_ptr)
957 if run_metadata:
958 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1178 if final_fetches or final_targets or (handle and feed_dict_tensor):
1179 results = self._do_run(handle, final_targets, final_fetches,
-> 1180 feed_dict_tensor, options, run_metadata)
1181 else:
1182 results = []
~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1357 if handle is None:
1358 return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1359 run_metadata)
1360 else:
1361 return self._do_call(_prun_fn, handle, feeds, fetches)
~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
1382 '
session_config.graph_options.rewrite_options.'
1383 'disable_meta_optimizer = True')
-> 1384 raise type(e)(node_def, op, message)
1385
1386 def _extend_graph(self):
InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.
PS: This is an edit from an earlier post wherein I was trying to evaluate the value of the tensorarray incorrectly.