Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
548 views
in Technique[技术] by (71.8m points)

tensorflow - Replace Validation Monitors with tf.train.SessionRunHook when using Estimators

I am running a DNNClassifier, for which I am monitoring accuracy while training. monitors.ValidationMonitor from contrib/learn has been working great, in my implementation I define it:

validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)

and then use call from:

clf.fit(input_fn=lambda: input_fn(A, Cl2),
            steps=1000, monitors=[validation_monitor])

where:

clf = tensorflow.contrib.learn.DNNClassifier(...

This works fine. That said, validation monitors appear to be deprecated and a similar functionality to be replaced with tf.train.SessionRunHook.

I am a newbie in TensorFlow, and it does not seem trivial to me how such a replacing implementation would look like. Any suggestion are highly appreciated. Again, I need to validate the training after a specific number of steps. Thanks very much in advance.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

There's an undocumented utility called monitors.replace_monitors_with_hooks() which converts monitors to hooks. The method accepts (i) a list which may contain both monitors and hooks and (ii) the Estimator for which the hooks will be used, and then returns a list of hooks by wrapping a SessionRunHook around each Monitor.

from tensorflow.contrib.learn.python.learn import monitors as monitor_lib

clf = tf.estimator.Estimator(...)

list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)

This isn't really a true solution to the problem of fully replacing the ValidationMonitor—we're just wrapping it up with a non-deprecated function instead. However, I can say this has worked for me so far in that it maintained all the functionality I need from the ValidationMonitor (i.e. evaluating every n steps, early stopping using a metric, etc.)

One more thing—to use this hook you'll need to update from a tf.contrib.learn.Estimator (which only accepts monitors) to the more full-fledged and official tf.estimator.Estimator (which only accepts hooks). So, you should instantiate your classifier as a tf.estimator.DNNClassifier, and train using its method train() instead (which is just a re-naming of fit()):

clf = tf.estimator.Estimator(...)

...

clf.train(
    input_fn=...
    ...
    hooks=hooks)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

2.1m questions

2.1m answers

60 comments

57.0k users

...