本文整理汇总了Python中tensorflow.python.util.compat.as_text函数的典型用法代码示例。如果您正苦于以下问题:Python as_text函数的具体用法?Python as_text怎么用?Python as_text使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了as_text函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: load_model
def load_model(saved_model_path):
"""Load a keras.Model from SavedModel.
load_model reinstantiates model state by:
1) loading model topology from json (this will eventually come
from metagraph).
2) loading model weights from checkpoint.
Args:
saved_model_path: a string specifying the path to an existing SavedModel.
Returns:
a keras.Model instance.
"""
# restore model topology from json string
model_json_filepath = os.path.join(
compat.as_bytes(saved_model_path),
compat.as_bytes(constants.ASSETS_DIRECTORY),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
model_json = file_io.read_file_to_string(model_json_filepath)
model = model_from_json(model_json)
# restore model weights
checkpoint_prefix = os.path.join(
compat.as_text(saved_model_path),
compat.as_text(constants.VARIABLES_DIRECTORY),
compat.as_text(constants.VARIABLES_FILENAME))
model.load_weights(checkpoint_prefix)
return model
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:29,代码来源:keras_saved_model.py
示例2: testFormatOneTensorOneDimVarySummarize
def testFormatOneTensorOneDimVarySummarize(self):
with self.test_session():
tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=-1)
out = self.evaluate(format_output)
expected = "[0 1 2 3 4 5]"
self.assertEqual(compat.as_text(out), expected)
with self.test_session():
tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=1)
out = self.evaluate(format_output)
expected = "[0 ... 5]"
self.assertEqual(compat.as_text(out), expected)
with self.test_session():
tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=2)
out = self.evaluate(format_output)
expected = "[0 1 ... 4 5]"
self.assertEqual(compat.as_text(out), expected)
with self.test_session():
tensor = math_ops.range(6)
format_output = string_ops.string_format("{}", tensor, summarize=10)
out = self.evaluate(format_output)
expected = "[0 1 2 3 4 5]"
self.assertEqual(compat.as_text(out), expected)
开发者ID:daiwk,项目名称:tensorflow,代码行数:28,代码来源:string_format_op_test.py
示例3: add_meta_graph_and_variables
def add_meta_graph_and_variables(self,
sess,
tags,
signature_def_map=None,
assets_collection=None,
legacy_init_op=None):
"""Adds the current meta graph to the SavedModel and saves variables.
Creates a Saver to save the variables from the provided session. Exports the
corresponding meta graph def. This function assumes that the variables to be
saved have been initialized. For a given `SavedModelBuilder`, this API must
be called exactly once and for the first meta graph to save. For subsequent
meta graph defs to be added, the `add_meta_graph()` API must be used.
Args:
sess: The TensorFlow session from which to save the meta graph and
variables.
tags: The set of tags with which to save the meta graph.
signature_def_map: The map of signature def map to add to the meta graph
def.
assets_collection: Assets collection to be saved with SavedModel.
legacy_init_op: Op or group of ops to execute after the restore op upon a
load.
"""
if self._has_saved_variables:
raise AssertionError("Variables and assets have already been saved. "
"Please invoke `add_meta_graph()` instead.")
# Save asset files and write them to disk, if any.
self._save_and_write_assets(assets_collection)
# Create the variables sub-directory, if it does not exist.
variables_dir = os.path.join(
compat.as_text(self._export_dir),
compat.as_text(constants.VARIABLES_DIRECTORY))
if not file_io.file_exists(variables_dir):
file_io.recursive_create_dir(variables_dir)
variables_path = os.path.join(
compat.as_text(variables_dir),
compat.as_text(constants.VARIABLES_FILENAME))
# Add legacy init op to the SavedModel.
self._maybe_add_legacy_init_op(legacy_init_op)
# Save the variables and export meta graph def.
saver = tf_saver.Saver(
variables.all_variables(),
sharded=True,
write_version=saver_pb2.SaverDef.V2)
saver.save(sess, variables_path, write_meta_graph=False)
meta_graph_def = saver.export_meta_graph()
# Tag the meta graph def and add it to the SavedModel.
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
# Mark this instance of SavedModel as having saved variables, such that
# subsequent attempts to save variables will fail.
self._has_saved_variables = True
开发者ID:caikehe,项目名称:tensorflow,代码行数:59,代码来源:builder.py
示例4: _do_run
def _do_run(self, target_list, fetch_list, feed_dict):
"""Runs a step based on the given fetches and feeds.
Args:
target_list: A list of byte arrays corresponding to names of tensors
or operations to be run to, but not fetched.
fetch_list: A list of byte arrays corresponding to names of tensors to
be fetched and operations to be run.
feed_dict: A dictionary that maps tensor names (as byte arrays) to
numpy ndarrays.
Returns:
A list of numpy ndarrays, corresponding to the elements of
`fetch_list`. If the ith element of `fetch_list` contains the
name of an operation, the first Tensor output of that operation
will be returned for that element.
"""
try:
# Ensure any changes to the graph are reflected in the runtime.
with self._extend_lock:
if self._graph.version > self._current_version:
graph_def = self._graph.as_graph_def(
from_version=self._current_version)
try:
status = tf_session.TF_NewStatus()
tf_session.TF_ExtendGraph(
self._session, graph_def.SerializeToString(), status)
if tf_session.TF_GetCode(status) != 0:
raise RuntimeError(compat.as_text(tf_session.TF_Message(status)))
self._opened = True
finally:
tf_session.TF_DeleteStatus(status)
self._current_version = self._graph.version
return tf_session.TF_Run(self._session, feed_dict, fetch_list,
target_list)
except tf_session.StatusNotOK as e:
e_type, e_value, e_traceback = sys.exc_info()
error_message = compat.as_text(e.error_message)
m = BaseSession._NODEDEF_NAME_RE.search(error_message)
if m is not None:
node_name = m.group(1)
node_def = None
try:
op = self._graph.get_operation_by_name(node_name)
node_def = op.node_def
except KeyError:
op = None
# pylint: disable=protected-access
raise errors._make_specific_exception(node_def, op, error_message,
e.code)
# pylint: enable=protected-access
six.reraise(e_type, e_value, e_traceback)
开发者ID:danvk,项目名称:tensorflow,代码行数:56,代码来源:session.py
示例5: load_from_saved_model
def load_from_saved_model(saved_model_path, custom_objects=None):
"""Loads a keras Model from a SavedModel created by `export_saved_model()`.
This function reinstantiates model state by:
1) loading model topology from json (this will eventually come
from metagraph).
2) loading model weights from checkpoint.
Example:
```python
import tensorflow as tf
# Create a tf.keras model.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=[10]))
model.summary()
# Save the tf.keras model in the SavedModel format.
path = '/tmp/simple_keras_model'
tf.keras.experimental.export_saved_model(model, path)
# Load the saved keras model back.
new_model = tf.keras.experimental.load_from_saved_model(path)
new_model.summary()
```
Args:
saved_model_path: a string specifying the path to an existing SavedModel.
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
Returns:
a keras.Model instance.
"""
# restore model topology from json string
model_json_filepath = os.path.join(
compat.as_bytes(saved_model_path),
compat.as_bytes(constants.ASSETS_DIRECTORY),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
model_json = file_io.read_file_to_string(model_json_filepath)
model = model_from_json(model_json, custom_objects=custom_objects)
# restore model weights
checkpoint_prefix = os.path.join(
compat.as_text(saved_model_path),
compat.as_text(constants.VARIABLES_DIRECTORY),
compat.as_text(constants.VARIABLES_FILENAME))
model.load_weights(checkpoint_prefix)
return model
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:51,代码来源:saved_model.py
示例6: testFormatOneTensorOneDim
def testFormatOneTensorOneDim(self):
with self.test_session():
tensor = math_ops.range(10)
format_output = string_ops.string_format("{}", tensor)
out = self.evaluate(format_output)
expected = "[0 1 2 ... 7 8 9]"
self.assertEqual(compat.as_text(out), expected)
with self.test_session():
tensor = math_ops.range(10)
format_output = string_ops.string_format("{}", [tensor])
out = self.evaluate(format_output)
expected = "[0 1 2 ... 7 8 9]"
self.assertEqual(compat.as_text(out), expected)
开发者ID:daiwk,项目名称:tensorflow,代码行数:14,代码来源:string_format_op_test.py
示例7: _start_local_server
def _start_local_server(self):
address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
self._server = server_lib.Server(
{
'local': ['0.0.0.0:0']
}, protocol='grpc', config=None, start=True)
# self._server.target is of the form: grpc://ipaddress:port
target = compat.as_bytes(self._server.target)
splits = target.split(compat.as_bytes(':'))
assert len(splits) == 3, self._server.target
assert splits[0] == compat.as_bytes('grpc'), self._server.target
self._coordinator_port = compat.as_text(splits[2])
self._coordinator_address = '%s:%s' % (
address, compat.as_text(self._coordinator_port))
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:14,代码来源:tpu_cluster_resolver.py
示例8: save_model
def save_model(model, saved_model_path):
"""Save a `tf.keras.Model` into Tensorflow SavedModel format.
`save_model` generates such files/folders under the `saved_model_path` folder:
1) an asset folder containing the json string of the model's
configuration(topology).
2) a checkpoint containing the model weights.
Note that subclassed models can not be saved via this function, unless you
provide an implementation for get_config() and from_config().
Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
saved to checkpoints. Use optimizers from `tf.train`.
Args:
model: A `tf.keras.Model` to be saved.
saved_model_path: a string specifying the path to the SavedModel directory.
Raises:
NotImplementedError: If the passed in model is a subclassed model.
"""
if not model._is_graph_network:
raise NotImplementedError
# save model configuration as a json string under assets folder.
model_json = model.to_json()
assets_destination_dir = os.path.join(
compat.as_bytes(saved_model_path),
compat.as_bytes(constants.ASSETS_DIRECTORY))
if not file_io.file_exists(assets_destination_dir):
file_io.recursive_create_dir(assets_destination_dir)
model_json_filepath = os.path.join(
compat.as_bytes(assets_destination_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
file_io.write_string_to_file(model_json_filepath, model_json)
# save model weights in checkpoint format.
checkpoint_destination_dir = os.path.join(
compat.as_bytes(saved_model_path),
compat.as_bytes(constants.VARIABLES_DIRECTORY))
if not file_io.file_exists(checkpoint_destination_dir):
file_io.recursive_create_dir(checkpoint_destination_dir)
checkpoint_prefix = os.path.join(
compat.as_text(checkpoint_destination_dir),
compat.as_text(constants.VARIABLES_FILENAME))
model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:49,代码来源:keras_saved_model.py
示例9: load_keras_model
def load_keras_model(saved_model_path):
"""Load a keras.Model from SavedModel.
load_model reinstantiates model state by:
1) loading model topology from json (this will eventually come
from metagraph).
2) loading model weights from checkpoint.
Example:
```python
import tensorflow as tf
# Create a tf.keras model.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=[10]))
model.summary()
# Save the tf.keras model in the SavedModel format.
saved_to_path = tf.contrib.saved_model.save_keras_model(
model, '/tmp/my_simple_tf_keras_saved_model')
# Load the saved keras model back.
model_prime = tf.contrib.saved_model.load_keras_model(saved_to_path)
model_prime.summary()
```
Args:
saved_model_path: a string specifying the path to an existing SavedModel.
Returns:
a keras.Model instance.
"""
# restore model topology from json string
model_json_filepath = os.path.join(
compat.as_bytes(saved_model_path),
compat.as_bytes(constants.ASSETS_DIRECTORY),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
model_json = file_io.read_file_to_string(model_json_filepath)
model = model_from_json(model_json)
# restore model weights
checkpoint_prefix = os.path.join(
compat.as_text(saved_model_path),
compat.as_text(constants.VARIABLES_DIRECTORY),
compat.as_text(constants.VARIABLES_FILENAME))
model.load_weights(checkpoint_prefix)
return model
开发者ID:aeverall,项目名称:tensorflow,代码行数:48,代码来源:keras_saved_model.py
示例10: _TestOneEpochWithHopBytes
def _TestOneEpochWithHopBytes(self,
files,
num_overlapped_records,
encoding=None):
with self.test_session() as sess:
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
footer_bytes=self._footer_bytes,
hop_bytes=self._hop_bytes,
encoding=encoding,
name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
queue.enqueue_many([files]).run()
queue.close().run()
for i in range(self._num_files):
for j in range(num_overlapped_records):
k, v = sess.run([key, value])
self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
self.assertAllEqual(self._OverlappedRecord(i, j), v)
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
k, v = sess.run([key, value])
开发者ID:AnishShah,项目名称:tensorflow,代码行数:26,代码来源:reader_ops_test.py
示例11: testComplexCodeView
def testComplexCodeView(self):
ops.reset_default_graph()
outfile = os.path.join(test.get_temp_dir(), 'dump')
opts = (builder(builder.trainable_variables_parameter())
.with_file_output(outfile)
.with_accounted_types(['.*'])
.with_node_names(show_name_regexes=
['.*model_analyzer_testlib.py.*'])
.account_displayed_op_only(False)
.select(['params', 'float_ops']).build())
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
with session.Session() as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
pctx.trace_next_step()
_ = sess.run(x)
tfprof_node = pctx.profiler.profile_python(options=opts)
# pylint: disable=line-too-long
with gfile.Open(outfile, 'r') as f:
lines = f.read().split('\n')
self.assertGreater(len(lines), 5)
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
self.assertTrue(
compat.as_text(lib.CheckAndRemoveDoc(result))
.startswith('node name | # parameters | # float_ops'))
self.assertLess(0, tfprof_node.total_exec_micros)
self.assertEqual(2844, tfprof_node.total_parameters)
self.assertLess(145660, tfprof_node.total_float_ops)
self.assertEqual(8, len(tfprof_node.children))
self.assertEqual('_TFProfRoot', tfprof_node.name)
self.assertEqual(
'model_analyzer_testlib.py:63:BuildFullModel',
tfprof_node.children[0].name)
self.assertEqual(
'model_analyzer_testlib.py:63:BuildFullModel (gradient)',
tfprof_node.children[1].name)
self.assertEqual(
'model_analyzer_testlib.py:67:BuildFullModel',
tfprof_node.children[2].name)
self.assertEqual(
'model_analyzer_testlib.py:67:BuildFullModel (gradient)',
tfprof_node.children[3].name)
self.assertEqual(
'model_analyzer_testlib.py:69:BuildFullModel',
tfprof_node.children[4].name)
self.assertEqual(
'model_analyzer_testlib.py:70:BuildFullModel',
tfprof_node.children[5].name)
self.assertEqual(
'model_analyzer_testlib.py:70:BuildFullModel (gradient)',
tfprof_node.children[6].name)
self.assertEqual(
'model_analyzer_testlib.py:72:BuildFullModel',
tfprof_node.children[7].name)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:60,代码来源:model_analyzer_test.py
示例12: run_benchmark
def run_benchmark(sess, init_op, add_op):
"""Returns MB/s rate of addition."""
logdir=FLAGS.logdir_prefix+'/'+FLAGS.name
os.system('mkdir -p '+logdir)
# TODO: make events follow same format as eager writer
writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
filename = compat.as_text(writer.FileName())
training_util.get_or_create_global_step()
sess.run(init_op)
for step in range(FLAGS.iters):
start_time = time.time()
for i in range(FLAGS.iters_per_step):
sess.run(add_op.op)
elapsed_time = time.time() - start_time
rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
event = make_event('rate', rate, step)
writer.WriteEvent(event)
writer.Flush()
writer.Close()
开发者ID:yaroslavvb,项目名称:stuff,代码行数:25,代码来源:benchmark_grpc_recv.py
示例13: save
def save(self, as_text=False):
"""Writes a `SavedModel` protocol buffer to disk.
The function writes the SavedModel protocol buffer to the export directory
in serialized format.
Args:
as_text: Writes the SavedModel protocol buffer in text format to disk.
Returns:
The path to which the SavedModel protocol buffer was written.
"""
if not file_io.file_exists(self._export_dir):
file_io.recursive_create_dir(self._export_dir)
if as_text:
path = os.path.join(
compat.as_bytes(self._export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
file_io.write_string_to_file(path, str(self._saved_model))
else:
path = os.path.join(
compat.as_bytes(self._export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
file_io.write_string_to_file(path, self._saved_model.SerializeToString())
tf_logging.info("SavedModel written to: %s", compat.as_text(path))
return path
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:28,代码来源:builder_impl.py
示例14: testReadGzipFiles
def testReadGzipFiles(self):
files = self._CreateFiles()
gzip_files = []
for i, fn in enumerate(files):
with open(fn, "rb") as f:
cdata = f.read()
zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
with gzip.GzipFile(zfn, "wb") as f:
f.write(cdata)
gzip_files.append(zfn)
with self.test_session() as sess:
options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
queue.enqueue_many([gzip_files]).run()
queue.close().run()
for i in range(self._num_files):
for j in range(self._num_records):
k, v = sess.run([key, value])
self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i]))
self.assertAllEqual(self._Record(i, j), v)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:25,代码来源:reader_ops_test.py
示例15: _export_model_json
def _export_model_json(model, saved_model_path):
"""Saves model configuration as a json string under assets folder."""
model_json = model.to_json()
model_json_filepath = os.path.join(
saved_model_utils.get_or_create_assets_dir(saved_model_path),
compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
file_io.write_string_to_file(model_json_filepath, model_json)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:7,代码来源:saved_model.py
示例16: testFormatOneTensorOneDimAlmostSummarize
def testFormatOneTensorOneDimAlmostSummarize(self):
with self.test_session():
tensor = math_ops.range(5)
format_output = string_ops.string_format("{}", tensor, summarize=3)
out = self.evaluate(format_output)
expected = "[0 1 2 3 4]"
self.assertEqual(compat.as_text(out), expected)
开发者ID:daiwk,项目名称:tensorflow,代码行数:7,代码来源:string_format_op_test.py
示例17: load_file_system_library
def load_file_system_library(library_filename):
"""Loads a TensorFlow plugin, containing file system implementation.
Pass `library_filename` to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
library are platform-specific and are not documented here.
Args:
library_filename: Path to the plugin.
Relative or absolute filesystem path to a dynamic library file.
Returns:
None.
Raises:
RuntimeError: when unable to load the library.
"""
status = py_tf.TF_NewStatus()
lib_handle = py_tf.TF_LoadLibrary(library_filename, status)
try:
error_code = py_tf.TF_GetCode(status)
if error_code != 0:
error_msg = compat.as_text(py_tf.TF_Message(status))
# pylint: disable=protected-access
raise errors_impl._make_specific_exception(
None, None, error_msg, error_code)
# pylint: enable=protected-access
finally:
py_tf.TF_DeleteStatus(status)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:29,代码来源:load_library.py
示例18: testWriteEvents
def testWriteEvents(self):
file_prefix = os.path.join(self.get_temp_dir(), "events")
writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(file_prefix))
filename = compat.as_text(writer.FileName())
event_written = event_pb2.Event(
wall_time=123.45, step=67,
summary=summary_pb2.Summary(
value=[summary_pb2.Summary.Value(tag="foo", simple_value=89.0)]))
writer.WriteEvent(event_written)
writer.Flush()
writer.Close()
with self.assertRaises(IOError):
for r in tf_record.tf_record_iterator(filename + "DOES_NOT_EXIST"):
self.assertTrue(False)
reader = tf_record.tf_record_iterator(filename)
event_read = event_pb2.Event()
event_read.ParseFromString(next(reader))
self.assertTrue(event_read.HasField("file_version"))
event_read.ParseFromString(next(reader))
# Second event
self.assertProtoEquals("""
wall_time: 123.45 step: 67
summary { value { tag: 'foo' simple_value: 89.0 } }
""", event_read)
with self.assertRaises(StopIteration):
next(reader)
开发者ID:0ruben,项目名称:tensorflow,代码行数:31,代码来源:events_writer_test.py
示例19: testZLibFlushRecord
def testZLibFlushRecord(self):
fn = self._WriteRecordsToFile([b"small record"], "small_record")
with open(fn, "rb") as h:
buff = h.read()
# creating more blocks and trailing blocks shouldn't break reads
compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS)
output = b""
for c in buff:
if isinstance(c, int):
c = six.int2byte(c)
output += compressor.compress(c)
output += compressor.flush(zlib.Z_FULL_FLUSH)
output += compressor.flush(zlib.Z_FULL_FLUSH)
output += compressor.flush(zlib.Z_FULL_FLUSH)
output += compressor.flush(zlib.Z_FINISH)
# overwrite the original file with the compressed data
with open(fn, "wb") as h:
h.write(output)
with self.test_session() as sess:
options = tf_record.TFRecordOptions(
compression_type=TFRecordCompressionType.ZLIB)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=())
key, value = reader.read(queue)
queue.enqueue(fn).run()
queue.close().run()
k, v = sess.run([key, value])
self.assertTrue(compat.as_text(k).startswith("%s:" % fn))
self.assertAllEqual(b"small record", v)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:34,代码来源:reader_ops_test.py
示例20: _save_and_write_assets
def _save_and_write_assets(self, assets_collection_to_add=None):
"""Saves asset to the meta graph and writes asset files to disk.
Args:
assets_collection_to_add: The collection where the asset paths are setup.
"""
asset_filename_map = _maybe_save_assets(assets_collection_to_add)
# Return if there are no assets to write.
if not asset_filename_map:
tf_logging.info("No assets to write.")
return
assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
self._export_dir)
# Copy each asset from source path to destination path.
for asset_basename, asset_source_filepath in asset_filename_map.items():
asset_destination_filepath = os.path.join(
compat.as_bytes(assets_destination_dir),
compat.as_bytes(asset_basename))
# Only copy the asset file to the destination if it does not already
# exist. This is to ensure that an asset with the same name defined as
# part of multiple graphs is only copied the first time.
if not file_io.file_exists(asset_destination_filepath):
file_io.copy(asset_source_filepath, asset_destination_filepath)
tf_logging.info("Assets written to: %s",
compat.as_text(assets_destination_dir))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:30,代码来源:builder_impl.py
注:本文中的tensorflow.python.util.compat.as_text函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论