• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python file_io.get_matching_files函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.lib.io.file_io.get_matching_files函数的典型用法代码示例。如果您正苦于以下问题:Python get_matching_files函数的具体用法?Python get_matching_files怎么用?Python get_matching_files使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_matching_files函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: latest_checkpoint

def latest_checkpoint(checkpoint_dir, latest_filename=None):
  """Finds the filename of latest saved checkpoint file.

  Args:
    checkpoint_dir: Directory where the variables were saved.
    latest_filename: Optional name for the protocol buffer file that
      contains the list of most recent checkpoint filenames.
      See the corresponding argument to `Saver.save()`.

  Returns:
    The full path to the latest checkpoint or `None` if no checkpoint was found.
  """
  # Pick the latest checkpoint based on checkpoint state.
  ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
  if ckpt and ckpt.model_checkpoint_path:
    # Look for either a V2 path or a V1 path, with priority for V2.
    v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
                                         saver_pb2.SaverDef.V2)
    v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
                                         saver_pb2.SaverDef.V1)
    if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
        v1_path):
      return ckpt.model_checkpoint_path
    else:
      logging.error("Couldn't match files for checkpoint %s",
                    ckpt.model_checkpoint_path)
  return None
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:checkpoint_management.py


示例2: get_train_eval_files

def get_train_eval_files(input_dir):
  """Get preprocessed training and eval files."""
  data_dir = _get_latest_data_dir(input_dir)
  train_pattern = os.path.join(data_dir, 'train*.tfrecord.gz')
  eval_pattern = os.path.join(data_dir, 'eval*.tfrecord.gz')
  train_files = file_io.get_matching_files(train_pattern)
  eval_files = file_io.get_matching_files(eval_pattern)
  return train_files, eval_files
开发者ID:googledatalab,项目名称:pydatalab,代码行数:8,代码来源:_util.py


示例3: _GetBaseApiMap

  def _GetBaseApiMap(self):
    """Get a map from graph op name to its base ApiDef.

    Returns:
      Dictionary mapping graph op name to corresponding ApiDef.
    """
    # Convert base ApiDef in Multiline format to Proto format.
    converted_base_api_dir = os.path.join(
        test.get_temp_dir(), 'temp_base_api_defs')
    subprocess.check_call(
        [os.path.join(resource_loader.get_root_dir_with_all_resources(),
                      _CONVERT_FROM_MULTILINE_SCRIPT),
         _BASE_API_DIR, converted_base_api_dir])

    name_to_base_api_def = {}
    base_api_files = file_io.get_matching_files(
        os.path.join(converted_base_api_dir, 'api_def_*.pbtxt'))
    for base_api_file in base_api_files:
      if file_io.file_exists(base_api_file):
        api_defs = api_def_pb2.ApiDefs()
        text_format.Merge(
            file_io.read_file_to_string(base_api_file), api_defs)
        for api_def in api_defs.op:
          name_to_base_api_def[api_def.graph_op_name] = api_def
    return name_to_base_api_def
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:api_compatibility_test.py


示例4: testAPIBackwardsCompatibility

  def testAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    traverse.traverse(tf, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:35,代码来源:api_compatibility_test.py


示例5: _batch_predict

def _batch_predict(args, cell):
  if args['cloud_config'] and not args['cloud']:
    raise ValueError('"cloud_config" is provided but no "--cloud". '
                     'Do you want local run or cloud run?')

  if args['cloud']:
    parts = args['model'].split('.')
    if len(parts) != 2:
      raise ValueError('Invalid model name for cloud prediction. Use "model.version".')

    version_name = ('projects/%s/models/%s/versions/%s' %
                    (Context.default().project_id, parts[0], parts[1]))

    cloud_config = args['cloud_config'] or {}
    job_id = cloud_config.pop('job_id', None)
    job_request = {
      'version_name': version_name,
      'data_format': 'TEXT',
      'input_paths': file_io.get_matching_files(args['prediction_data']['csv']),
      'output_path': args['output'],
    }
    job_request.update(cloud_config)
    job = datalab_ml.Job.submit_batch_prediction(job_request, job_id)
    _show_job_link(job)
  else:
    print('local prediction...')
    _local_predict.local_batch_predict(args['model'],
                                       args['prediction_data']['csv'],
                                       args['output'],
                                       args['format'],
                                       args['batch_size'])
    print('done.')
开发者ID:javiervicho,项目名称:pydatalab,代码行数:32,代码来源:_ml.py


示例6: raw_training_input_fn

  def raw_training_input_fn():
    """Training input function that reads raw data and applies transforms."""

    if isinstance(raw_data_file_pattern, six.string_types):
      filepath_list = [raw_data_file_pattern]
    else:
      filepath_list = raw_data_file_pattern

    files = []
    for path in filepath_list:
      files.extend(file_io.get_matching_files(path))

    filename_queue = tf.train.string_input_producer(
        files, num_epochs=num_epochs, shuffle=randomize_input)

    csv_id, csv_lines = tf.TextLineReader().read_up_to(filename_queue, training_batch_size)

    queue_capacity = (reader_num_threads + 3) * training_batch_size + min_after_dequeue
    if randomize_input:
      _, batch_csv_lines = tf.train.shuffle_batch(
          tensors=[csv_id, csv_lines],
          batch_size=training_batch_size,
          capacity=queue_capacity,
          min_after_dequeue=min_after_dequeue,
          enqueue_many=True,
          num_threads=reader_num_threads,
          allow_smaller_final_batch=allow_smaller_final_batch)

    else:
      _, batch_csv_lines = tf.train.batch(
          tensors=[csv_id, csv_lines],
          batch_size=training_batch_size,
          capacity=queue_capacity,
          enqueue_many=True,
          num_threads=reader_num_threads,
          allow_smaller_final_batch=allow_smaller_final_batch)

    csv_header, record_defaults = csv_header_and_defaults(features, schema, stats, keep_target=True)
    parsed_tensors = tf.decode_csv(batch_csv_lines, record_defaults, name='csv_to_tensors')
    raw_features = dict(zip(csv_header, parsed_tensors))

    transform_fn = make_preprocessing_fn(analysis_output_dir, features, keep_target=True)
    transformed_tensors = transform_fn(raw_features)

    # Expand the dims of non-sparse tensors. This is needed by tf.learn.
    transformed_features = {}
    for k, v in six.iteritems(transformed_tensors):
      if isinstance(v, tf.Tensor) and v.get_shape().ndims == 1:
        transformed_features[k] = tf.expand_dims(v, -1)
      else:
        transformed_features[k] = v

    # Remove the target tensor, and return it directly
    target_name = get_target_name(features)
    if not target_name or target_name not in transformed_features:
      raise ValueError('Cannot find target transform in features')

    transformed_target = transformed_features.pop(target_name)

    return transformed_features, transformed_target
开发者ID:parthea,项目名称:pydatalab,代码行数:60,代码来源:feature_transforms.py


示例7: create_object_test

def create_object_test():
  """Verifies file_io's object manipulation methods ."""
  starttime = int(round(time.time() * 1000))
  dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime)
  print("Creating dir %s." % dir_name)
  file_io.create_dir(dir_name)

  # Create a file in this directory.
  file_name = "%s/test_file.txt" % dir_name
  print("Creating file %s." % file_name)
  file_io.write_string_to_file(file_name, "test file creation.")

  list_files_pattern = "%s/test_file*.txt" % dir_name
  print("Getting files matching pattern %s." % list_files_pattern)
  files_list = file_io.get_matching_files(list_files_pattern)
  print(files_list)

  assert len(files_list) == 1
  assert files_list[0] == file_name

  # Cleanup test files.
  print("Deleting file %s." % file_name)
  file_io.delete_file(file_name)

  # Delete directory.
  print("Deleting directory %s." % dir_name)
  file_io.delete_recursively(dir_name)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:27,代码来源:gcs_smoke.py


示例8: testGetMatchingFiles

 def testGetMatchingFiles(self):
   dir_path = os.path.join(self._base_dir, "temp_dir")
   file_io.create_dir(dir_path)
   files = ["file1.txt", "file2.txt", "file3.txt"]
   for name in files:
     file_path = os.path.join(dir_path, name)
     file_io.FileIO(file_path, mode="w").write("testing")
   expected_match = [os.path.join(dir_path, name) for name in files]
   self.assertItemsEqual(
       file_io.get_matching_files(os.path.join(dir_path, "file*.txt")),
       expected_match)
   self.assertItemsEqual(file_io.get_matching_files(tuple()), [])
   files_subset = [
       os.path.join(dir_path, files[0]), os.path.join(dir_path, files[2])
   ]
   self.assertItemsEqual(
       file_io.get_matching_files(files_subset), files_subset)
   file_io.delete_recursively(dir_path)
   self.assertFalse(file_io.file_exists(os.path.join(dir_path, "file3.txt")))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:19,代码来源:file_io_test.py


示例9: _run_transform

  def _run_transform(self):
    """Runs DataFlow for makint tf.example files.

    Only the train file uses DataFlow, the eval file runs beam locally to save
    time.
    """
    cloud = True
    extra_args = []
    if cloud:
      extra_args = ['--cloud',
                    '--job-name=test-mltoolbox-df-%s' % uuid.uuid4().hex,
                    '--project-id=%s' % self._get_default_project_id(),
                    '--num-workers=3']

    cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'),
           '--csv=' + self._csv_train_filename,
           '--analysis=' + self._analysis_output,
           '--prefix=features_train',
           '--output=' + self._transform_output,
           '--shuffle'] + extra_args

    self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
    subprocess.check_call(' '.join(cmd), shell=True)

    # Don't wate time running a 2nd DF job, run it locally.
    cmd = ['python %s' % os.path.join(CODE_PATH, 'transform.py'),
           '--csv=' + self._csv_eval_filename,
           '--analysis=' + self._analysis_output,
           '--prefix=features_eval',
           '--output=' + self._transform_output]

    self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
    subprocess.check_call(' '.join(cmd), shell=True)

    # Check the files were made
    train_files = file_io.get_matching_files(
        os.path.join(self._transform_output, 'features_train*'))
    eval_files = file_io.get_matching_files(
        os.path.join(self._transform_output, 'features_eval*'))
    self.assertNotEqual([], train_files)
    self.assertNotEqual([], eval_files)
开发者ID:googledatalab,项目名称:pydatalab,代码行数:41,代码来源:test_cloud_workflow.py


示例10: testGetMatchingFiles

 def testGetMatchingFiles(self):
   dir_path = os.path.join(self._base_dir, "temp_dir")
   file_io.create_dir(dir_path)
   files = ["file1.txt", "file2.txt", "file3.txt"]
   for name in files:
     file_path = os.path.join(dir_path, name)
     file_io.write_string_to_file(file_path, "testing")
   expected_match = [os.path.join(dir_path, name) for name in files]
   self.assertItemsEqual(
       file_io.get_matching_files(os.path.join(dir_path, "file*.txt")),
       expected_match)
   file_io.delete_recursively(dir_path)
   self.assertFalse(file_io.file_exists(os.path.join(dir_path, "file3.txt")))
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:13,代码来源:file_io_test.py


示例11: checkpoint_exists

def checkpoint_exists(checkpoint_prefix):
  """Checks whether a V1 or V2 checkpoint exists with the specified prefix.

  This is the recommended way to check if a checkpoint exists, since it takes
  into account the naming difference between V1 and V2 formats.

  Args:
    checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
  Returns:
    A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
  """
  pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
                                        saver_pb2.SaverDef.V2)
  if file_io.get_matching_files(pathname):
    return True
  elif file_io.get_matching_files(checkpoint_prefix):
    return True
  else:
    return False
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:22,代码来源:checkpoint_management.py


示例12: testNewAPIBackwardsCompatibility

  def testNewAPIBackwardsCompatibility(self):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.do_not_descend_map['tf'].append('contrib')
    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    # TODO(annarev): Make slide_dataset available in API.
    public_api_visitor.private_map['tf'] = ['slide_dataset']
    traverse.traverse(api, public_api_visitor)

    proto_dict = visitor.GetProtos()

    # Read all golden files.
    expression = os.path.join(
        resource_loader.get_root_dir_with_all_resources(),
        _KeyToFilePath('*'))
    golden_file_list = file_io.get_matching_files(expression)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }

    # user_ops is an empty module. It is currently available in TensorFlow API
    # but we don't keep empty modules in the new API.
    # We delete user_ops from golden_proto_dict to make sure assert passes
    # when diffing new API against goldens.
    # TODO(annarev): remove user_ops from goldens once we switch to new API.
    tf_module = golden_proto_dict['tensorflow'].tf_module
    for i in range(len(tf_module.member)):
      if tf_module.member[i].name == 'user_ops':
        del tf_module.member[i]
        break

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=False,
        additional_missing_object_message=
        'Check if tf_export decorator/call is missing for this symbol.')
开发者ID:PuchatekwSzortach,项目名称:tensorflow,代码行数:50,代码来源:api_compatibility_test.py


示例13: _checkBackwardsCompatibility

  def _checkBackwardsCompatibility(self,
                                   root,
                                   golden_file_pattern,
                                   api_version,
                                   additional_private_map=None,
                                   omit_golden_symbols_map=None):
    # Extract all API stuff.
    visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()

    public_api_visitor = public_api.PublicAPIVisitor(visitor)
    public_api_visitor.private_map['tf'] = ['contrib']
    if api_version == 2:
      public_api_visitor.private_map['tf'].append('enable_v2_behavior')

    public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
    if FLAGS.only_test_core_api:
      public_api_visitor.do_not_descend_map['tf'].extend(_NON_CORE_PACKAGES)
    if additional_private_map:
      public_api_visitor.private_map.update(additional_private_map)

    traverse.traverse(root, public_api_visitor)
    proto_dict = visitor.GetProtos()

    # Read all golden files.
    golden_file_list = file_io.get_matching_files(golden_file_pattern)
    if FLAGS.only_test_core_api:
      golden_file_list = _FilterNonCoreGoldenFiles(golden_file_list)

    def _ReadFileToProto(filename):
      """Read a filename, create a protobuf from its contents."""
      ret_val = api_objects_pb2.TFAPIObject()
      text_format.Merge(file_io.read_file_to_string(filename), ret_val)
      return ret_val

    golden_proto_dict = {
        _FileNameToKey(filename): _ReadFileToProto(filename)
        for filename in golden_file_list
    }
    golden_proto_dict = _FilterGoldenProtoDict(golden_proto_dict,
                                               omit_golden_symbols_map)

    # Diff them. Do not fail if called with update.
    # If the test is run to update goldens, only report diffs but do not fail.
    self._AssertProtoDictEquals(
        golden_proto_dict,
        proto_dict,
        verbose=FLAGS.verbose_diffs,
        update_goldens=FLAGS.update_goldens,
        api_version=api_version)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:49,代码来源:api_compatibility_test.py


示例14: copy_data_to_tmp

def copy_data_to_tmp(input_files):
  """Copies data to /tmp/ and returns glob matching the files."""
  files = []
  for e in input_files:
    for path in e.split(','):
      files.extend(file_io.get_matching_files(path))

  for path in files:
    if not path.startswith('gs://'):
      return input_files

  tmp_path = os.path.join('/tmp/', str(uuid.uuid4()))
  os.makedirs(tmp_path)
  subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path])
  return [os.path.join(tmp_path, '*')]
开发者ID:cottrell,项目名称:notebooks,代码行数:15,代码来源:task.py


示例15: get_latest_checkpoint

def get_latest_checkpoint():
  index_files = file_io.get_matching_files(os.path.join(FLAGS.train_dir, 'model.ckpt-*.index'))

  # No files
  if not index_files:
    return None


  # Index file path with the maximum step size.
  latest_index_file = sorted(
      [(int(os.path.basename(f).split("-")[-1].split(".")[0]), f)
       for f in index_files])[-1][1]

  # Chop off .index suffix and return
  return latest_index_file[:-6]
开发者ID:vijayky88,项目名称:youtube-8m,代码行数:15,代码来源:eval.py


示例16: read_examples

def read_examples(input_files, batch_size, shuffle, num_epochs=None):
  """Creates readers and queues for reading example protos."""
  files = []
  for e in input_files:
    for path in e.split(','):
      files.extend(file_io.get_matching_files(path))
  thread_count = multiprocessing.cpu_count()

  # The minimum number of instances in a queue from which examples are drawn
  # randomly. The larger this number, the more randomness at the expense of
  # higher memory requirements.
  min_after_dequeue = 1000

  # When batching data, the queue's capacity will be larger than the batch_size
  # by some factor. The recommended formula is (num_threads + a small safety
  # margin). For now, we use a single thread for reading, so this can be small.
  queue_size_multiplier = thread_count + 3

  # Convert num_epochs == 0 -> num_epochs is None, if necessary
  num_epochs = num_epochs or None

  # Build a queue of the filenames to be read.
  filename_queue = tf.train.string_input_producer(files, num_epochs, shuffle)

  options = tf.python_io.TFRecordOptions(
      compression_type=tf.python_io.TFRecordCompressionType.GZIP)
  example_id, encoded_example = tf.TFRecordReader(options=options).read_up_to(
      filename_queue, batch_size)

  if shuffle:
    capacity = min_after_dequeue + queue_size_multiplier * batch_size
    return tf.train.shuffle_batch(
        [example_id, encoded_example],
        batch_size,
        capacity,
        min_after_dequeue,
        enqueue_many=True,
        num_threads=thread_count)

  else:
    capacity = queue_size_multiplier * batch_size
    return tf.train.batch(
        [example_id, encoded_example],
        batch_size,
        capacity=capacity,
        enqueue_many=True,
        num_threads=thread_count)
开发者ID:amygdala,项目名称:tensorflow-workshop,代码行数:47,代码来源:util.py


示例17: local_batch_predict

def local_batch_predict(model_dir, csv_file_pattern, output_dir, output_format, batch_size=100):
  """ Batch Predict with a specified model.

  It does batch prediction, saves results to output files and also creates an output
  schema file. The output file names are input file names prepended by 'predict_results_'.

  Args:
    model_dir: The model directory containing a SavedModel (usually saved_model.pb).
    csv_file_pattern: a pattern of csv files as batch prediction source.
    output_dir: the path of the output directory.
    output_format: csv or json.
    batch_size: Larger batch_size improves performance but may
        cause more memory usage.
  """

  file_io.recursive_create_dir(output_dir)
  csv_files = file_io.get_matching_files(csv_file_pattern)
  if len(csv_files) == 0:
    raise ValueError('No files found given ' + csv_file_pattern)

  with tf.Graph().as_default(), tf.Session() as sess:
    input_alias_map, output_alias_map = _tf_load_model(sess, model_dir)
    csv_tensor_name = list(input_alias_map.values())[0]
    output_schema = _get_output_schema(sess, output_alias_map)
    for csv_file in csv_files:
      output_file = os.path.join(
          output_dir,
          'predict_results_' +
          os.path.splitext(os.path.basename(csv_file))[0] + '.' + output_format)
      with file_io.FileIO(output_file, 'w') as f:
        prediction_source = _batch_csv_reader(csv_file, batch_size)
        for batch in prediction_source:
          batch = [l.rstrip() for l in batch if l]
          predict_results = sess.run(fetches=output_alias_map, feed_dict={csv_tensor_name: batch})
          formatted_results = _format_results(output_format, output_schema, predict_results)
          f.write('\n'.join(formatted_results) + '\n')

  file_io.write_string_to_file(os.path.join(output_dir, 'predict_results_schema.json'),
                               json.dumps(output_schema, indent=2))
开发者ID:googledatalab,项目名称:pydatalab,代码行数:39,代码来源:_local_predict.py


示例18: testMatchingFilesPermission

 def testMatchingFilesPermission(self):
   # Create top level directory test_dir.
   dir_path = os.path.join(self._base_dir, "test_dir")
   file_io.create_dir(dir_path)
   # Create second level directories `noread` and `any`.
   noread_path = os.path.join(dir_path, "noread")
   file_io.create_dir(noread_path)
   any_path = os.path.join(dir_path, "any")
   file_io.create_dir(any_path)
   files = ["file1.txt", "file2.txt", "file3.txt"]
   for name in files:
     file_path = os.path.join(any_path, name)
     file_io.FileIO(file_path, mode="w").write("testing")
   file_path = os.path.join(noread_path, "file4.txt")
   file_io.FileIO(file_path, mode="w").write("testing")
   # Change noread to noread access.
   os.chmod(noread_path, 0)
   expected_match = [os.path.join(any_path, name) for name in files]
   self.assertItemsEqual(
       file_io.get_matching_files(os.path.join(dir_path, "*", "file*.txt")),
       expected_match)
   # Change noread back so that it could be cleaned during tearDown.
   os.chmod(noread_path, 0o777)
开发者ID:aritratony,项目名称:tensorflow,代码行数:23,代码来源:file_io_test.py


示例19: _run_batch_prediction

  def _run_batch_prediction(self):
    """Run batch prediction using the cloudml engine prediction service.

    There is no local version of this step as it's the last step.
    """

    job_name = 'test_mltoolbox_batchprediction_%s' % uuid.uuid4().hex
    cmd = ['gcloud ml-engine jobs submit prediction ' + job_name,
           '--data-format=TEXT',
           '--input-paths=' + self._csv_predict_filename,
           '--output-path=' + self._prediction_output,
           '--model-dir=' + os.path.join(self._train_output, 'model'),
           '--runtime-version=1.0',
           '--region=us-central1']
    self._logger.debug('Running subprocess: %s \n\n' % ' '.join(cmd))
    subprocess.check_call(' '.join(cmd), shell=True)  # async call.
    subprocess.check_call('gcloud ml-engine jobs stream-logs ' + job_name, shell=True)

    # check that there was no errors.
    error_files = file_io.get_matching_files(
        os.path.join(self._prediction_output, 'prediction.errors_stats*'))
    self.assertEqual(1, len(error_files))
    error_str = file_io.read_file_to_string(error_files[0])
    self.assertEqual('', error_str)
开发者ID:googledatalab,项目名称:pydatalab,代码行数:24,代码来源:test_cloud_workflow.py


示例20: load_session_bundle_from_path

def load_session_bundle_from_path(export_dir,
                                  target="",
                                  config=None,
                                  meta_graph_def=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in
      tf.compat.v1.Session()
    config: A ConfigProto proto with configuration options. See config in
      tf.compat.v1.Session()
    meta_graph_def: optional object of type MetaGraphDef. If this object is
      present, then it is used instead of parsing MetaGraphDef from export_dir.

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  if not meta_graph_def:
    meta_graph_filename = os.path.join(export_dir,
                                       constants.META_GRAPH_DEF_FILENAME)
    if not file_io.file_exists(meta_graph_filename):
      raise RuntimeError("Expected meta graph file missing %s" %
                         meta_graph_filename)
    # Reads meta graph file.
    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    meta_graph_def.ParseFromString(
        file_io.read_file_to_string(meta_graph_filename, binary_mode=True))

  variables_filename = ""
  variables_filename_list = []
  checkpoint_sharded = False

  variables_index_filename = os.path.join(export_dir,
                                          constants.VARIABLES_INDEX_FILENAME_V2)
  checkpoint_v2 = file_io.file_exists(variables_index_filename)

  # Find matching checkpoint files.
  if checkpoint_v2:
    # The checkpoint is in v2 format.
    variables_filename_pattern = os.path.join(
        export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
    variables_filename_list = file_io.get_matching_files(
        variables_filename_pattern)
    checkpoint_sharded = True
  else:
    variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
    if file_io.file_exists(variables_filename):
      variables_filename_list = [variables_filename]
    else:
      variables_filename = os.path.join(export_dir,
                                        constants.VARIABLES_FILENAME_PATTERN)
      variables_filename_list = file_io.get_matching_files(variables_filename)
      checkpoint_sharded = True

  # Prepare the files to restore a session.
  if not variables_filename_list:
    restore_files = ""
  elif checkpoint_v2 or not checkpoint_sharded:
    # For checkpoint v2 or v1 with non-sharded files, use "export" to restore
    # the session.
    restore_files = constants.VARIABLES_FILENAME
  else:
    restore_files = constants.VARIABLES_FILENAME_PATTERN

  assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)

  collection_def = meta_graph_def.collection_def
  graph_def = graph_pb2.GraphDef()
  if constants.GRAPH_KEY in collection_def:
    # Use serving graph_def in MetaGraphDef collection_def if exists
    graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
    if len(graph_def_any) != 1:
      raise RuntimeError("Expected exactly one serving GraphDef in : %s" %
                         meta_graph_def)
    else:
      graph_def_any[0].Unpack(graph_def)
      # Replace the graph def in meta graph proto.
      meta_graph_def.graph_def.CopyFrom(graph_def)

  ops.reset_default_graph()
  sess = session.Session(target, graph=None, config=config)
  # Import the graph.
  saver = saver_lib.import_meta_graph(meta_graph_def)
  # Restore the session.
  if restore_files:
    saver.restore(sess, os.path.join(export_dir, restore_files))

  init_op_tensor = None
  if constants.INIT_OP_KEY in collection_def:
    init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
    if len(init_ops) != 1:
#.........这里部分代码省略.........
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:101,代码来源:session_bundle.py



注:本文中的tensorflow.python.lib.io.file_io.get_matching_files函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python file_io.is_directory函数代码示例发布时间:2022-05-27
下一篇:
Python file_io.file_exists函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap