Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
196 views
in Technique[技术] by (71.8m points)

python - InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0

I have a TensorArray (a) to store the values computed within the tf.while_loop. However, I cannot convert the TensorArray to a Numpy array. For some reason, there seems to be a mismatch between int32 and float32.

import time
import tensorflow as tf
import numpy as np


#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    path='mnist.npz'
)


x_batch = tf.convert_to_tensor(x_train)
s_pred_im = tf.convert_to_tensor(x_batch)
iters = tf.constant(10)
a = tf.TensorArray(tf.float32, size=10)

def cond(value, a, s_pred_im, x_batch, i, iters):
    return tf.less(i, iters)

def body(value, a, s_pred_im, x_batch, i, iters):
    value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
    a = a.write(i,value)
    return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]

res = tf.while_loop(cond, body, [0, a, s_pred_im, x_batch, 0, iters])

b = res[1].stack()

with tf.Session() as sess:
    b.eval()

Doing this gives the following error -

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)

~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
   1364     try:
-> 1365       return fn(*args)
   1366     except errors.OpError as e:

~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1347       # Ensure any changes to the graph are reflected in the runtime.
-> 1348       self._extend_graph()
   1349       return self._call_tf_sessionrun(options, feed_dict, fetch_list,

~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _extend_graph(self)
   1387     with self._graph._session_run_lock():  # pylint: disable=protected-access
-> 1388       tf_session.ExtendSession(self._session)
   1389 

InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-72-5642d29d3bf6> in <module>
      1 with tf.Session() as sess:
----> 2     b.eval()

~anaconda3envsestlibsite-packagesensorflow_corepythonframeworkops.py in eval(self, feed_dict, session)
    796 
    797     """
--> 798     return _eval_using_default_session(self, feed_dict, self.graph, session)
    799 
    800   def experimental_ref(self):

~anaconda3envsestlibsite-packagesensorflow_corepythonframeworkops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
   5405                        "the tensor's graph is different from the session's "
   5406                        "graph.")
-> 5407   return session.run(tensors, feed_dict)
   5408 
   5409 

~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in run(self, fetches, feed_dict, options, run_metadata)
    954     try:
    955       result = self._run(None, fetches, feed_dict, options_ptr,
--> 956                          run_metadata_ptr)
    957       if run_metadata:
    958         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1178     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1179       results = self._do_run(handle, final_targets, final_fetches,
-> 1180                              feed_dict_tensor, options, run_metadata)
   1181     else:
   1182       results = []

~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1357     if handle is None:
   1358       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1359                            run_metadata)
   1360     else:
   1361       return self._do_call(_prun_fn, handle, feeds, fetches)

~anaconda3envsestlibsite-packagesensorflow_corepythonclientsession.py in _do_call(self, fn, *args)
   1382                     '
session_config.graph_options.rewrite_options.'
   1383                     'disable_meta_optimizer = True')
-> 1384       raise type(e)(node_def, op, message)
   1385 
   1386   def _extend_graph(self):

InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.

PS: This is an edit from an earlier post wherein I was trying to evaluate the value of the tensorarray incorrectly.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)
等待大神答复

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...