本文整理汇总了Python中tensorflow.python.lib.io.file_io.get_matching_files函数的典型用法代码示例。如果您正苦于以下问题:Python get_matching_files函数的具体用法?Python get_matching_files怎么用?Python get_matching_files使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_matching_files函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: latest_checkpoint
def latest_checkpoint(checkpoint_dir, latest_filename=None):
"""Finds the filename of latest saved checkpoint file.
Args:
checkpoint_dir: Directory where the variables were saved.
latest_filename: Optional name for the protocol buffer file that
contains the list of most recent checkpoint filenames.
See the corresponding argument to `Saver.save()`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was found.
"""
# Pick the latest checkpoint based on checkpoint state.
ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
if ckpt and ckpt.model_checkpoint_path:
# Look for either a V2 path or a V1 path, with priority for V2.
v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V2)
v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
saver_pb2.SaverDef.V1)
if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
v1_path):
return ckpt.model_checkpoint_path
else:
logging.error("Couldn't match files for checkpoint %s",
ckpt.model_checkpoint_path)
return None
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:checkpoint_management.py
示例2: get_train_eval_files
def get_train_eval_files(input_dir):
"""Get preprocessed training and eval files."""
data_dir = _get_latest_data_dir(input_dir)
train_pattern = os.path.join(data_dir, 'train*.tfrecord.gz')
eval_pattern = os.path.join(data_dir, 'eval*.tfrecord.gz')
train_files = file_io.get_matching_files(train_pattern)
eval_files = file_io.get_matching_files(eval_pattern)
return train_files, eval_files
开发者ID:googledatalab,项目名称:pydatalab,代码行数:8,代码来源:_util.py
示例3: _GetBaseApiMap
def _GetBaseApiMap(self):
"""Get a map from graph op name to its base ApiDef.
Returns:
Dictionary mapping graph op name to corresponding ApiDef.
"""
# Convert base ApiDef in Multiline format to Proto format.
converted_base_api_dir = os.path.join(
test.get_temp_dir(), 'temp_base_api_defs')
subprocess.check_call(
[os.path.join(resource_loader.get_root_dir_with_all_resources(),
_CONVERT_FROM_MULTILINE_SCRIPT),
_BASE_API_DIR, converted_base_api_dir])
name_to_base_api_def = {}
base_api_files = file_io.get_matching_files(
os.path.join(converted_base_api_dir, 'api_def_*.pbtxt'))
for base_api_file in base_api_files:
if file_io.file_exists(base_api_file):
api_defs = api_def_pb2.ApiDefs()
text_format.Merge(
file_io.read_file_to_string(base_api_file), api_defs)
for api_def in api_defs.op:
name_to_base_api_def[api_def.graph_op_name] = api_def
return name_to_base_api_def
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:api_compatibility_test.py
示例4: testAPIBackwardsCompatibility
def testAPIBackwardsCompatibility(self):
# Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
traverse.traverse(tf, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
expression = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*'))
golden_file_list = file_io.get_matching_files(expression)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
ret_val = api_objects_pb2.TFAPIObject()
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
return ret_val
golden_proto_dict = {
_FileNameToKey(filename): _ReadFileToProto(filename)
for filename in golden_file_list
}
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
golden_proto_dict,
proto_dict,
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:35,代码来源:api_compatibility_test.py
示例5: _batch_predict
def _batch_predict(args, cell):
if args['cloud_config'] and not args['cloud']:
raise ValueError('"cloud_config" is provided but no "--cloud". '
'Do you want local run or cloud run?')
if args['cloud']:
parts = args['model'].split('.')
if len(parts) != 2:
raise ValueError('Invalid model name for cloud prediction. Use "model.version".')
version_name = ('projects/%s/models/%s/versions/%s' %
(Context.default().project_id, parts[0], parts[1]))
cloud_config = args['cloud_config'] or {}
job_id = cloud_config.pop('job_id', None)
job_request = {
'version_name': version_name,
'data_format': 'TEXT',
'input_paths': file_io.get_matching_files(args['prediction_data']['csv']),
'output_path': args['output'],
}
job_request.update(cloud_config)
job = datalab_ml.Job.submit_batch_prediction(job_request, job_id)
_show_job_link(job)
else:
print('local prediction...')
_local_predict.local_batch_predict(args['model'],
args['prediction_data']['csv'],
args['output'],
args['format'],
args['batch_size'])
print('done.')
开发者ID:javiervicho,项目名称:pydatalab,代码行数:32,代码来源:_ml.py
示例6: raw_training_input_fn
def raw_training_input_fn():
"""Training input function that reads raw data and applies transforms."""
if isinstance(raw_data_file_pattern, six.string_types):
filepath_list = [raw_data_file_pattern]
else:
filepath_list = raw_data_file_pattern
files = []
for path in filepath_list:
files.extend(file_io.get_matching_files(path))
filename_queue = tf.train.string_input_producer(
files, num_epochs=num_epochs, shuffle=randomize_input)
csv_id, csv_lines = tf.TextLineReader().read_up_to(filename_queue, training_batch_size)
queue_capacity = (reader_num_threads + 3) * training_batch_size + min_after_dequeue
if randomize_input:
_, batch_csv_lines = tf.train.shuffle_batch(
tensors=[csv_id, csv_lines],
batch_size=training_batch_size,
capacity=queue_capacity,
min_after_dequeue=min_after_dequeue,
enqueue_many=True,
num_threads=reader_num_threads,
allow_smaller_final_batch=allow_smaller_final_batch)
else:
_, batch_csv_lines = tf.train.batch(
tensors=[csv_id, csv_lines],
batch_size=training_batch_size,
capacity=queue_capacity,
enqueue_many=True,
num_threads=reader_num_threads,
allow_smaller_final_batch=allow_smaller_final_batch)
csv_header, record_defaults = csv_header_and_defaults(features, schema, stats, keep_target=True)
parsed_tensors = tf.decode_csv(batch_csv_lines, record_defaults, name='csv_to_tensors')
raw_features = dict(zip(csv_header, parsed_tensors))
transform_fn = make_preprocessing_fn(analysis_output_dir, features, keep_target=True)
transformed_tensors = transform_fn(raw_features)
# Expand the dims of non-sparse tensors. This is needed by tf.learn.
transformed_features = {}
for k, v in six.iteritems(transformed_tensors):
if isinstance(v, tf.Tensor) and v.get_shape().ndims == 1:
transformed_features[k] = tf.expand_dims(v, -1)
else:
transformed_features[k] = v
# Remove the target tensor, and return it directly
target_name = get_target_name(features)
if not target_name or target_name not in transformed_features:
raise ValueError('Cannot find target transform in features')
transformed_target = transformed_features.pop(target_name)
return transformed_features, transformed_target
开发者ID:parthea,项目名称:pydatalab,代码行数:60,代码来源:feature_transforms.py
示例7: create_object_test
def create_object_test():
"""Verifies file_io's object manipulation methods ."""
starttime = int(round(time.time() * 1000))
dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime)
print("Creating dir %s." % dir_name)
file_io.create_dir(dir_name)
# Create a file in this directory.
file_name = "%s/test_file.txt" % dir_name
print("Creating file %s." % file_name)
file_io.write_string_to_file(file_name, "test file creation.")
list_files_pattern = "%s/test_file*.txt" % dir_name
print("Getting files matching pattern %s." % list_files_pattern)
files_list = file_io.get_matching_files(list_files_pattern)
print(files_list)
assert len(files_list) == 1
assert files_list[0] == file_name
# Cleanup test files.
print("Deleting file %s." % file_name)
file_io.delete_file(file_name)
# Delete directory.
print("Deleting directory %s." % dir_name)
file_io.delete_recursively(dir_name)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:27,代码来源:gcs_smoke.py
示例8: testGetMatchingFiles
def testGetMatchingFiles(self):
dir_path = os.path.join(self._base_dir, "temp_dir")
file_io.create_dir(dir_path)
files = ["file1.txt", "file2.txt", "file3.txt"]
for name in files:
file_path = os.path.join(dir_path, name)
file_io.FileIO(file_path, mode="w").write("testing")
expected_match = [os.path.join(dir_path, name) for name in files]
self.assertItemsEqual(
file_io.get_matching_files(os.path.join(dir_path, "file*.txt")),
expected_match)
self.assertItemsEqual(file_io.get_matching_files(tuple()), [])
files_subset = [
os.path.join(dir_path, files[0]), os.path.join(dir_path, files[2])
]
self.assertItemsEqual(
file_io.get_matching_files(files_subset), files_subset)
file_io.delete_recursively(dir_path)
self.assertFalse(file_io.file_exists(os.path.join(dir_path, "file3.txt")))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:19,代码来源:file_io_test.py
示例9: _run_transform
def _run_transform(self):
"""Runs DataFlow for makint tf.example files.
Only the train file uses DataFlow, the eval file runs beam locally to save
time.
"""
cloud = True
extra_args = []
if cloud:
extra_args = ['--cloud',
'--job-name=test-mltoolbox-df-%s' % uuid.uuid4().hex,
'--project-id=%s' % self._get_default_project_id(),
'--num-workers=3']
cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'),
'--csv=' + self._csv_train_filename,
'--analysis=' + self._analysis_output,
'--prefix=features_train',
'--output=' + self._transform_output,
'--shuffle'] + extra_args
self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
subprocess.check_call(' '.join(cmd), shell=True)
# Don't wate time running a 2nd DF job, run it locally.
cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'),
'--csv=' + self._csv_eval_filename,
'--analysis=' + self._analysis_output,
'--prefix=features_eval',
'--output=' + self._transform_output]
self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
subprocess.check_call(' '.join(cmd), shell=True)
# Check the files were made
train_files = file_io.get_matching_files(
os.path.join(self._transform_output, 'features_train*'))
eval_files = file_io.get_matching_files(
os.path.join(self._transform_output, 'features_eval*'))
self.assertNotEqual([], train_files)
self.assertNotEqual([], eval_files)
开发者ID:googledatalab,项目名称:pydatalab,代码行数:41,代码来源:test_cloud_workflow.py
示例10: testGetMatchingFiles
def testGetMatchingFiles(self):
dir_path = os.path.join(self._base_dir, "temp_dir")
file_io.create_dir(dir_path)
files = ["file1.txt", "file2.txt", "file3.txt"]
for name in files:
file_path = os.path.join(dir_path, name)
file_io.write_string_to_file(file_path, "testing")
expected_match = [os.path.join(dir_path, name) for name in files]
self.assertItemsEqual(
file_io.get_matching_files(os.path.join(dir_path, "file*.txt")),
expected_match)
file_io.delete_recursively(dir_path)
self.assertFalse(file_io.file_exists(os.path.join(dir_path, "file3.txt")))
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:13,代码来源:file_io_test.py
示例11: checkpoint_exists
def checkpoint_exists(checkpoint_prefix):
"""Checks whether a V1 or V2 checkpoint exists with the specified prefix.
This is the recommended way to check if a checkpoint exists, since it takes
into account the naming difference between V1 and V2 formats.
Args:
checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
Returns:
A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
"""
pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
saver_pb2.SaverDef.V2)
if file_io.get_matching_files(pathname):
return True
elif file_io.get_matching_files(checkpoint_prefix):
return True
else:
return False
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:22,代码来源:checkpoint_management.py
示例12: testNewAPIBackwardsCompatibility
def testNewAPIBackwardsCompatibility(self):
# Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
# TODO(annarev): Make slide_dataset available in API.
public_api_visitor.private_map['tf'] = ['slide_dataset']
traverse.traverse(api, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
expression = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*'))
golden_file_list = file_io.get_matching_files(expression)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
ret_val = api_objects_pb2.TFAPIObject()
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
return ret_val
golden_proto_dict = {
_FileNameToKey(filename): _ReadFileToProto(filename)
for filename in golden_file_list
}
# user_ops is an empty module. It is currently available in TensorFlow API
# but we don't keep empty modules in the new API.
# We delete user_ops from golden_proto_dict to make sure assert passes
# when diffing new API against goldens.
# TODO(annarev): remove user_ops from goldens once we switch to new API.
tf_module = golden_proto_dict['tensorflow'].tf_module
for i in range(len(tf_module.member)):
if tf_module.member[i].name == 'user_ops':
del tf_module.member[i]
break
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
golden_proto_dict,
proto_dict,
verbose=FLAGS.verbose_diffs,
update_goldens=False,
additional_missing_object_message=
'Check if tf_export decorator/call is missing for this symbol.')
开发者ID:PuchatekwSzortach,项目名称:tensorflow,代码行数:50,代码来源:api_compatibility_test.py
示例13: _checkBackwardsCompatibility
def _checkBackwardsCompatibility(self,
root,
golden_file_pattern,
api_version,
additional_private_map=None,
omit_golden_symbols_map=None):
# Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.private_map['tf'] = ['contrib']
if api_version == 2:
public_api_visitor.private_map['tf'].append('enable_v2_behavior')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
if FLAGS.only_test_core_api:
public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
if additional_private_map:
public_api_visitor.private_map.update(additional_private_map)
traverse.traverse(root, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
golden_file_list = file_io.get_matching_files(golden_file_pattern)
if FLAGS.only_test_core_api:
golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
ret_val = api_objects_pb2.TFAPIObject()
text_format.Merge(file_io.read_file_to_string(filename), ret_val)
return ret_val
golden_proto_dict = {
_FileNameToKey(filename): _ReadFileToProto(filename)
for filename in golden_file_list
}
golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
omit_golden_symbols_map)
# Diff them. Do not fail if called with update.
# If the test is run to update goldens, only report diffs but do not fail.
self._AssertProtoDictEquals(
golden_proto_dict,
proto_dict,
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens,
api_version=api_version)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:49,代码来源:api_compatibility_test.py
示例14: copy_data_to_tmp
def copy_data_to_tmp(input_files):
"""Copies data to /tmp/ and returns glob matching the files."""
files = []
for e in input_files:
for path in e.split(','):
files.extend(file_io.get_matching_files(path))
for path in files:
if not path.startswith('gs://'):
return input_files
tmp_path = os.path.join('/tmp/', str(uuid.uuid4()))
os.makedirs(tmp_path)
subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path])
return [os.path.join(tmp_path, '*')]
开发者ID:cottrell,项目名称:notebooks,代码行数:15,代码来源:task.py
示例15: get_latest_checkpoint
def get_latest_checkpoint():
index_files = file_io.get_matching_files(os.path.join(FLAGS.train_dir, 'model.ckpt-*.index'))
# No files
if not index_files:
return None
# Index file path with the maximum step size.
latest_index_file = sorted(
[(int(os.path.basename(f).split("-")[-1].split(".")[0]), f)
for f in index_files])[-1][1]
# Chop off .index suffix and return
return latest_index_file[:-6]
开发者ID:vijayky88,项目名称:youtube-8m,代码行数:15,代码来源:eval.py
示例16: read_examples
def read_examples(input_files, batch_size, shuffle, num_epochs=None):
"""Creates readers and queues for reading example protos."""
files = []
for e in input_files:
for path in e.split(','):
files.extend(file_io.get_matching_files(path))
thread_count = multiprocessing.cpu_count()
# The minimum number of instances in a queue from which examples are drawn
# randomly. The larger this number, the more randomness at the expense of
# higher memory requirements.
min_after_dequeue = 1000
# When batching data, the queue's capacity will be larger than the batch_size
# by some factor. The recommended formula is (num_threads + a small safety
# margin). For now, we use a single thread for reading, so this can be small.
queue_size_multiplier = thread_count + 3
# Convert num_epochs == 0 -> num_epochs is None, if necessary
num_epochs = num_epochs or None
# Build a queue of the filenames to be read.
filename_queue = tf.train.string_input_producer(files, num_epochs, shuffle)
options = tf.python_io.TFRecordOptions(
compression_type=tf.python_io.TFRecordCompressionType.GZIP)
example_id, encoded_example = tf.TFRecordReader(options=options).read_up_to(
filename_queue, batch_size)
if shuffle:
capacity = min_after_dequeue + queue_size_multiplier * batch_size
return tf.train.shuffle_batch(
[example_id, encoded_example],
batch_size,
capacity,
min_after_dequeue,
enqueue_many=True,
num_threads=thread_count)
else:
capacity = queue_size_multiplier * batch_size
return tf.train.batch(
[example_id, encoded_example],
batch_size,
capacity=capacity,
enqueue_many=True,
num_threads=thread_count)
开发者ID:amygdala,项目名称:tensorflow-workshop,代码行数:47,代码来源:util.py
示例17: local_batch_predict
def local_batch_predict(model_dir, csv_file_pattern, output_dir, output_format, batch_size=100):
""" Batch Predict with a specified model.
It does batch prediction, saves results to output files and also creates an output
schema file. The output file names are input file names prepended by 'predict_results_'.
Args:
model_dir: The model directory containing a SavedModel (usually saved_model.pb).
csv_file_pattern: a pattern of csv files as batch prediction source.
output_dir: the path of the output directory.
output_format: csv or json.
batch_size: Larger batch_size improves performance but may
cause more memory usage.
"""
file_io.recursive_create_dir(output_dir)
csv_files = file_io.get_matching_files(csv_file_pattern)
if len(csv_files) == 0:
raise ValueError('No files found given ' + csv_file_pattern)
with tf.Graph().as_default(), tf.Session() as sess:
input_alias_map, output_alias_map = _tf_load_model(sess, model_dir)
csv_tensor_name = list(input_alias_map.values())[0]
output_schema = _get_output_schema(sess, output_alias_map)
for csv_file in csv_files:
output_file = os.path.join(
output_dir,
'predict_results_' +
os.path.splitext(os.path.basename(csv_file))[0] + '.' + output_format)
with file_io.FileIO(output_file, 'w') as f:
prediction_source = _batch_csv_reader(csv_file, batch_size)
for batch in prediction_source:
batch = [l.rstrip() for l in batch if l]
predict_results = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: batch})
formatted_results = _format_results(output_format, output_schema, predict_results)
f.write('\n'.join(formatted_results) + '\n')
file_io.write_string_to_file(os.path.join(output_dir, 'predict_results_schema.json'),
json.dumps(output_schema, indent=2))
开发者ID:googledatalab,项目名称:pydatalab,代码行数:39,代码来源:_local_predict.py
示例18: testMatchingFilesPermission
def testMatchingFilesPermission(self):
# Create top level directory test_dir.
dir_path = os.path.join(self._base_dir, "test_dir")
file_io.create_dir(dir_path)
# Create second level directories `noread` and `any`.
noread_path = os.path.join(dir_path, "noread")
file_io.create_dir(noread_path)
any_path = os.path.join(dir_path, "any")
file_io.create_dir(any_path)
files = ["file1.txt", "file2.txt", "file3.txt"]
for name in files:
file_path = os.path.join(any_path, name)
file_io.FileIO(file_path, mode="w").write("testing")
file_path = os.path.join(noread_path, "file4.txt")
file_io.FileIO(file_path, mode="w").write("testing")
# Change noread to noread access.
os.chmod(noread_path, 0)
expected_match = [os.path.join(any_path, name) for name in files]
self.assertItemsEqual(
file_io.get_matching_files(os.path.join(dir_path, "*", "file*.txt")),
expected_match)
# Change noread back so that it could be cleaned during tearDown.
os.chmod(noread_path, 0o777)
开发者ID:aritratony,项目名称:tensorflow,代码行数:23,代码来源:file_io_test.py
示例19: _run_batch_prediction
def _run_batch_prediction(self):
"""Run batch prediction using the cloudml engine prediction service.
There is no local version of this step as it's the last step.
"""
job_name = 'test_mltoolbox_batchprediction_%s' % uuid.uuid4().hex
cmd = ['gcloud ml-engine jobs submit prediction ' + job_name,
'--data-format=TEXT',
'--input-paths=' + self._csv_predict_filename,
'--output-path=' + self._prediction_output,
'--model-dir=' + os.path.join(self._train_output, 'model'),
'--runtime-version=1.0',
'--region=us-central1']
self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
subprocess.check_call(' '.join(cmd), shell=True) # async call.
subprocess.check_call('gcloud ml-engine jobs stream-logs ' + job_name, shell=True)
# check that there was no errors.
error_files = file_io.get_matching_files(
os.path.join(self._prediction_output, 'prediction.errors_stats*'))
self.assertEqual(1, len(error_files))
error_str = file_io.read_file_to_string(error_files[0])
self.assertEqual('', error_str)
开发者ID:googledatalab,项目名称:pydatalab,代码行数:24,代码来源:test_cloud_workflow.py
示例20: load_session_bundle_from_path
def load_session_bundle_from_path(export_dir,
target="",
config=None,
meta_graph_def=None):
"""Load session bundle from the given path.
The function reads input from the export_dir, constructs the graph data to the
default graph and restores the parameters for the session created.
Args:
export_dir: the directory that contains files exported by exporter.
target: The execution engine to connect to. See target in
tf.compat.v1.Session()
config: A ConfigProto proto with configuration options. See config in
tf.compat.v1.Session()
meta_graph_def: optional object of type MetaGraphDef. If this object is
present, then it is used instead of parsing MetaGraphDef from export_dir.
Returns:
session: a tensorflow session created from the variable files.
meta_graph: a meta graph proto saved in the exporter directory.
Raises:
RuntimeError: if the required files are missing or contain unrecognizable
fields, i.e. the exported model is invalid.
"""
if not meta_graph_def:
meta_graph_filename = os.path.join(export_dir,
constants.META_GRAPH_DEF_FILENAME)
if not file_io.file_exists(meta_graph_filename):
raise RuntimeError("Expected meta graph file missing %s" %
meta_graph_filename)
# Reads meta graph file.
meta_graph_def = meta_graph_pb2.MetaGraphDef()
meta_graph_def.ParseFromString(
file_io.read_file_to_string(meta_graph_filename, binary_mode=True))
variables_filename = ""
variables_filename_list = []
checkpoint_sharded = False
variables_index_filename = os.path.join(export_dir,
constants.VARIABLES_INDEX_FILENAME_V2)
checkpoint_v2 = file_io.file_exists(variables_index_filename)
# Find matching checkpoint files.
if checkpoint_v2:
# The checkpoint is in v2 format.
variables_filename_pattern = os.path.join(
export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
variables_filename_list = file_io.get_matching_files(
variables_filename_pattern)
checkpoint_sharded = True
else:
variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
if file_io.file_exists(variables_filename):
variables_filename_list = [variables_filename]
else:
variables_filename = os.path.join(export_dir,
constants.VARIABLES_FILENAME_PATTERN)
variables_filename_list = file_io.get_matching_files(variables_filename)
checkpoint_sharded = True
# Prepare the files to restore a session.
if not variables_filename_list:
restore_files = ""
elif checkpoint_v2 or not checkpoint_sharded:
# For checkpoint v2 or v1 with non-sharded files, use "export" to restore
# the session.
restore_files = constants.VARIABLES_FILENAME
else:
restore_files = constants.VARIABLES_FILENAME_PATTERN
assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)
collection_def = meta_graph_def.collection_def
graph_def = graph_pb2.GraphDef()
if constants.GRAPH_KEY in collection_def:
# Use serving graph_def in MetaGraphDef collection_def if exists
graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
if len(graph_def_any) != 1:
raise RuntimeError("Expected exactly one serving GraphDef in : %s" %
meta_graph_def)
else:
graph_def_any[0].Unpack(graph_def)
# Replace the graph def in meta graph proto.
meta_graph_def.graph_def.CopyFrom(graph_def)
ops.reset_default_graph()
sess = session.Session(target, graph=None, config=config)
# Import the graph.
saver = saver_lib.import_meta_graph(meta_graph_def)
# Restore the session.
if restore_files:
saver.restore(sess, os.path.join(export_dir, restore_files))
init_op_tensor = None
if constants.INIT_OP_KEY in collection_def:
init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
if len(init_ops) != 1:
#.........这里部分代码省略.........
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:101,代码来源:session_bundle.py
注:本文中的tensorflow.python.lib.io.file_io.get_matching_files函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论