本文整理汇总了Python中tensorflow.python.util.compat.as_bytes函数的典型用法代码示例。如果您正苦于以下问题:Python as_bytes函数的具体用法?Python as_bytes怎么用?Python as_bytes使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了as_bytes函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _save_and_write_assets
def _save_and_write_assets(self, assets_collection_to_add=None):
"""Saves asset to the meta graph and writes asset files to disk.
Args:
assets_collection_to_add: The collection where the asset paths are setup.
"""
asset_source_filepath_list = self._save_assets(assets_collection_to_add)
# Return if there are no assets to write.
if len(asset_source_filepath_list) is 0:
tf_logging.info("No assets to write.")
return
assets_destination_dir = os.path.join(
compat.as_bytes(self._export_dir),
compat.as_bytes(constants.ASSETS_DIRECTORY))
if not file_io.file_exists(assets_destination_dir):
file_io.recursive_create_dir(assets_destination_dir)
# Copy each asset from source path to destination path.
for asset_source_filepath in asset_source_filepath_list:
asset_source_filename = os.path.basename(asset_source_filepath)
asset_destination_filepath = os.path.join(
compat.as_bytes(assets_destination_dir),
compat.as_bytes(asset_source_filename))
file_io.copy(
asset_source_filepath, asset_destination_filepath, overwrite=True)
tf_logging.info("Assets written to: %s", assets_destination_dir)
开发者ID:Qstar,项目名称:tensorflow,代码行数:31,代码来源:builder.py
示例2: save
def save(self, as_text=False):
"""Writes a `SavedModel` protocol buffer to disk.
The function writes the SavedModel protocol buffer to the export directory
in serialized format.
Args:
as_text: Writes the SavedModel protocol buffer in text format to disk.
Returns:
The path to which the SavedModel protocol buffer was written.
"""
if not file_io.file_exists(self._export_dir):
file_io.recursive_create_dir(self._export_dir)
if as_text:
path = os.path.join(
compat.as_bytes(self._export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
file_io.write_string_to_file(path, str(self._saved_model))
else:
path = os.path.join(
compat.as_bytes(self._export_dir),
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
file_io.write_string_to_file(path, self._saved_model.SerializeToString())
tf_logging.info("SavedModel written to: %s", path)
return path
开发者ID:1000sprites,项目名称:tensorflow,代码行数:28,代码来源:builder_impl.py
示例3: get_timestamped_dir
def get_timestamped_dir(dir_base):
"""Builds a path to a new subdirectory within the base directory.
The subdirectory will be named using the current time.
This guarantees monotonically increasing directory numbers even across
multiple runs of the pipeline.
The timestamp used is the number of seconds since epoch UTC.
Args:
dir_base: A string containing a directory to create the subdirectory under.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts = 0
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
timestamp = int(time.time())
result_dir = os.path.join(
compat.as_bytes(dir_base), compat.as_bytes(str(timestamp)))
if not gfile.Exists(result_dir):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return result_dir
time.sleep(1)
attempts += 1
logging.warn('Directory {} already exists; retrying (attempt {}/{})'.format(
result_dir, attempts, MAX_DIRECTORY_CREATION_ATTEMPTS))
raise RuntimeError('Failed to obtain a unique export directory name after '
'{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:35,代码来源:util.py
示例4: to_proto
def to_proto(self, export_scope=None): # pylint: disable=unused-argument
"""Converts a `HParams` object to a `HParamDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `HParamDef` protocol buffer.
"""
hparam_proto = hparam_pb2.HParamDef()
for name in self._hparam_types:
# Parse the values.
param_type, is_list = self._hparam_types.get(name, (None, None))
kind = HParams._get_kind_name(param_type, is_list)
if is_list:
if kind.startswith('bytes'):
v_list = [compat.as_bytes(v) for v in getattr(self, name)]
else:
v_list = [v for v in getattr(self, name)]
getattr(hparam_proto.hparam[name], kind).value.extend(v_list)
else:
v = getattr(self, name)
if kind.startswith('bytes'):
v = compat.as_bytes(getattr(self, name))
setattr(hparam_proto.hparam[name], kind, v)
return hparam_proto
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:28,代码来源:hparam.py
示例5: _get_asset_tensors
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
"""Gets the asset tensors, if defined in the meta graph def to load.
Args:
export_dir: Directory where the SavedModel is located.
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
Returns:
A dictionary of asset tensors, keyed by the name of the asset tensor. The
value in the map corresponds to the absolute path of the asset file.
"""
# Collection-def that may contain the assets key.
collection_def = meta_graph_def_to_load.collection_def
asset_tensor_dict = {}
if constants.ASSETS_KEY in collection_def:
# Location of the assets for SavedModel.
assets_directory = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.ASSETS_DIRECTORY))
assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
# Process each asset and add it to the asset tensor dictionary.
for asset_any_proto in assets_any_proto:
asset_proto = meta_graph_pb2.AssetFileDef()
asset_any_proto.Unpack(asset_proto)
asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
compat.as_bytes(assets_directory),
compat.as_bytes(asset_proto.filename))
return asset_tensor_dict
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:29,代码来源:loader_impl.py
示例6: gfile_copy_callback
def gfile_copy_callback(files_to_copy, export_dir_path):
"""Callback to copy files using `gfile.Copy` to an export directory.
This method is used as the default `assets_callback` in `Exporter.init` to
copy assets from the `assets_collection`. It can also be invoked directly to
copy additional supplementary files into the export directory (in which case
it is not a callback).
Args:
files_to_copy: A dictionary that maps original file paths to desired
basename in the export directory.
export_dir_path: Directory to copy the files to.
"""
logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
gfile.MakeDirs(export_dir_path)
for source_filepath, basename in files_to_copy.items():
new_path = os.path.join(
compat.as_bytes(export_dir_path), compat.as_bytes(basename))
logging.info("Copying asset %s to path %s.", source_filepath, new_path)
if gfile.Exists(new_path):
# Guard against being restarted while copying assets, and the file
# existing and being in an unknown state.
# TODO(b/28676216): Do some file checks before deleting.
logging.info("Removing file %s.", new_path)
gfile.Remove(new_path)
gfile.Copy(source_filepath, new_path)
开发者ID:2020zyc,项目名称:tensorflow,代码行数:27,代码来源:exporter.py
示例7: testShardsRunOnRequestedDevices
def testShardsRunOnRequestedDevices(self):
config = config_pb2.ConfigProto(device_count={"CPU": 4})
@function.Defun()
def Body():
# Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on
# which the resource was created, so that we can verify that ops were
# actually run on the requested devices.
#
# TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the
# name of the device on which a resource lives / for determining the
# device on which an op ran.
with ops.device("/cpu:0"):
s1 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device("/cpu:1"):
s2 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device("/cpu:2"):
s3 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
return s1, s2, s3
with self.test_session(config=config, use_gpu=True) as sess:
outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:28,代码来源:functional_ops_test.py
示例8: testComplexCodeView
def testComplexCodeView(self):
ops.reset_default_graph()
outfile = os.path.join(test.get_temp_dir(), 'dump')
opts = (builder(builder.trainable_variables_parameter())
.with_file_output(outfile)
.with_accounted_types(['.*'])
.with_node_names(show_name_regexes=
['.*model_analyzer_testlib.py.*'])
.account_displayed_op_only(False)
.select(['params', 'float_ops']).build())
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
with session.Session() as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
pctx.trace_next_step()
_ = sess.run(x)
tfprof_node = pctx.profiler.profile_python(options=opts)
# pylint: disable=line-too-long
with gfile.Open(outfile, 'r') as f:
lines = f.read().split('\n')
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
self.assertEqual(
compat.as_bytes(
'node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/168.86k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/45.37k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/8 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/1.30k flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/2.30k flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/67.39k f\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/46.66\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/20.74\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/18.58k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/37.00k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/258 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/129 flop\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/141 flops)\n'
), compat.as_bytes(lib.CheckAndRemoveDoc(result)))
self.assertLess(0, tfprof_node.total_exec_micros)
self.assertEqual(2844, tfprof_node.total_parameters)
self.assertEqual(168863, tfprof_node.total_float_ops)
self.assertEqual(8, len(tfprof_node.children))
self.assertEqual('_TFProfRoot', tfprof_node.name)
self.assertEqual(
'model_analyzer_testlib.py:63:BuildFullModel',
tfprof_node.children[0].name)
self.assertEqual(
'model_analyzer_testlib.py:63:BuildFullModel (gradient)',
tfprof_node.children[1].name)
self.assertEqual(
'model_analyzer_testlib.py:67:BuildFullModel',
tfprof_node.children[2].name)
self.assertEqual(
'model_analyzer_testlib.py:67:BuildFullModel (gradient)',
tfprof_node.children[3].name)
self.assertEqual(
'model_analyzer_testlib.py:69:BuildFullModel',
tfprof_node.children[4].name)
self.assertEqual(
'model_analyzer_testlib.py:70:BuildFullModel',
tfprof_node.children[5].name)
self.assertEqual(
'model_analyzer_testlib.py:70:BuildFullModel (gradient)',
tfprof_node.children[6].name)
self.assertEqual(
'model_analyzer_testlib.py:72:BuildFullModel',
tfprof_node.children[7].name)
开发者ID:andrewharp,项目名称:tensorflow,代码行数:60,代码来源:model_analyzer_test.py
示例9: testAssets
def testAssets(self):
export_dir = self._get_export_dir("test_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection.
ignored_filepath = os.path.join(
compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
file_io.write_string_to_file(ignored_filepath, "will be ignored")
asset_collection = self._build_asset_collection("hello42.txt",
"foo bar baz",
"asset_file_tensor")
builder.add_meta_graph_and_variables(
sess, ["foo"], assets_collection=asset_collection)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
"asset_file_tensor:0")
ignored_asset_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.ASSETS_DIRECTORY),
compat.as_bytes("ignored.txt"))
self.assertFalse(file_io.file_exists(ignored_asset_path))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:32,代码来源:saved_model_test.py
示例10: testNoShuffle
def testNoShuffle(self):
filenames = ['a', 'b', 'c']
self._touchTempFiles(filenames)
# Repeat the list twice and ensure that the order is the same each time.
# NOTE(mrry): This depends on an implementation detail of `list_files()`,
# which is that the list of files is captured when the iterator is
# initialized. Otherwise, or if e.g. the iterator were initialized more than
# once, it's possible that the non-determinism of `tf.matching_files()`
# would cause this test to fail. However, it serves as a useful confirmation
# that the `shuffle=False` argument is working as intended.
# TODO(b/73959787): Provide some ordering guarantees so that this test is
# more meaningful.
dataset = dataset_ops.Dataset.list_files(
path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
with self.cached_session() as sess:
itr = dataset.make_one_shot_iterator()
next_element = itr.get_next()
full_filenames = []
produced_filenames = []
for filename in filenames * 2:
full_filenames.append(
compat.as_bytes(path.join(self.tmp_dir, filename)))
produced_filenames.append(compat.as_bytes(sess.run(next_element)))
with self.assertRaises(errors.OutOfRangeError):
sess.run(itr.get_next())
self.assertItemsEqual(full_filenames, produced_filenames)
self.assertEqual(produced_filenames[:len(filenames)],
produced_filenames[len(filenames):])
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:30,代码来源:list_files_dataset_op_test.py
示例11: testSaveAsText
def testSaveAsText(self):
export_dir = os.path.join(
compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("astext"))
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(42, name="v")
sess.run(tf.initialize_all_variables())
self.assertEqual(42, v.eval())
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(43, name="v")
sess.run(tf.initialize_all_variables())
self.assertEqual(43, v.eval())
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
开发者ID:apollos,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py
示例12: testFileMiddles
def testFileMiddles(self):
filenames = ['a.txt', 'b.py', 'c.pyc']
self._touchTempFiles(filenames)
filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
dataset = dataset_ops.Dataset.list_files(filename_placeholder)
with self.cached_session() as sess:
itr = dataset.make_initializable_iterator()
next_element = itr.get_next()
sess.run(
itr.initializer,
feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')})
full_filenames = []
produced_filenames = []
for filename in filenames[1:]:
full_filenames.append(
compat.as_bytes(path.join(self.tmp_dir, filename)))
produced_filenames.append(compat.as_bytes(sess.run(next_element)))
self.assertItemsEqual(full_filenames, produced_filenames)
with self.assertRaises(errors.OutOfRangeError):
sess.run(itr.get_next())
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:25,代码来源:list_files_dataset_op_test.py
示例13: _AddOpInternal
def _AddOpInternal(self, op):
# pylint: disable=protected-access
if op.type in _BLACKLISTED_OPS:
logging.error("Operation of type %s (%s) is not supported on the TPU. "
"Execution will fail if this op is used in the graph. " %
(op.type, op.name))
if op.type in _NOT_IMPLEMENTED_OPS:
self._unsupported_ops.append(op)
if any(x.dtype._is_ref_dtype for x in op.inputs):
raise NotImplementedError(
"Non-resource Variables are not supported inside TPU computations "
"(operator name: %s)" % op.name)
if _TPU_REPLICATE_ATTR in op.node_def.attr:
raise ValueError("TPU computations cannot be nested")
op._set_attr(_TPU_REPLICATE_ATTR,
attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
if self._outside_compilation_cluster:
op._set_attr(
_OUTSIDE_COMPILATION_ATTR,
attr_value_pb2.AttrValue(
s=compat.as_bytes(self._outside_compilation_cluster)))
if self._num_replicas > 1 or not self._outside_compilation_cluster:
# Prevent feeding or fetching anything that is being compiled,
# and any replicated outside_compilation Op.
op.graph.prevent_feeding(op)
op.graph.prevent_fetching(op)
开发者ID:jinxin0924,项目名称:tensorflow,代码行数:28,代码来源:tpu.py
示例14: get_matching_files_v2
def get_matching_files_v2(pattern):
"""Returns a list of files that match the given pattern(s).
Args:
pattern: string or iterable of strings. The glob pattern(s).
Returns:
A list of strings containing filenames that match the given pattern(s).
Raises:
errors.OpError: If there are filesystem / directory listing errors.
"""
with errors.raise_exception_on_not_ok_status() as status:
if isinstance(pattern, six.string_types):
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
compat.as_bytes(pattern), status)
]
else:
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
for single_filename in pattern
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
compat.as_bytes(single_filename), status)
]
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:28,代码来源:file_io.py
示例15: _save_and_write_assets
def _save_and_write_assets(self, assets_collection_to_add=None):
"""Saves asset to the meta graph and writes asset files to disk.
Args:
assets_collection_to_add: The collection where the asset paths are setup.
"""
asset_filename_map = _maybe_save_assets(assets_collection_to_add)
# Return if there are no assets to write.
if not asset_filename_map:
tf_logging.info("No assets to write.")
return
assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
self._export_dir)
# Copy each asset from source path to destination path.
for asset_basename, asset_source_filepath in asset_filename_map.items():
asset_destination_filepath = os.path.join(
compat.as_bytes(assets_destination_dir),
compat.as_bytes(asset_basename))
# Only copy the asset file to the destination if it does not already
# exist. This is to ensure that an asset with the same name defined as
# part of multiple graphs is only copied the first time.
if not file_io.file_exists(asset_destination_filepath):
file_io.copy(asset_source_filepath, asset_destination_filepath)
tf_logging.info("Assets written to: %s",
compat.as_text(assets_destination_dir))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:30,代码来源:builder_impl.py
示例16: testReadFileIgnoreError
def testReadFileIgnoreError(self):
def write_string_to_file(value, filename):
with open(filename, "w") as f:
f.write(value)
filenames = [os.path.join(self.get_temp_dir(), "file_%d.txt" % i)
for i in range(5)]
for filename in filenames:
write_string_to_file(filename, filename)
dataset = (dataset_ops.Dataset.from_tensor_slices(filenames)
.map(io_ops.read_file, num_threads=2, output_buffer_size=2)
.ignore_errors())
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
# All of the files are present.
sess.run(init_op)
for filename in filenames:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Delete one of the files.
os.remove(filenames[0])
# Attempting to read filenames[0] will fail, but ignore_errors()
# will catch the error.
sess.run(init_op)
for filename in filenames[1:]:
self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:34,代码来源:map_dataset_op_test.py
示例17: _shouldResolve
def _shouldResolve(self):
if (self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('/bns')) or
self._tpu.startswith(compat.as_bytes('grpc://'))):
return False
return True
开发者ID:ebrevdo,项目名称:tensorflow,代码行数:7,代码来源:tpu_cluster_resolver.py
示例18: testBasic
def testBasic(self):
base_path = tf.test.test_src_dir_path(
"contrib/session_bundle/example/half_plus_two/00000123")
tf.reset_default_graph()
sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))
self.assertTrue(sess)
asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
with sess.as_default():
path1, path2 = sess.run(["filename1:0", "filename2:0"])
self.assertEqual(
compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
self.assertEqual(
compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)
collection_def = meta_graph_def.collection_def
signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
self.assertEquals(len(signatures_any), 1)
signatures = manifest_pb2.Signatures()
signatures_any[0].Unpack(signatures)
self._checkRegressionSignature(signatures, sess)
self._checkNamedSigantures(signatures, sess)
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:25,代码来源:session_bundle_test.py
示例19: tf_record_iterator
def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
Args:
path: The path to the TFRecords file.
options: (optional) A TFRecordOptions object.
Yields:
Strings.
Raises:
IOError: If `path` cannot be opened for reading.
"""
compression_type = TFRecordOptions.get_compression_type_string(options)
with errors.raise_exception_on_not_ok_status() as status:
reader = pywrap_tensorflow.PyRecordReader_New(
compat.as_bytes(path), 0, compat.as_bytes(compression_type), status)
if reader is None:
raise IOError("Could not open %s." % path)
while True:
try:
with errors.raise_exception_on_not_ok_status() as status:
reader.GetNext(status)
except errors.OutOfRangeError:
break
yield reader.record()
reader.Close()
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:28,代码来源:tf_record.py
示例20: export_fn
def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
"""A wrapper to export to SavedModel, and convert it to other formats."""
result_dir = base_strategy.export(estimator, export_dir,
checkpoint_path,
eval_result)
with ops.Graph().as_default() as graph:
with tf_session.Session(graph=graph) as sess:
saved_model_loader.load(
sess, [tag_constants.SERVING], result_dir)
# Note: This is GTFlow internal API and might change.
ensemble_model = graph.get_operation_by_name(
"ensemble_model/TreeEnsembleSerialize")
_, dfec_str = sess.run(ensemble_model.outputs)
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
dtec.ParseFromString(dfec_str)
# Export the result in the same folder as the saved model.
if convert_fn:
convert_fn(dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices),
len(sparse_int_indices), result_dir, eval_result)
feature_importances = _get_feature_importances(
dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices))
sorted_by_importance = sorted(
feature_importances.items(), key=lambda x: -x[1])
assets_dir = os.path.join(
compat.as_bytes(result_dir), compat.as_bytes("assets.extra"))
gfile.MakeDirs(assets_dir)
with gfile.GFile(os.path.join(
compat.as_bytes(assets_dir),
compat.as_bytes("feature_importances")), "w") as f:
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
return result_dir
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:35,代码来源:custom_export_strategy.py
注:本文中的tensorflow.python.util.compat.as_bytes函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论