• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python compat.OrderedDict类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中theano.compat.OrderedDict的典型用法代码示例。如果您正苦于以下问题:Python OrderedDict类的具体用法?Python OrderedDict怎么用?Python OrderedDict使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



在下文中一共展示了OrderedDict类的19个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: get_monitoring_channels

 def get_monitoring_channels(self, data):
     rval = OrderedDict()
     try:
         rval.update(self.mlp.get_monitoring_channels(data))
     except Exception:
         warnings.warn("something went wrong with compressor.mlp's monitoring channels")
     return rval
开发者ID:vinmisra,项目名称:adversary-compress,代码行数:7,代码来源:CAN.py


示例2: get_monitoring_channels

    def get_monitoring_channels(self, data):
        rval = OrderedDict()

        g_ch = self.generator.get_monitoring_channels(data)
        d_ch = self.discriminator.get_monitoring_channels((data, None))

        samples, _, conditional_data, _ = self.generator.sample_and_noise(100)
        d_samp_ch = self.discriminator.get_monitoring_channels(((samples, conditional_data), None))

        i_ch = OrderedDict()
        if self.inferer is not None:
            batch_size = self.inference_monitoring_batch_size
            sample, noise, conditional_data, _ = self.generator.sample_and_noise(batch_size)
            i_ch.update(self.inferer.get_monitoring_channels(((sample, conditional_data), noise)))

        if self.monitor_generator:
            for key in g_ch:
                rval["gen_" + key] = g_ch[key]
        if self.monitor_discriminator:
            for key in d_ch:
                rval["dis_on_data_" + key] = d_samp_ch[key]
            for key in d_ch:
                rval["dis_on_samp_" + key] = d_ch[key]
        if self.monitor_inference:
            for key in i_ch:
                rval["inf_" + key] = i_ch[key]
        return rval
开发者ID:hit-computer,项目名称:adversarial,代码行数:27,代码来源:__init__.py


示例3: orderings

    def orderings(self):
        """
        Return dict d s.t. d[node] is a list of nodes that must be evaluated
        before node itself can be evaluated.

        This is used primarily by the destroy_handler feature to ensure that
        all clients of any destroyed inputs have already computed their outputs.

        Notes
        -----
        This only calls the orderings() fct on all features. It does not
        take care of computing dependencies by itself.

        """
        ords = OrderedDict()
        assert isinstance(self._features, list)
        for feature in self._features:
            if hasattr(feature, 'orderings'):
                orderings = feature.orderings(self)
                if not isinstance(orderings, OrderedDict):
                    raise TypeError("Non-deterministic return value from " +
                                    str(feature.orderings) +
                                    ". Nondeterministic object is " +
                                    str(orderings))
                for node, prereqs in iteritems(orderings):
                    if not isinstance(prereqs, (list, OrderedSet)):
                        raise TypeError(
                            "prereqs must be a type with a "
                            "deterministic iteration order, or toposort "
                            " will be non-deterministic.")
                    ords.setdefault(node, []).extend(prereqs)
        # eliminate duplicate prereqs
        for (node, prereqs) in iteritems(ords):
            ords[node] = list(OrderedSet(prereqs))
        return ords
开发者ID:chinnadhurai,项目名称:Theano,代码行数:35,代码来源:fg.py


示例4: __init__

    def __init__(self, valid=None, invalid=None, valid_equivalent=None):
        '''
        Check if variables can be expressed without using variables in invalid.

        init_valid_equivalent provides a dictionary mapping some invalid
        variables to valid ones that can be used instead.
        '''

        if valid is None:
            valid = []
        if invalid is None:
            invalid = []
        if valid_equivalent is None:
            valid_equivalent = OrderedDict()

        # Nodes that are valid to have in the graph computing outputs
        self.valid = set(valid)

        # Nodes that are NOT valid to have in the graph computing outputs
        self.invalid = set(invalid)

        # Mapping from invalid variables to equivalent valid ones.
        self.valid_equivalent = valid_equivalent.copy()
        self.valid.update(valid_equivalent.values())
        self.invalid.update(valid_equivalent.keys())
开发者ID:LEEKYOUNGHUN,项目名称:Theano,代码行数:25,代码来源:scan_utils.py


示例5: get_layer_monitoring_channels

    def get_layer_monitoring_channels(self, state_below=None, state=None, targets=None):

        W, = self.transformer.get_params()

        assert W.ndim == 4

        sq_W = T.sqr(W)

        row_norms = T.sqrt(sq_W.sum(axis=(0, 1, 2)))

        P = state

        rval = OrderedDict()

        vars_and_prefixes = [(P, '')]

        for var, prefix in vars_and_prefixes:
            if not hasattr(var, 'ndim') or var.ndim != 4:
                print "expected 4D tensor, got "
                print var
                print type(var)
                if isinstance(var, tuple):
                    print "tuple length: ", len(var)
                assert False
            v_max = var.max(axis=3)
            v_min = var.min(axis=3)
            v_mean = var.mean(axis=3)
            v_range = v_max - v_min
            v_max = v_max.max(axis=(1,2))
            v_min = v_min.min(axis=(1,2))

            # max_x.mean_u is "the mean over *u*nits of the max over
            # e*x*amples" The x and u are included in the name because
            # otherwise its hard to remember which axis is which when reading
            # the monitor I use inner.outer rather than outer_of_inner or
            # something like that because I want mean_x.* to appear next to
            # each other in the alphabetical list, as these are commonly
            # plotted together
            for key, val in [('max_x.max_u',    v_max.max()),
                             ('max_x.mean_u',   v_max.mean()),
                             ('max_x.min_u',    v_max.min()),
                             ('min_x.max_u',    v_min.max()),
                             ('min_x.mean_u',   v_min.mean()),
                             ('min_x.min_u',    v_min.min()),
                             ('range_x.max_u',  v_range.max()),
                             ('range_x.mean_u', v_range.mean()),
                             ('range_x.min_u',  v_range.min()),
                             ('mean_x.max_u',   v_mean.max()),
                             ('mean_x.mean_u',  v_mean.mean()),
                             ('mean_x.min_u',   v_mean.min())]:
                rval[prefix+key] = val

        rval.update(OrderedDict([('kernel_norms_min',  row_norms.min()),
                            ('kernel_norms_mean', row_norms.mean()),
                            ('kernel_norms_max',  row_norms.max()), ]))

        return rval
开发者ID:cc13ny,项目名称:galatea,代码行数:57,代码来源:deconv.py


示例6: on_attach

    def on_attach(self, fgraph):
        """
        When attaching to a new fgraph, check that
            1) This DestroyHandler wasn't already attached to some fgraph
               (its data structures are only set up to serve one).
            2) The FunctionGraph doesn't already have a DestroyHandler.
               This would result in it validating everything twice, causing
               compilation to be slower.

        Give the FunctionGraph instance:
            1) A new method "destroyers(var)"
               TODO: what does this do exactly?
            2) A new attribute, "destroy_handler"
        TODO: WRITEME: what does this do besides the checks?

        """

        # Do the checking #
        already_there = False
        if self.fgraph is fgraph:
            already_there = True
        if self.fgraph is not None:
            raise Exception(
                "A DestroyHandler instance can only serve one"
                " FunctionGraph. (Matthew 6:24)")
        for attr in ('destroyers', 'destroy_handler'):
            if hasattr(fgraph, attr):
                already_there = True

        if already_there:
            # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
            raise toolbox.AlreadyThere(
                "DestroyHandler feature is already present"
                " or in conflict with another plugin.")

        # Annotate the FunctionGraph #
        self.unpickle(fgraph)
        fgraph.destroy_handler = self

        self.fgraph = fgraph
        self.destroyers = OrderedSet()  # set of Apply instances with non-null destroy_map
        self.view_i = OrderedDict()  # variable -> variable used in calculation
        self.view_o = OrderedDict()  # variable -> set of variables that use this one as a direct input
        # clients: how many times does an apply use a given variable
        self.clients = OrderedDict()  # variable -> apply -> ninputs
        self.stale_droot = True

        self.debug_all_apps = OrderedSet()
        if self.do_imports_on_attach:
            toolbox.Bookkeeper.on_attach(self, fgraph)
开发者ID:12190143,项目名称:Theano,代码行数:50,代码来源:destroyhandler.py


示例7: __init__

 def __init__(self, *axis):
     # Sort them to make sure we merge all possible case.
     items = sorted(axis)
     self.axis = OrderedDict(items)
     for axis, broad in iteritems(self.axis):
         assert isinstance(axis, (numpy.integer, int)), ("Rebroadcast needs integer axes. Got ", axis)
         assert isinstance(broad, bool), ("Rebroadcast needs bool for new broadcast pattern. Got ", broad)
开发者ID:ZhangAustin,项目名称:attention-lvcsr,代码行数:7,代码来源:ops.py


示例8: get_layer_monitoring_channels

 def get_layer_monitoring_channels(self, state_below=None,
                                   state=None, targets=None):
     W, = self.transformer.get_params()
     assert W.ndim == 5
     sq_W = T.sqr(W)
     row_norms = T.sqrt(sq_W.sum(axis=(1, 2, 3, 4)))
     rval = OrderedDict([
                        ('kernel_norms_min', row_norms.min()),
                        ('kernel_norms_mean', row_norms.mean()),
                        ('kernel_norms_max', row_norms.max()),
                        ])
     cost = self.cost
     orval = self.nonlin.get_monitoring_channels_from_state(state,
                                                            targets,
                                                            cost_fn=cost)
     rval.update(orval)
     return rval
开发者ID:robintibor,项目名称:pylearn3dconv,代码行数:17,代码来源:base.py


示例9: __init__

    def __init__(self, valid=None, invalid=None, valid_equivalent=None):
        if valid is None:
            valid = []
        if invalid is None:
            invalid = []
        if valid_equivalent is None:
            valid_equivalent = OrderedDict()

        # Nodes that are valid to have in the graph computing outputs
        self.valid = set(valid)

        # Nodes that are NOT valid to have in the graph computing outputs
        self.invalid = set(invalid)

        # Mapping from invalid variables to equivalent valid ones.
        self.valid_equivalent = valid_equivalent.copy()
        self.valid.update(list(valid_equivalent.values()))
        self.invalid.update(list(valid_equivalent.keys()))
开发者ID:ALISCIFP,项目名称:Segmentation,代码行数:18,代码来源:scan_utils.py


示例10: get_monitoring_channels

    def get_monitoring_channels(self, data):
        if data is None:
            m = 100
        else:
            m = data.shape[0]
        n = self.mlp.get_input_space().get_total_dimension()
        noise = self.get_noise((m, n))
        rval = OrderedDict()

        try:
            rval.update(self.mlp.get_monitoring_channels((noise, None)))
        except Exception:
            warnings.warn("something went wrong with generator.mlp's monitoring channels")

        if  self.monitor_ll:
            rval['ll'] = T.cast(self.ll(data, self.ll_n_samples, self.ll_sigma),
                                        theano.config.floatX).mean()
            rval['nll'] = -rval['ll']
        return rval
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:19,代码来源:__init__.py


示例11: get_gradients

    def get_gradients(self, model, data, **kwargs):
        space, sources = self.get_data_specs(model)
        space.validate(data)
        assert isinstance(model, CompressAdversaryPair)
        g = model.compressor
        d = model.discriminator

        #get raw gradients for d and g objectives...
        d_obj, g_obj = self.get_objectives(model, data)
        g_params = g.get_params()
        d_params = d.get_params()
        for param in g_params:
            assert param not in d_params
        for param in d_params:
            assert param not in g_params
        
        d_grads = T.grad(d_obj, d_params)
        g_grads = T.grad(g_obj, g_params)

        # if self.scale_grads:
        #     S_grad = T.grad(g_obj, S)
        #     scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum()))
        #     g_grads = [g_grad * scale for g_grad in g_grads]

        #adjust raw gradients with control signals
        rval = OrderedDict()
        zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32'))

        if self.ever_train_discriminator:
            rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads])))
        else:
            rval.update(OrderedDict(zip(d_params, zeros)))

        if self.ever_train_compressor:
            rval.update(OrderedDict(safe_zip(g_params, [self.now_train_compressor * gg for gg in g_grads])))
        else:
            rval.update(OrderedDict(zip(g_params, zeros)))

        #update control signals using the updates return functionality
        updates = OrderedDict()
        #first, the clock
        self.future_train_clock = T.switch(T.ge(self.train_clock,self.discriminator_steps+self.joint_steps+self.compressor_steps),1.,self.train_clock+1.)
        updates[self.train_clock] = self.future_train_clock
        #then the control signals
        updates[self.now_train_discriminator] = T.switch(T.le(self.future_train_clock,self.discriminator_steps+self.joint_steps),1.,0.)
        updates[self.now_train_compressor] = T.switch(T.gt(self.future_train_clock,self.discriminator_steps),1.,0.)

        return rval, updates
开发者ID:vinmisra,项目名称:adversary-compress,代码行数:48,代码来源:CAN.py


示例12: __init__

    def __init__(self, *axis):
        # Sort them to make sure we merge all possible case.
        items = sorted(axis)
        self.axis = OrderedDict(items)
        for axis, broad in iteritems(self.axis):
            if not isinstance(axis, (numpy.integer, integer_types)):
                raise TypeError("Rebroadcast needs integer axes. "
                                "Got {}".format(axis))

            if not isinstance(broad, (numpy.bool_, bool)):
                raise TypeError("Rebroadcast needs bool for new broadcast "
                                "pattern. Got {}".format(broad))
开发者ID:Azrael1,项目名称:Theano,代码行数:12,代码来源:ops.py


示例13: get_lr_scalers

    def get_lr_scalers(self):
        """
        .. todo::

            WRITEME
        """
        rval = OrderedDict()

        params = self.get_params()

        for layer in self.hidden_layers + [ self.visible_layer ]:
            contrib = layer.get_lr_scalers()

            # No two layers can contend to scale a parameter
            assert not any([key in rval for key in contrib])
            # Don't try to scale anything that's not a parameter
            assert all([key in params for key in contrib])

            rval.update(contrib)
        assert all([isinstance(val, float) for val in rval.values()])

        return rval
开发者ID:JakeMick,项目名称:pylearn2,代码行数:22,代码来源:dbm.py


示例14: scan


#.........这里部分代码省略.........
    # Make sure we get rid of numpy arrays or ints or anything like that
    # passed as inputs to scan
    non_seqs = []
    for elem in wrap_into_list(non_sequences):
        if not isinstance(elem, gof.Variable):
            non_seqs.append(tensor.as_tensor_variable(elem))
        else:
            non_seqs.append(elem)

    # If we provided a known number of steps ( before compilation)
    # and if that number is 1 or -1, then we can skip the Scan Op,
    # and just apply the inner function once
    # To do that we check here to see the nature of n_steps
    n_fixed_steps = None

    if isinstance(n_steps, (float, int)):
        n_fixed_steps = int(n_steps)
    else:
        try:
            n_fixed_steps = opt.get_scalar_constant_value(n_steps)
        except tensor.basic.NotScalarConstantError:
            n_fixed_steps = None

    # Check n_steps is an int
    if (hasattr(n_steps, 'dtype') and
        str(n_steps.dtype)[:3] not in ('uin', 'int')):
        raise ValueError(' n_steps must be an int. dtype provided '
                         'is %s' % n_steps.dtype)

    # compute number of sequences and number of outputs
    n_seqs = len(seqs)
    n_outs = len(outs_info)

    return_steps = OrderedDict()
    # wrap sequences in a dictionary if they are not already dictionaries
    for i in xrange(n_seqs):
        if not isinstance(seqs[i], dict):
            seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])])
        elif seqs[i].get('taps', None) is not None:
            seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
        elif seqs[i].get('taps', None) is None:
            # seqs dictionary does not have the ``taps`` key
            seqs[i]['taps'] = [0]

    # wrap outputs info in a dictionary if they are not already in one
    for i in xrange(n_outs):
        if outs_info[i] is not None:
            if isinstance(outs_info[i], dict):
                # DEPRECATED :
                if outs_info[i].get('return_steps', None) is not None:
                    raise ValueError(
                            "Using `return_steps` has been deprecated. "
                            "Simply select the entries you need using a "
                            "subtensor. Scan will optimize memory "
                            "consumption, so do not worry about that.")
                # END

            if not isinstance(outs_info[i], dict):
                # by default any output has a tap value of -1
                outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])])
            elif (outs_info[i].get('initial', None) is None and
                    outs_info[i].get('taps', None) is not None):
                # ^ no initial state but taps provided
                raise ValueError(('If you are using slices of an output '
                                  'you need to provide a initial state '
                                  'for it'), outs_info[i])
开发者ID:Micseb,项目名称:Theano,代码行数:67,代码来源:scan.py


示例15: Rebroadcast

class Rebroadcast(gof.Op):
    """
    Change the input's broadcastable fields in some predetermined way.

    See Also
    --------
    unbroadcast <theano.tensor.unbroadcast>
    addbroadcast <theano.tensor.addbroadcast>
    patternbroadcast <theano.tensor.patternbroadcast>

    Notes
    -----
    Works inplace and works for CudaNdarrayType.

    Example
    -------
    `Rebroadcast((0, True), (1, False))(x)` would make `x` broadcastable in
    axis 0 and not broadcastable in axis 1.

    """

    view_map = {0: [0]}
    _f16_ok = True
    # Mapping from Type to C code (and version) to use.
    # In the C code, the name of the input variable is %(iname)s,
    # the output variable is %(oname)s.
    c_code_and_version = {}

    check_input = False
    __props__ = ("axis",)

    def __init__(self, *axis):
        # Sort them to make sure we merge all possible case.
        items = sorted(axis)
        self.axis = OrderedDict(items)
        for axis, broad in iteritems(self.axis):
            if not isinstance(axis, (numpy.integer, integer_types)):
                raise TypeError("Rebroadcast needs integer axes. "
                                "Got {}".format(axis))

            if not isinstance(broad, (numpy.bool_, bool)):
                raise TypeError("Rebroadcast needs bool for new broadcast "
                                "pattern. Got {}".format(broad))

    def __hash__(self):
        # Need special __hash__ as dict aren't hashable.
        # no ambiguity because each item key is unique
        items = sorted(iteritems(self.axis))
        return hash((type(self), tuple(items)))

    def __str__(self):
        if len(self.axis) == 0:
            broadcast_pattern = []
        else:
            broadcast_pattern = ['?' for i
                                 in xrange(1 + max(self.axis.keys()))]
        for k, v in iteritems(self.axis):
            broadcast_pattern[k] = str(int(v))
        return '%s{%s}' % (self.__class__.__name__,
                           ','.join(broadcast_pattern))

    def make_node(self, x):
        if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
            raise ValueError('Trying to rebroadcast non-existent dimension')
        t = x.type.clone(
            broadcastable=[self.axis.get(i, b)
                           for i, b in enumerate(x.type.broadcastable)])
        return gof.Apply(self, [x], [t()])

    def perform(self, node, inp, out_):
        x, = inp
        out, = out_
        for axis, value in iteritems(self.axis):
            if value and x.shape[axis] != 1:
                raise ValueError('Dimension %s in Rebroadcast\'s input was'
                                 ' supposed to be 1 (got %s instead)' %
                                 (axis, x.shape[axis]))
        out[0] = x

    def grad(self, inp, grads):
        x, = inp
        gz, = grads
        # restore the broadcasting pattern of the input
        return Rebroadcast(*[(axis, x.type.broadcastable[axis])
                             for axis, value in iteritems(self.axis)])(gz),

    def infer_shape(self, node, ishapes):
        assert len(ishapes) == 1
        l = []
        one = theano.tensor.basic.constant(1)
        for ax in xrange(len(ishapes[0])):
            if self.axis.get(ax, False):
                l.append(one)
            else:
                l.append(ishapes[0][ax])

        return [tuple(l)]

    def R_op(self, inputs, eval_points):
        if eval_points[0] is None:
#.........这里部分代码省略.........
开发者ID:Azrael1,项目名称:Theano,代码行数:101,代码来源:ops.py


示例16: DestroyHandler

class DestroyHandler(toolbox.Bookkeeper):  # noqa
    """
    The DestroyHandler class detects when a graph is impossible to evaluate
    because of aliasing and destructive operations.

    Several data structures are used to do this.

    An Op can use its view_map property to declare that an output may be
    aliased to an input. If that output is destroyed, the input is also
    considered to be destroyed. The view_maps of several Ops can feed into
    one another and form a directed graph. The consequence of destroying any
    variable in such a graph is that all variables in the graph must be
    considered to be destroyed, because they could all be refering to the
    same underlying storage.

    In the current implementation, that graph is a tree, and the root of that
    tree is called the foundation.

    TODO: why "in the current implementation" ? is there another implementation
          planned?
    TODO: why is the graph a tree? isn't it possible that one variable could
          be aliased to many variables? for example, don't switch and ifelse
          have to do this?

    The original DestroyHandler (if 0'ed out above) computed several data
    structures from scratch each time it was asked to validate the graph.
    Because this happens potentially thousands of times and each graph to
    validate is extremely similar to the previous one, computing the
    data structures from scratch repeatedly was wasteful and resulted in
    high compile times for large graphs.

    This implementation computes the data structures once at initialization
    and then incrementally updates them.

    It is a work in progress. The following data structures have been
    converted to use the incremental strategy:
        <none>

    The following data structures remain to be converted:
        <unknown>
    """
    pickle_rm_attr = ["destroyers"]

    def __init__(self, do_imports_on_attach=True):
        self.fgraph = None
        self.do_imports_on_attach = do_imports_on_attach

        """maps every variable in the graph to its "foundation" (deepest
        ancestor in view chain)
        TODO: change name to var_to_vroot"""
        self.droot = OrderedDict()

        """maps a variable to all variables that are indirect or direct views of it
         (including itself)
         essentially the inverse of droot
        TODO: do all variables appear in this dict, or only those that are foundations?
        TODO: do only destroyed variables go in here? one old docstring said so
        TODO: rename to x_to_views after reverse engineering what x is"""
        self.impact = OrderedDict()

        """if a var is destroyed, then this dict will map
        droot[var] to the apply node that destroyed var
        TODO: rename to vroot_to_destroyer"""
        self.root_destroyer = OrderedDict()

    def on_attach(self, fgraph):
        """
        When attaching to a new fgraph, check that
            1) This DestroyHandler wasn't already attached to some fgraph
               (its data structures are only set up to serve one)
            2) The FunctionGraph doesn't already have a DestroyHandler.
               This would result in it validating everything twice, causing
               compilation to be slower.

        Give the FunctionGraph instance:
            1) A new method "destroyers(var)"
                TODO: what does this do exactly?
            2) A new attribute, "destroy_handler"
        TODO: WRITEME: what does this do besides the checks?
        """

        # Do the checking #
        already_there = False
        if self.fgraph is fgraph:
            already_there = True
        if self.fgraph is not None:
            raise Exception(
                "A DestroyHandler instance can only serve one"
                " FunctionGraph. (Matthew 6:24)")
        for attr in ('destroyers', 'destroy_handler'):
            if hasattr(fgraph, attr):
                already_there = True

        if already_there:
            # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
            raise toolbox.AlreadyThere(
                "DestroyHandler feature is already present"
                " or in conflict with another plugin.")

        # Annotate the FunctionGraph #
#.........这里部分代码省略.........
开发者ID:ragavvenkatesan,项目名称:Theano,代码行数:101,代码来源:destroyhandler.py


示例17: xrange

            n_outs = n_outs - 1
        outs_info = [OrderedDict() for x in xrange(n_outs)]

    # Step 5.1 Outputs with taps different then -1

    for i, out in enumerate(outs_info):
        if 'taps' in out and out['taps'] != [-1]:
            mit_sot_inner_outputs.append(outputs[i])

    # Step 5.2 Outputs with tap equal to -1
    for i, out in enumerate(outs_info):
        if 'taps' in out and out['taps'] == [-1]:
            sit_sot_inner_outputs.append(outputs[i])

    # Step 5.3 Outputs that correspond to update rules of shared variables
    givens = OrderedDict()
    n_shared_outs = 0
    shared_scan_inputs = []
    shared_inner_inputs = []
    shared_inner_outputs = []
    sit_sot_shared = []
    for input in dummy_f.maker.expanded_inputs:
        if isinstance(input.variable, SharedVariable) and input.update:
            new_var = safe_new(input.variable)
            if getattr(input.variable, 'name', None) is not None:
                new_var.name = input.variable.name + '_copy'
            if isinstance(new_var.type, ops.expandable_types):
                sit_sot_inner_inputs.append(new_var)
                sit_sot_scan_inputs.append(
                    scan_utils.expand(
                        tensor.unbroadcast(
开发者ID:TimSalimans,项目名称:Theano,代码行数:31,代码来源:scan.py


示例18: scan


#.........这里部分代码省略.........
        allow_gc = config.scan.allow_gc

    # Make sure we get rid of numpy arrays or ints or anything like that
    # passed as inputs to scan
    non_seqs = []
    for elem in wrap_into_list(params):
        if not isinstance(elem, gof.Variable):
            non_seqs.append(tensor.as_tensor_variable(elem))
        else:
            non_seqs.append(elem)

    # If we provided a known number of steps ( before compilation)
    # and if that number is 1 or -1, then we can skip the Scan Op,
    # and just apply the inner function once
    # To do that we check here to see the nature of n_steps
    n_fixed_steps = None

    if isinstance(n_steps, (float, int)):
        n_fixed_steps = int(n_steps)
    else:
        try:
            n_fixed_steps = opt.get_scalar_constant_value(n_steps)
        except tensor.basic.NotScalarConstantError:
            n_fixed_steps = None

    # Check n_steps is an int
    if hasattr(n_steps, "dtype") and str(n_steps.dtype)[:3] not in ("uin", "int"):
        raise ValueError(" n_steps must be an int. dtype provided " "is %s" % n_steps.dtype)

    # compute number of sequences and number of outputs
    n_seqs = len(seqs)
    n_outs = len(outs_info)

    return_steps = OrderedDict()
    # wrap outputs info in a dictionary if they are not already in one
    for i in xrange(n_outs):
        if outs_info[i] is not None:
            if not isinstance(outs_info[i], dict):
                # by default any output has a tap value of -1
                outs_info[i] = dict(membuf=outs_info[i], taps=[-1])
            elif not outs_info[i].get("membuf", None) and outs_info[i].get("taps", None):
                # ^ no initial state but taps provided
                raise ValueError(
                    ("If you are using slices of an output " "you need to provide a memory buffer for " "the state "),
                    outs_info[i],
                )
            elif outs_info[i].get("membuf", None) and not outs_info[i].get("taps", None):
                # ^ initial state but taps not provided
                if "taps" in outs_info[i]:
                    # ^ explicitly provided a None for taps
                    _logger.warning(
                        "Output %s (index %d) has a memory " "buffer but taps is explicitly set to None ",
                        getattr(outs_info[i]["membuf"], "name", "None"),
                        i,
                    )
                outs_info[i]["taps"] = [-1]
        else:
            # if a None is provided as the output info we replace it
            # with an dict(steps=n_steps) to simplify handling
            outs_info[i] = dict(steps=n_steps)

    ##
    # Step 2. Generate inputs and outputs of the inner functions
    # for compiling a dummy function (Iteration #1)
    ##
开发者ID:amanrajdce,项目名称:Theano,代码行数:66,代码来源:scan.py


示例19: get_gradients

    def get_gradients(self, model, data, **kwargs):
        space, sources = self.get_data_specs(model)
        space.validate(data)
        assert isinstance(model, AdversaryPair)
        g = model.generator
        d = model.discriminator

        S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data)

        g_params = g.get_params()
        d_params = d.get_params()
        for param in g_params:
            assert param not in d_params
        for param in d_params:
            assert param not in g_params
        d_grads = T.grad(d_obj, d_params)
        g_grads = T.grad(g_obj, g_params)

        if self.scale_grads:
            S_grad = T.grad(g_obj, S)
            scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum()))
            g_grads = [g_grad * scale for g_grad in g_grads]

        rval = OrderedDict()
        zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32'))
        if self.ever_train_discriminator:
            rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads])))
        else:
            rval.update(OrderedDict(zip(d_params, zeros)))
        if self.ever_train_generator:
            rval.update(OrderedDict(safe_zip(g_params, [self.now_train_generator * gg for gg in g_grads])))
        else:
            rval.update(OrderedDict(zip(g_params, zeros)))
        if self.ever_train_inference and model.inferer is not None:
            i_params = model.inferer.get_params()
            i_grads = T.grad(i_obj, i_params)
            rval.update(OrderedDict(safe_zip(i_params, [self.now_train_inference * ig for ig in i_grads])))
        elif model.inferer is not None:
            rval.update(OrderedDict(model.inferer.get_params(), zeros))

        updates = OrderedDict()

        # Two d steps for every g step
        if self.alternate_g:
            updates[self.now_train_generator] = 1. - self.now_train_generator

        return rval, updates
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:47,代码来源:__init__.py



注:本文中的theano.compat.OrderedDict类示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python python2x.any函数代码示例发布时间:2022-05-27
下一篇:
Python compat.izip函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap