本文整理汇总了Python中theano.compat.OrderedDict类的典型用法代码示例。如果您正苦于以下问题:Python OrderedDict类的具体用法?Python OrderedDict怎么用?Python OrderedDict使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了OrderedDict类的19个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: get_monitoring_channels
def get_monitoring_channels(self, data):
rval = OrderedDict()
try:
rval.update(self.mlp.get_monitoring_channels(data))
except Exception:
warnings.warn("something went wrong with compressor.mlp's monitoring channels")
return rval
开发者ID:vinmisra,项目名称:adversary-compress,代码行数:7,代码来源:CAN.py
示例2: get_monitoring_channels
def get_monitoring_channels(self, data):
rval = OrderedDict()
g_ch = self.generator.get_monitoring_channels(data)
d_ch = self.discriminator.get_monitoring_channels((data, None))
samples, _, conditional_data, _ = self.generator.sample_and_noise(100)
d_samp_ch = self.discriminator.get_monitoring_channels(((samples, conditional_data), None))
i_ch = OrderedDict()
if self.inferer is not None:
batch_size = self.inference_monitoring_batch_size
sample, noise, conditional_data, _ = self.generator.sample_and_noise(batch_size)
i_ch.update(self.inferer.get_monitoring_channels(((sample, conditional_data), noise)))
if self.monitor_generator:
for key in g_ch:
rval["gen_" + key] = g_ch[key]
if self.monitor_discriminator:
for key in d_ch:
rval["dis_on_data_" + key] = d_samp_ch[key]
for key in d_ch:
rval["dis_on_samp_" + key] = d_ch[key]
if self.monitor_inference:
for key in i_ch:
rval["inf_" + key] = i_ch[key]
return rval
开发者ID:hit-computer,项目名称:adversarial,代码行数:27,代码来源:__init__.py
示例3: orderings
def orderings(self):
"""
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their outputs.
Notes
-----
This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
"""
ords = OrderedDict()
assert isinstance(self._features, list)
for feature in self._features:
if hasattr(feature, 'orderings'):
orderings = feature.orderings(self)
if not isinstance(orderings, OrderedDict):
raise TypeError("Non-deterministic return value from " +
str(feature.orderings) +
". Nondeterministic object is " +
str(orderings))
for node, prereqs in iteritems(orderings):
if not isinstance(prereqs, (list, OrderedSet)):
raise TypeError(
"prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic.")
ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs
for (node, prereqs) in iteritems(ords):
ords[node] = list(OrderedSet(prereqs))
return ords
开发者ID:chinnadhurai,项目名称:Theano,代码行数:35,代码来源:fg.py
示例4: __init__
def __init__(self, valid=None, invalid=None, valid_equivalent=None):
'''
Check if variables can be expressed without using variables in invalid.
init_valid_equivalent provides a dictionary mapping some invalid
variables to valid ones that can be used instead.
'''
if valid is None:
valid = []
if invalid is None:
invalid = []
if valid_equivalent is None:
valid_equivalent = OrderedDict()
# Nodes that are valid to have in the graph computing outputs
self.valid = set(valid)
# Nodes that are NOT valid to have in the graph computing outputs
self.invalid = set(invalid)
# Mapping from invalid variables to equivalent valid ones.
self.valid_equivalent = valid_equivalent.copy()
self.valid.update(valid_equivalent.values())
self.invalid.update(valid_equivalent.keys())
开发者ID:LEEKYOUNGHUN,项目名称:Theano,代码行数:25,代码来源:scan_utils.py
示例5: get_layer_monitoring_channels
def get_layer_monitoring_channels(self, state_below=None, state=None, targets=None):
W, = self.transformer.get_params()
assert W.ndim == 4
sq_W = T.sqr(W)
row_norms = T.sqrt(sq_W.sum(axis=(0, 1, 2)))
P = state
rval = OrderedDict()
vars_and_prefixes = [(P, '')]
for var, prefix in vars_and_prefixes:
if not hasattr(var, 'ndim') or var.ndim != 4:
print "expected 4D tensor, got "
print var
print type(var)
if isinstance(var, tuple):
print "tuple length: ", len(var)
assert False
v_max = var.max(axis=3)
v_min = var.min(axis=3)
v_mean = var.mean(axis=3)
v_range = v_max - v_min
v_max = v_max.max(axis=(1,2))
v_min = v_min.min(axis=(1,2))
# max_x.mean_u is "the mean over *u*nits of the max over
# e*x*amples" The x and u are included in the name because
# otherwise its hard to remember which axis is which when reading
# the monitor I use inner.outer rather than outer_of_inner or
# something like that because I want mean_x.* to appear next to
# each other in the alphabetical list, as these are commonly
# plotted together
for key, val in [('max_x.max_u', v_max.max()),
('max_x.mean_u', v_max.mean()),
('max_x.min_u', v_max.min()),
('min_x.max_u', v_min.max()),
('min_x.mean_u', v_min.mean()),
('min_x.min_u', v_min.min()),
('range_x.max_u', v_range.max()),
('range_x.mean_u', v_range.mean()),
('range_x.min_u', v_range.min()),
('mean_x.max_u', v_mean.max()),
('mean_x.mean_u', v_mean.mean()),
('mean_x.min_u', v_mean.min())]:
rval[prefix+key] = val
rval.update(OrderedDict([('kernel_norms_min', row_norms.min()),
('kernel_norms_mean', row_norms.mean()),
('kernel_norms_max', row_norms.max()), ]))
return rval
开发者ID:cc13ny,项目名称:galatea,代码行数:57,代码来源:deconv.py
示例6: on_attach
def on_attach(self, fgraph):
"""
When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one).
2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing
compilation to be slower.
Give the FunctionGraph instance:
1) A new method "destroyers(var)"
TODO: what does this do exactly?
2) A new attribute, "destroy_handler"
TODO: WRITEME: what does this do besides the checks?
"""
# Do the checking #
already_there = False
if self.fgraph is fgraph:
already_there = True
if self.fgraph is not None:
raise Exception(
"A DestroyHandler instance can only serve one"
" FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
# Annotate the FunctionGraph #
self.unpickle(fgraph)
fgraph.destroy_handler = self
self.fgraph = fgraph
self.destroyers = OrderedSet() # set of Apply instances with non-null destroy_map
self.view_i = OrderedDict() # variable -> variable used in calculation
self.view_o = OrderedDict() # variable -> set of variables that use this one as a direct input
# clients: how many times does an apply use a given variable
self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True
self.debug_all_apps = OrderedSet()
if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, fgraph)
开发者ID:12190143,项目名称:Theano,代码行数:50,代码来源:destroyhandler.py
示例7: __init__
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in iteritems(self.axis):
assert isinstance(axis, (numpy.integer, int)), ("Rebroadcast needs integer axes. Got ", axis)
assert isinstance(broad, bool), ("Rebroadcast needs bool for new broadcast pattern. Got ", broad)
开发者ID:ZhangAustin,项目名称:attention-lvcsr,代码行数:7,代码来源:ops.py
示例8: get_layer_monitoring_channels
def get_layer_monitoring_channels(self, state_below=None,
state=None, targets=None):
W, = self.transformer.get_params()
assert W.ndim == 5
sq_W = T.sqr(W)
row_norms = T.sqrt(sq_W.sum(axis=(1, 2, 3, 4)))
rval = OrderedDict([
('kernel_norms_min', row_norms.min()),
('kernel_norms_mean', row_norms.mean()),
('kernel_norms_max', row_norms.max()),
])
cost = self.cost
orval = self.nonlin.get_monitoring_channels_from_state(state,
targets,
cost_fn=cost)
rval.update(orval)
return rval
开发者ID:robintibor,项目名称:pylearn3dconv,代码行数:17,代码来源:base.py
示例9: __init__
def __init__(self, valid=None, invalid=None, valid_equivalent=None):
if valid is None:
valid = []
if invalid is None:
invalid = []
if valid_equivalent is None:
valid_equivalent = OrderedDict()
# Nodes that are valid to have in the graph computing outputs
self.valid = set(valid)
# Nodes that are NOT valid to have in the graph computing outputs
self.invalid = set(invalid)
# Mapping from invalid variables to equivalent valid ones.
self.valid_equivalent = valid_equivalent.copy()
self.valid.update(list(valid_equivalent.values()))
self.invalid.update(list(valid_equivalent.keys()))
开发者ID:ALISCIFP,项目名称:Segmentation,代码行数:18,代码来源:scan_utils.py
示例10: get_monitoring_channels
def get_monitoring_channels(self, data):
if data is None:
m = 100
else:
m = data.shape[0]
n = self.mlp.get_input_space().get_total_dimension()
noise = self.get_noise((m, n))
rval = OrderedDict()
try:
rval.update(self.mlp.get_monitoring_channels((noise, None)))
except Exception:
warnings.warn("something went wrong with generator.mlp's monitoring channels")
if self.monitor_ll:
rval['ll'] = T.cast(self.ll(data, self.ll_n_samples, self.ll_sigma),
theano.config.floatX).mean()
rval['nll'] = -rval['ll']
return rval
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:19,代码来源:__init__.py
示例11: get_gradients
def get_gradients(self, model, data, **kwargs):
space, sources = self.get_data_specs(model)
space.validate(data)
assert isinstance(model, CompressAdversaryPair)
g = model.compressor
d = model.discriminator
#get raw gradients for d and g objectives...
d_obj, g_obj = self.get_objectives(model, data)
g_params = g.get_params()
d_params = d.get_params()
for param in g_params:
assert param not in d_params
for param in d_params:
assert param not in g_params
d_grads = T.grad(d_obj, d_params)
g_grads = T.grad(g_obj, g_params)
# if self.scale_grads:
# S_grad = T.grad(g_obj, S)
# scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum()))
# g_grads = [g_grad * scale for g_grad in g_grads]
#adjust raw gradients with control signals
rval = OrderedDict()
zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32'))
if self.ever_train_discriminator:
rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads])))
else:
rval.update(OrderedDict(zip(d_params, zeros)))
if self.ever_train_compressor:
rval.update(OrderedDict(safe_zip(g_params, [self.now_train_compressor * gg for gg in g_grads])))
else:
rval.update(OrderedDict(zip(g_params, zeros)))
#update control signals using the updates return functionality
updates = OrderedDict()
#first, the clock
self.future_train_clock = T.switch(T.ge(self.train_clock,self.discriminator_steps+self.joint_steps+self.compressor_steps),1.,self.train_clock+1.)
updates[self.train_clock] = self.future_train_clock
#then the control signals
updates[self.now_train_discriminator] = T.switch(T.le(self.future_train_clock,self.discriminator_steps+self.joint_steps),1.,0.)
updates[self.now_train_compressor] = T.switch(T.gt(self.future_train_clock,self.discriminator_steps),1.,0.)
return rval, updates
开发者ID:vinmisra,项目名称:adversary-compress,代码行数:48,代码来源:CAN.py
示例12: __init__
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in iteritems(self.axis):
if not isinstance(axis, (numpy.integer, integer_types)):
raise TypeError("Rebroadcast needs integer axes. "
"Got {}".format(axis))
if not isinstance(broad, (numpy.bool_, bool)):
raise TypeError("Rebroadcast needs bool for new broadcast "
"pattern. Got {}".format(broad))
开发者ID:Azrael1,项目名称:Theano,代码行数:12,代码来源:ops.py
示例13: get_lr_scalers
def get_lr_scalers(self):
"""
.. todo::
WRITEME
"""
rval = OrderedDict()
params = self.get_params()
for layer in self.hidden_layers + [ self.visible_layer ]:
contrib = layer.get_lr_scalers()
# No two layers can contend to scale a parameter
assert not any([key in rval for key in contrib])
# Don't try to scale anything that's not a parameter
assert all([key in params for key in contrib])
rval.update(contrib)
assert all([isinstance(val, float) for val in rval.values()])
return rval
开发者ID:JakeMick,项目名称:pylearn2,代码行数:22,代码来源:dbm.py
示例14: scan
#.........这里部分代码省略.........
# Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan
non_seqs = []
for elem in wrap_into_list(non_sequences):
if not isinstance(elem, gof.Variable):
non_seqs.append(tensor.as_tensor_variable(elem))
else:
non_seqs.append(elem)
# If we provided a known number of steps ( before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
n_fixed_steps = None
if isinstance(n_steps, (float, int)):
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError:
n_fixed_steps = None
# Check n_steps is an int
if (hasattr(n_steps, 'dtype') and
str(n_steps.dtype)[:3] not in ('uin', 'int')):
raise ValueError(' n_steps must be an int. dtype provided '
'is %s' % n_steps.dtype)
# compute number of sequences and number of outputs
n_seqs = len(seqs)
n_outs = len(outs_info)
return_steps = OrderedDict()
# wrap sequences in a dictionary if they are not already dictionaries
for i in xrange(n_seqs):
if not isinstance(seqs[i], dict):
seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])])
elif seqs[i].get('taps', None) is not None:
seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
elif seqs[i].get('taps', None) is None:
# seqs dictionary does not have the ``taps`` key
seqs[i]['taps'] = [0]
# wrap outputs info in a dictionary if they are not already in one
for i in xrange(n_outs):
if outs_info[i] is not None:
if isinstance(outs_info[i], dict):
# DEPRECATED :
if outs_info[i].get('return_steps', None) is not None:
raise ValueError(
"Using `return_steps` has been deprecated. "
"Simply select the entries you need using a "
"subtensor. Scan will optimize memory "
"consumption, so do not worry about that.")
# END
if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1
outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])])
elif (outs_info[i].get('initial', None) is None and
outs_info[i].get('taps', None) is not None):
# ^ no initial state but taps provided
raise ValueError(('If you are using slices of an output '
'you need to provide a initial state '
'for it'), outs_info[i])
开发者ID:Micseb,项目名称:Theano,代码行数:67,代码来源:scan.py
示例15: Rebroadcast
class Rebroadcast(gof.Op):
"""
Change the input's broadcastable fields in some predetermined way.
See Also
--------
unbroadcast <theano.tensor.unbroadcast>
addbroadcast <theano.tensor.addbroadcast>
patternbroadcast <theano.tensor.patternbroadcast>
Notes
-----
Works inplace and works for CudaNdarrayType.
Example
-------
`Rebroadcast((0, True), (1, False))(x)` would make `x` broadcastable in
axis 0 and not broadcastable in axis 1.
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version = {}
check_input = False
__props__ = ("axis",)
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in iteritems(self.axis):
if not isinstance(axis, (numpy.integer, integer_types)):
raise TypeError("Rebroadcast needs integer axes. "
"Got {}".format(axis))
if not isinstance(broad, (numpy.bool_, bool)):
raise TypeError("Rebroadcast needs bool for new broadcast "
"pattern. Got {}".format(broad))
def __hash__(self):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique
items = sorted(iteritems(self.axis))
return hash((type(self), tuple(items)))
def __str__(self):
if len(self.axis) == 0:
broadcast_pattern = []
else:
broadcast_pattern = ['?' for i
in xrange(1 + max(self.axis.keys()))]
for k, v in iteritems(self.axis):
broadcast_pattern[k] = str(int(v))
return '%s{%s}' % (self.__class__.__name__,
','.join(broadcast_pattern))
def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
raise ValueError('Trying to rebroadcast non-existent dimension')
t = x.type.clone(
broadcastable=[self.axis.get(i, b)
for i, b in enumerate(x.type.broadcastable)])
return gof.Apply(self, [x], [t()])
def perform(self, node, inp, out_):
x, = inp
out, = out_
for axis, value in iteritems(self.axis):
if value and x.shape[axis] != 1:
raise ValueError('Dimension %s in Rebroadcast\'s input was'
' supposed to be 1 (got %s instead)' %
(axis, x.shape[axis]))
out[0] = x
def grad(self, inp, grads):
x, = inp
gz, = grads
# restore the broadcasting pattern of the input
return Rebroadcast(*[(axis, x.type.broadcastable[axis])
for axis, value in iteritems(self.axis)])(gz),
def infer_shape(self, node, ishapes):
assert len(ishapes) == 1
l = []
one = theano.tensor.basic.constant(1)
for ax in xrange(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
#.........这里部分代码省略.........
开发者ID:Azrael1,项目名称:Theano,代码行数:101,代码来源:ops.py
示例16: DestroyHandler
class DestroyHandler(toolbox.Bookkeeper): # noqa
"""
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
Several data structures are used to do this.
An Op can use its view_map property to declare that an output may be
aliased to an input. If that output is destroyed, the input is also
considered to be destroyed. The view_maps of several Ops can feed into
one another and form a directed graph. The consequence of destroying any
variable in such a graph is that all variables in the graph must be
considered to be destroyed, because they could all be refering to the
same underlying storage.
In the current implementation, that graph is a tree, and the root of that
tree is called the foundation.
TODO: why "in the current implementation" ? is there another implementation
planned?
TODO: why is the graph a tree? isn't it possible that one variable could
be aliased to many variables? for example, don't switch and ifelse
have to do this?
The original DestroyHandler (if 0'ed out above) computed several data
structures from scratch each time it was asked to validate the graph.
Because this happens potentially thousands of times and each graph to
validate is extremely similar to the previous one, computing the
data structures from scratch repeatedly was wasteful and resulted in
high compile times for large graphs.
This implementation computes the data structures once at initialization
and then incrementally updates them.
It is a work in progress. The following data structures have been
converted to use the incremental strategy:
<none>
The following data structures remain to be converted:
<unknown>
"""
pickle_rm_attr = ["destroyers"]
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
"""maps every variable in the graph to its "foundation" (deepest
ancestor in view chain)
TODO: change name to var_to_vroot"""
self.droot = OrderedDict()
"""maps a variable to all variables that are indirect or direct views of it
(including itself)
essentially the inverse of droot
TODO: do all variables appear in this dict, or only those that are foundations?
TODO: do only destroyed variables go in here? one old docstring said so
TODO: rename to x_to_views after reverse engineering what x is"""
self.impact = OrderedDict()
"""if a var is destroyed, then this dict will map
droot[var] to the apply node that destroyed var
TODO: rename to vroot_to_destroyer"""
self.root_destroyer = OrderedDict()
def on_attach(self, fgraph):
"""
When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one)
2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing
compilation to be slower.
Give the FunctionGraph instance:
1) A new method "destroyers(var)"
TODO: what does this do exactly?
2) A new attribute, "destroy_handler"
TODO: WRITEME: what does this do besides the checks?
"""
# Do the checking #
already_there = False
if self.fgraph is fgraph:
already_there = True
if self.fgraph is not None:
raise Exception(
"A DestroyHandler instance can only serve one"
" FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
# Annotate the FunctionGraph #
#.........这里部分代码省略.........
开发者ID:ragavvenkatesan,项目名称:Theano,代码行数:101,代码来源:destroyhandler.py
示例17: xrange
n_outs = n_outs - 1
outs_info = [OrderedDict() for x in xrange(n_outs)]
# Step 5.1 Outputs with taps different then -1
for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] != [-1]:
mit_sot_inner_outputs.append(outputs[i])
# Step 5.2 Outputs with tap equal to -1
for i, out in enumerate(outs_info):
if 'taps' in out and out['taps'] == [-1]:
sit_sot_inner_outputs.append(outputs[i])
# Step 5.3 Outputs that correspond to update rules of shared variables
givens = OrderedDict()
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
shared_inner_outputs = []
sit_sot_shared = []
for input in dummy_f.maker.expanded_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable)
if getattr(input.variable, 'name', None) is not None:
new_var.name = input.variable.name + '_copy'
if isinstance(new_var.type, ops.expandable_types):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
scan_utils.expand(
tensor.unbroadcast(
开发者ID:TimSalimans,项目名称:Theano,代码行数:31,代码来源:scan.py
示例18: scan
#.........这里部分代码省略.........
allow_gc = config.scan.allow_gc
# Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan
non_seqs = []
for elem in wrap_into_list(params):
if not isinstance(elem, gof.Variable):
non_seqs.append(tensor.as_tensor_variable(elem))
else:
non_seqs.append(elem)
# If we provided a known number of steps ( before compilation)
# and if that number is 1 or -1, then we can skip the Scan Op,
# and just apply the inner function once
# To do that we check here to see the nature of n_steps
n_fixed_steps = None
if isinstance(n_steps, (float, int)):
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = opt.get_scalar_constant_value(n_steps)
except tensor.basic.NotScalarConstantError:
n_fixed_steps = None
# Check n_steps is an int
if hasattr(n_steps, "dtype") and str(n_steps.dtype)[:3] not in ("uin", "int"):
raise ValueError(" n_steps must be an int. dtype provided " "is %s" % n_steps.dtype)
# compute number of sequences and number of outputs
n_seqs = len(seqs)
n_outs = len(outs_info)
return_steps = OrderedDict()
# wrap outputs info in a dictionary if they are not already in one
for i in xrange(n_outs):
if outs_info[i] is not None:
if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1
outs_info[i] = dict(membuf=outs_info[i], taps=[-1])
elif not outs_info[i].get("membuf", None) and outs_info[i].get("taps", None):
# ^ no initial state but taps provided
raise ValueError(
("If you are using slices of an output " "you need to provide a memory buffer for " "the state "),
outs_info[i],
)
elif outs_info[i].get("membuf", None) and not outs_info[i].get("taps", None):
# ^ initial state but taps not provided
if "taps" in outs_info[i]:
# ^ explicitly provided a None for taps
_logger.warning(
"Output %s (index %d) has a memory " "buffer but taps is explicitly set to None ",
getattr(outs_info[i]["membuf"], "name", "None"),
i,
)
outs_info[i]["taps"] = [-1]
else:
# if a None is provided as the output info we replace it
# with an dict(steps=n_steps) to simplify handling
outs_info[i] = dict(steps=n_steps)
##
# Step 2. Generate inputs and outputs of the inner functions
# for compiling a dummy function (Iteration #1)
##
开发者ID:amanrajdce,项目名称:Theano,代码行数:66,代码来源:scan.py
示例19: get_gradients
def get_gradients(self, model, data, **kwargs):
space, sources = self.get_data_specs(model)
space.validate(data)
assert isinstance(model, AdversaryPair)
g = model.generator
d = model.discriminator
S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data)
g_params = g.get_params()
d_params = d.get_params()
for param in g_params:
assert param not in d_params
for param in d_params:
assert param not in g_params
d_grads = T.grad(d_obj, d_params)
g_grads = T.grad(g_obj, g_params)
if self.scale_grads:
S_grad = T.grad(g_obj, S)
scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum()))
g_grads = [g_grad * scale for g_grad in g_grads]
rval = OrderedDict()
zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32'))
if self.ever_train_discriminator:
rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads])))
else:
rval.update(OrderedDict(zip(d_params, zeros)))
if self.ever_train_generator:
rval.update(OrderedDict(safe_zip(g_params, [self.now_train_generator * gg for gg in g_grads])))
else:
rval.update(OrderedDict(zip(g_params, zeros)))
if self.ever_train_inference and model.inferer is not None:
i_params = model.inferer.get_params()
i_grads = T.grad(i_obj, i_params)
rval.update(OrderedDict(safe_zip(i_params, [self.now_train_inference * ig for ig in i_grads])))
elif model.inferer is not None:
rval.update(OrderedDict(model.inferer.get_params(), zeros))
updates = OrderedDict()
# Two d steps for every g step
if self.alternate_g:
updates[self.now_train_generator] = 1. - self.now_train_generator
return rval, updates
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:47,代码来源:__init__.py
注:本文中的theano.compat.OrderedDict类示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论