本文整理汇总了Python中tensorpack.utils.logger.info函数的典型用法代码示例。如果您正苦于以下问题:Python info函数的具体用法?Python info怎么用?Python info使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了info函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: get_config
def get_config(model, nr_tower):
batch = TOTAL_BATCH_SIZE // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
step_size = 1280000 // TOTAL_BATCH_SIZE
max_iter = 3 * 10**5
max_epoch = (max_iter // step_size) + 1
callbacks = [
ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(0, 0.5), (max_iter, 0)],
interp='linear', step_based=True),
]
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
# single-GPU inference with queue prefetch
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else:
# multi-GPU inference (with mandatory queue prefetch)
callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))))
return TrainConfig(
model=model,
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=step_size,
max_epoch=max_epoch,
)
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:33,代码来源:shufflenet.py
示例2: _sync
def _sync(self):
logger.info("Updating weights ...")
dic = {v.name: v.eval() for v in self.vars}
self.shared_dic['params'] = dic
self.condvar.acquire()
self.condvar.notify_all()
self.condvar.release()
开发者ID:j50888,项目名称:tensorpack,代码行数:7,代码来源:simulator.py
示例3: __init__
def __init__(self,
predictor_io_names,
player,
state_shape,
batch_size,
memory_size, init_memory_size,
exploration, end_exploration, exploration_epoch_anneal,
update_frequency, history_len):
"""
Args:
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
player (RLEnvironment): the player.
history_len (int): length of history frames to concat. Zero-filled
initial frames.
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
"""
init_memory_size = int(init_memory_size)
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.num_actions = player.get_action_space().num_actions()
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
# TODO just use a semaphore?
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len)
开发者ID:j50888,项目名称:tensorpack,代码行数:34,代码来源:expreplay.py
示例4: _init_memory
def _init_memory(self):
logger.info("Populating replay memory with epsilon={} ...".format(self.exploration))
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
self._populate_exp()
pbar.update()
self._init_memory_flag.set()
开发者ID:tobyma,项目名称:tensorpack,代码行数:8,代码来源:expreplay.py
示例5: __init__
def __init__(self, dirname, label='phoneme'):
self.dirname = dirname
assert os.path.isdir(dirname), dirname
self.filelists = [k for k in fs.recursive_walk(self.dirname)
if k.endswith('.wav')]
logger.info("Found {} wav files ...".format(len(self.filelists)))
assert len(self.filelists), "Found no '.wav' files!"
assert label in ['phoneme', 'letter'], label
self.label = label
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:9,代码来源:create-lmdb.py
示例6: eval_model_multithread
def eval_model_multithread(pred, nr_eval, get_player_fn):
"""
Args:
pred (OfflinePredictor): state -> Qvalue
"""
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
with pred.sess.as_default():
mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:9,代码来源:common.py
示例7: update_target_param
def update_target_param():
vars = tf.global_variables()
ops = []
G = tf.get_default_graph()
for v in vars:
target_name = v.op.name
if target_name.startswith('target'):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
开发者ID:caserzer,项目名称:tensorpack,代码行数:11,代码来源:DQNModel.py
示例8: eval_with_funcs
def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
"""
Args:
predictors ([PredictorBase])
"""
class Worker(StoppableThread, ShareSessionThread):
def __init__(self, func, queue):
super(Worker, self).__init__()
self._func = func
self.q = queue
def func(self, *args, **kwargs):
if self.stopped():
raise RuntimeError("stopped!")
return self._func(*args, **kwargs)
def run(self):
with self.default_sess():
player = get_player_fn(train=False)
while not self.stopped():
try:
score = play_one_episode(player, self.func)
except RuntimeError:
return
self.queue_put_stoppable(self.q, score)
q = queue.Queue()
threads = [Worker(f, q) for f in predictors]
for k in threads:
k.start()
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
def fetch():
r = q.get()
stat.feed(r)
if verbose:
logger.info("Score: {}".format(r))
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
# waiting is necessary, otherwise the estimated mean score is biased
logger.info("Waiting for all the workers to finish the last run...")
for k in threads:
k.stop()
for k in threads:
k.join()
while q.qsize():
fetch()
if stat.count > 0:
return (stat.average, stat.max)
return (0, 0)
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:54,代码来源:common.py
示例9: convert_param_name
def convert_param_name(param):
resnet_param = {}
for k, v in six.iteritems(param):
try:
newname = name_conversion(k)
except Exception:
logger.error("Exception when processing caffe layer {}".format(k))
raise
logger.info("Name Transform: " + k + ' --> ' + newname)
resnet_param[newname] = v
return resnet_param
开发者ID:caserzer,项目名称:tensorpack,代码行数:11,代码来源:load-resnet.py
示例10: compute_mean_std
def compute_mean_std(db, fname):
ds = LMDBSerializer.load(db, shuffle=False)
ds.reset_state()
o = OnlineMoments()
for dp in get_tqdm(ds):
feat = dp[0] # len x dim
for f in feat:
o.feed(f)
logger.info("Writing to {} ...".format(fname))
with open(fname, 'wb') as f:
f.write(serialize.dumps([o.mean, o.std]))
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:11,代码来源:create-lmdb.py
示例11: run
def run(self):
self.clients = defaultdict(self.ClientState)
try:
while True:
msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg
client = self.clients[ident]
if client.ident is None:
client.ident = ident
# maybe check history and warn about dead client?
self._process_msg(client, state, reward, isOver)
except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.")
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:13,代码来源:simulator.py
示例12: compute_mean_std
def compute_mean_std(db, fname):
ds = LMDBDataPoint(db, shuffle=False)
ds.reset_state()
o = OnlineMoments()
with get_tqdm(total=ds.size()) as bar:
for dp in ds.get_data():
feat = dp[0] # len x dim
for f in feat:
o.feed(f)
bar.update()
logger.info("Writing to {} ...".format(fname))
with open(fname, 'wb') as f:
f.write(serialize.dumps([o.mean, o.std]))
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:13,代码来源:create-lmdb.py
示例13: _trigger_epoch
def _trigger_epoch(self):
if self.exploration > self.end_exploration:
self.exploration -= self.exploration_epoch_anneal
logger.info("Exploration changed to {}".format(self.exploration))
# log player statistics
stats = self.player.stats
for k, v in six.iteritems(stats):
try:
mean, max = np.mean(v), np.max(v)
self.trainer.add_scalar_summary('expreplay/mean_' + k, mean)
self.trainer.add_scalar_summary('expreplay/max_' + k, max)
except:
pass
self.player.reset_stat()
开发者ID:j50888,项目名称:tensorpack,代码行数:14,代码来源:expreplay.py
示例14: print_class_histogram
def print_class_histogram(self, imgs):
nr_class = len(COCOMeta.class_names)
hist_bins = np.arange(nr_class + 1)
# Histogram of ground-truth objects
gt_hist = np.zeros((nr_class,), dtype=np.int)
for entry in imgs:
# filter crowd?
gt_inds = np.where(
(entry['class'] > 0) & (entry['is_crowd'] == 0))[0]
gt_classes = entry['class'][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = [[COCOMeta.class_names[i], v] for i, v in enumerate(gt_hist)]
data.append(['total', sum([x[1] for x in data])])
table = tabulate(data, headers=['class', '#box'], tablefmt='pipe')
logger.info("Ground-Truth Boxes:\n" + colored(table, 'cyan'))
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:16,代码来源:coco.py
示例15: texture_loss
def texture_loss(x, p=16):
_, h, w, c = x.get_shape().as_list()
x = normalize(x)
assert h % p == 0 and w % p == 0
logger.info('Create texture loss for layer {} with shape {}'.format(x.name, x.get_shape()))
x = tf.space_to_batch_nd(x, [p, p], [[0, 0], [0, 0]]) # [b * ?, h/p, w/p, c]
x = tf.reshape(x, [p, p, -1, h // p, w // p, c]) # [p, p, b, h/p, w/p, c]
x = tf.transpose(x, [2, 3, 4, 0, 1, 5]) # [b * ?, p, p, c]
patches_a, patches_b = tf.split(x, 2, axis=0) # each is b,h/p,w/p,p,p,c
patches_a = tf.reshape(patches_a, [-1, p, p, c]) # [b * ?, p, p, c]
patches_b = tf.reshape(patches_b, [-1, p, p, c]) # [b * ?, p, p, c]
return tf.losses.mean_squared_error(
gram_matrix(patches_a),
gram_matrix(patches_b),
reduction=Reduction.MEAN
)
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:18,代码来源:enet-pat.py
示例16: _parameter_net
def _parameter_net(self, theta, kernel_shape=9):
"""Estimate filters for convolution layers
Args:
theta: angle of filter
kernel_shape: size of each filter
Returns:
learned filter as [B, k, k, 1]
"""
with argscope(FullyConnected, nl=tf.nn.leaky_relu):
net = FullyConnected('fc1', theta, 64)
net = FullyConnected('fc2', net, 128)
pred_filter = FullyConnected('fc3', net, kernel_shape ** 2, nl=tf.identity)
pred_filter = tf.reshape(pred_filter, [BATCH, kernel_shape, kernel_shape, 1], name="pred_filter")
logger.info('Parameter net output: {}'.format(pred_filter.get_shape().as_list()))
return pred_filter
开发者ID:tobyma,项目名称:tensorpack,代码行数:18,代码来源:steering-filter.py
示例17: __init__
def __init__(self, basedir, name):
assert name in COCOMeta.INSTANCE_TO_BASEDIR.keys(), name
self.name = name
self._imgdir = os.path.join(basedir, COCOMeta.INSTANCE_TO_BASEDIR[name])
assert os.path.isdir(self._imgdir), self._imgdir
annotation_file = os.path.join(
basedir, 'annotations/instances_{}.json'.format(name))
assert os.path.isfile(annotation_file), annotation_file
self.coco = COCO(annotation_file)
# initialize the meta
cat_ids = self.coco.getCatIds()
cat_names = [c['name'] for c in self.coco.loadCats(cat_ids)]
if not COCOMeta.valid():
COCOMeta.create(cat_ids, cat_names)
else:
assert COCOMeta.cat_names == cat_names
logger.info("Instances loaded from {}.".format(annotation_file))
开发者ID:ahuirecome,项目名称:tensorpack,代码行数:20,代码来源:coco.py
示例18: get_logits
def get_logits(self, image):
with argscope([Conv2D, MaxPooling, AvgPooling, GlobalAvgPooling, BatchNorm], data_format='channels_first'), \
argscope(Conv2D, use_bias=False):
group = args.group
if not args.v2:
# Copied from the paper
channels = {
3: [240, 480, 960],
4: [272, 544, 1088],
8: [384, 768, 1536]
}
mul = group * 4 # #chan has to be a multiple of this number
channels = [int(math.ceil(x * args.ratio / mul) * mul)
for x in channels[group]]
# The first channel must be a multiple of group
first_chan = int(math.ceil(24 * args.ratio / group) * group)
else:
# Copied from the paper
channels = {
0.5: [48, 96, 192],
1.: [116, 232, 464]
}[args.ratio]
first_chan = 24
logger.info("#Channels: " + str([first_chan] + channels))
l = Conv2D('conv1', image, first_chan, 3, strides=2, activation=BNReLU)
l = MaxPooling('pool1', l, 3, 2, padding='SAME')
l = shufflenet_stage('stage2', l, channels[0], 4, group)
l = shufflenet_stage('stage3', l, channels[1], 8, group)
l = shufflenet_stage('stage4', l, channels[2], 4, group)
if args.v2:
l = Conv2D('conv5', l, 1024, 1, activation=BNReLU)
l = GlobalAvgPooling('gap', l)
logits = FullyConnected('linear', l, 1000)
return logits
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:41,代码来源:shufflenet.py
示例19: run
def run(self):
self.clients = defaultdict(self.ClientState)
try:
while True:
msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg
# TODO check history and warn about dead client
client = self.clients[ident]
# check if reward&isOver is valid
# in the first message, only state is valid
if len(client.memory) > 0:
client.memory[-1].reward = reward
if isOver:
self._on_episode_over(ident)
else:
self._on_datapoint(ident)
# feed state and return action
self._on_state(state, ident)
except zmq.ContextTerminated:
logger.info("[Simulator] Context was terminated.")
开发者ID:j50888,项目名称:tensorpack,代码行数:21,代码来源:simulator.py
示例20: __init__
def __init__(self,
predictor_io_names,
player,
state_shape,
batch_size,
memory_size, init_memory_size,
init_exploration,
update_frequency, history_len):
"""
Args:
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
player (RLEnvironment): the player.
state_shape (tuple): h, w, c
history_len (int): length of history frames to concat. Zero-filled
initial frames.
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
"""
assert len(state_shape) == 3, state_shape
init_memory_size = int(init_memory_size)
for k, v in locals().items():
if k != 'self':
setattr(self, k, v)
self.exploration = init_exploration
self.num_actions = player.action_space.n
logger.info("Number of Legal actions: {}".format(self.num_actions))
self.rng = get_rng(self)
self._init_memory_flag = threading.Event() # tell if memory has been initialized
# a queue to receive notifications to populate memory
self._populate_job_queue = queue.Queue(maxsize=5)
self.mem = ReplayMemory(memory_size, state_shape, history_len)
self._current_ob = self.player.reset()
self._player_scores = StatCounter()
self._current_game_score = StatCounter()
开发者ID:quanlzheng,项目名称:tensorpack,代码行数:39,代码来源:expreplay.py
注:本文中的tensorpack.utils.logger.info函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论