• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python context.cpu函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中mxnet.context.cpu函数的典型用法代码示例。如果您正苦于以下问题:Python cpu函数的具体用法?Python cpu怎么用?Python cpu使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了cpu函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: get_mobilenet

def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
                  root='~/.mxnet/models', **kwargs):
    r"""MobileNet model from the
    `"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
    <https://arxiv.org/abs/1704.04861>`_ paper.

    Parameters
    ----------
    multiplier : float
        The width multiplier for controling the model size. Only multipliers that are no
        less than 0.25 are supported. The actual number of channels is equal to the original
        channel size multiplied by this multiplier.
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = MobileNet(multiplier, **kwargs)

    if pretrained:
        from .model_store import get_model_file
        version_suffix = '{0:.2f}'.format(multiplier)
        if version_suffix in ('1.00', '0.50'):
            version_suffix = version_suffix[:-1]
        net.load_parameters(
            get_model_file('mobilenet%s' % version_suffix, tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        net.synset = attrib.synset
        net.classes = attrib.classes
        net.classes_long = attrib.classes_long
    return net
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:35,代码来源:mobilenet.py


示例2: get_vgg

def get_vgg(num_layers, pretrained=False, ctx=cpu(),
            root='~/.mxnet/models', **kwargs):
    r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
    <https://arxiv.org/abs/1409.1556>`_ paper.

    Parameters
    ----------
    num_layers : int
        Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    layers, filters = vgg_spec[num_layers]
    net = VGG(layers, filters, **kwargs)
    if pretrained:
        from .model_store import get_model_file
        batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
        net.load_parameters(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix),
                                           tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        net.synset = attrib.synset
        net.classes = attrib.classes
        net.classes_long = attrib.classes_long
    return net
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:30,代码来源:vgg.py


示例3: inception_v3

def inception_v3(pretrained=False, ctx=cpu(),
                 root='~/.mxnet/models', **kwargs):
    r"""Inception v3 model from
    `"Rethinking the Inception Architecture for Computer Vision"
    <http://arxiv.org/abs/1512.00567>`_ paper.

    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    norm_layer : object
        Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
        Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    norm_kwargs : dict
        Additional `norm_layer` arguments, for example `num_devices=4`
        for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    """
    net = Inception3(**kwargs)
    if pretrained:
        from .model_store import get_model_file
        net.load_parameters(get_model_file('inceptionv3',
                                           tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        net.synset = attrib.synset
        net.classes = attrib.classes
        net.classes_long = attrib.classes_long
    return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:33,代码来源:inception.py


示例4: inception_v3

def inception_v3(pretrained=False, ctx=cpu(),
                 root='~/.mxnet/models', **kwargs):
    r"""Inception v3 model from
    `"Rethinking the Inception Architecture for Computer Vision"
    <http://arxiv.org/abs/1512.00567>`_ paper.

    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    """
    net = Inception3(**kwargs)
    if pretrained:
        from .model_store import get_model_file
        net.load_parameters(get_model_file('inceptionv3',
                                           tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        net.synset = attrib.synset
        net.classes = attrib.classes
        net.classes_long = attrib.classes_long
    return net
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:27,代码来源:inception.py


示例5: __init__

    def __init__(self, symbol, data_names, label_names,
                 logger=logging, context=ctx.cpu(), work_load_list=None,
                 max_data_shapes=None, max_label_shapes=None, fixed_param_prefix=None):
        super(MutableModule, self).__init__(logger=logger)
        self._symbol = symbol
        self._data_names = data_names
        self._label_names = label_names
        self._context = context
        self._work_load_list = work_load_list

        self._curr_module = None
        self._max_data_shapes = max_data_shapes
        self._max_label_shapes = max_label_shapes
        self._fixed_param_prefix = fixed_param_prefix

        if self._max_data_shapes is None:
            self._max_data_shapes = []
        if self._max_label_shapes is None:
            self._max_label_shapes = []
        if self._fixed_param_prefix is None:
            self._fixed_param_prefix = []

        fixed_param_names = list()
        for name in self._symbol.list_arguments():
            for prefix in self._fixed_param_prefix:
                if prefix in name:
                    fixed_param_names.append(name)
        self._fixed_param_names = fixed_param_names
开发者ID:Alven8816,项目名称:mxnet,代码行数:28,代码来源:module.py


示例6: resnet18_v1b_89

def resnet18_v1b_89(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
    """Constructs a ResNetV1b-18_2.6x model. Uses resnet18_v1b construction from resnetv1b.py

    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    """
    model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], name_prefix='resnetv1b_', **kwargs)
    dirname = os.path.dirname(__file__)
    json_filename = os.path.join(dirname, 'resnet%d_v%db_%.1fx' % (18, 1, 2.6) + ".json")
    with open(json_filename, "r") as jsonFile:
        params_shapes = json.load(jsonFile)
    if pretrained:
        from ..model_store import get_model_file
        params_file = get_model_file('resnet%d_v%db_%.1fx' % (18, 1, 2.6), tag=pretrained,
                                     root=root)
        prune_gluon_block(model, model.name, params_shapes, params=ndarray.load(params_file),
                          pretrained=True, ctx=ctx)
    else:
        prune_gluon_block(model, model.name, params_shapes, params=None, pretrained=False, ctx=ctx)
    if pretrained:
        from ...data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        model.synset = attrib.synset
        model.classes = attrib.classes
        model.classes_long = attrib.classes_long
    return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:33,代码来源:resnetv1b_pruned.py


示例7: resnet152_v1b

def resnet152_v1b(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
    """Constructs a ResNetV1b-152 model.

    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    dilated: bool, default False
        Whether to apply dilation strategy to ResNetV1b, yielding a stride 8 model.
    norm_layer : object
        Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
        Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    last_gamma : bool, default False
        Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
    use_global_stats : bool, default False
        Whether forcing BatchNorm to use global statistics instead of minibatch statistics;
        optionally set to True if finetuning using ImageNet classification pretrained models.
    """
    model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], name_prefix='resnetv1b_', **kwargs)
    if pretrained:
        from .model_store import get_model_file
        model.load_parameters(get_model_file('resnet%d_v%db'%(152, 1),
                                             tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        model.synset = attrib.synset
        model.classes = attrib.classes
        model.classes_long = attrib.classes_long
    return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:34,代码来源:resnetv1b.py


示例8: resnet152_v1s

def resnet152_v1s(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
    """Constructs a ResNetV1s-152 model.

    Parameters
    ----------
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    dilated: bool, default False
        Whether to apply dilation strategy to ResNetV1b, yielding a stride 8 model.
    norm_layer : object
        Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`).
        Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    """
    model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64,
                      name_prefix='resnetv1s_', **kwargs)
    if pretrained:
        from .model_store import get_model_file
        model.load_parameters(get_model_file('resnet%d_v%ds'%(152, 1),
                                             tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        model.synset = attrib.synset
        model.classes = attrib.classes
        model.classes_long = attrib.classes_long
    return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:30,代码来源:resnetv1b.py


示例9: get_densenet

def get_densenet(num_layers, pretrained=False, ctx=cpu(),
                 root='~/.mxnet/models', **kwargs):
    r"""Densenet-BC model from the
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper.

    Parameters
    ----------
    num_layers : int
        Number of layers for the variant of densenet. Options are 121, 161, 169, 201.
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    norm_layer : object
        Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
        Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    norm_kwargs : dict
        Additional `norm_layer` arguments, for example `num_devices=4`
        for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    """
    num_init_features, growth_rate, block_config = densenet_spec[num_layers]
    net = DenseNet(num_init_features, growth_rate, block_config, **kwargs)
    if pretrained:
        from .model_store import get_model_file
        net.load_parameters(get_model_file('densenet%d'%(num_layers),
                                           tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        net.synset = attrib.synset
        net.classes = attrib.classes
        net.classes_long = attrib.classes_long
    return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:35,代码来源:densenet.py


示例10: get_deeplab

def get_deeplab(dataset='pascal_voc', backbone='resnet50', pretrained=False,
            root='~/.mxnet/models', ctx=cpu(0), **kwargs):
    r"""DeepLabV3
    Parameters
    ----------
    dataset : str, default pascal_voc
        The dataset that model pretrained on. (pascal_voc, ade20k)
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default '~/.mxnet/models'
        Location for keeping the model parameters.

    Examples
    --------
    >>> model = get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False)
    >>> print(model)
    """
    acronyms = {
        'pascal_voc': 'voc',
        'pascal_aug': 'voc',
        'ade20k': 'ade',
        'coco': 'coco',
    }
    from ..data import datasets
    # infer number of classes
    model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, ctx=ctx, **kwargs)
    if pretrained:
        from .model_store import get_model_file
        model.load_parameters(get_model_file('deeplab_%s_%s'%(backbone, acronyms[dataset]),
                                             tag=pretrained, root=root), ctx=ctx)
    return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:34,代码来源:deeplabv3.py


示例11: fetcher_loop

def fetcher_loop(data_queue, data_buffer, pin_memory=False):
    """Fetcher loop for fetching data from queue and put in reorder dict."""
    while True:
        idx, batch = data_queue.get()
        if idx is None:
            break
        if pin_memory:
            batch = _as_in_context(batch, context.cpu_pinned())
        else:
            batch = _as_in_context(batch, context.cpu())
        data_buffer[idx] = batch
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:11,代码来源:dataloader.py


示例12: get_params

    def get_params(self, arg_params, aux_params):
        """ Copy data from each executor to `arg_params` and `aux_params`.

        Parameters
        ----------
        arg_params : list of NDArray
            target parameter arrays
        aux_params : list of NDArray
            target aux arrays

        Notes
        -----
        - This function will inplace update the NDArrays in arg_params and aux_params.
        """
        for name, block in zip(self.param_names, self.param_arrays):
            weight = sum(w.copyto(ctx.cpu()) for w in block) / len(block)
            weight.astype(arg_params[name].dtype).copyto(arg_params[name])
        for name, block in zip(self.aux_names, self.aux_arrays):
            weight = sum(w.copyto(ctx.cpu()) for w in block) / len(block)
            weight.astype(aux_params[name].dtype).copyto(aux_params[name])
开发者ID:ktr-hubrt,项目名称:Deformable-ConvNets,代码行数:20,代码来源:DataParallelExecutorGroup.py


示例13: get_simple_pose_resnet

def get_simple_pose_resnet(base_name, pretrained=False, ctx=cpu(),
                           root='~/.mxnet/models', **kwargs):

    net = SimplePoseResNet(base_name, **kwargs)

    if pretrained:
        from ..model_store import get_model_file
        net.load_parameters(get_model_file('simple_pose_%s'%(base_name),
                                           tag=pretrained, root=root), ctx=ctx)

    return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:11,代码来源:simple_pose_resnet.py


示例14: __init__

 def __init__(self, nclass, backbone='resnet50', aux=True, ctx=cpu(), pretrained_base=True,
              base_size=520, crop_size=480, **kwargs):
     super(FCN, self).__init__(nclass, aux, backbone, ctx=ctx, base_size=base_size,
                               crop_size=crop_size, pretrained_base=pretrained_base, **kwargs)
     with self.name_scope():
         self.head = _FCNHead(2048, nclass, **kwargs)
         self.head.initialize(ctx=ctx)
         self.head.collect_params().setattr('lr_mult', 10)
         if self.aux:
             self.auxlayer = _FCNHead(1024, nclass, **kwargs)
             self.auxlayer.initialize(ctx=ctx)
             self.auxlayer.collect_params().setattr('lr_mult', 10)
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:12,代码来源:fcn.py


示例15: get_resnet

def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
               root='~/.mxnet/models', use_se=False, **kwargs):
    r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
    <http://arxiv.org/abs/1512.03385>`_ paper.
    ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
    <https://arxiv.org/abs/1603.05027>`_ paper.

    Parameters
    ----------
    version : int
        Version of ResNet. Options are 1, 2.
    num_layers : int
        Numbers of layers. Options are 18, 34, 50, 101, 152.
    pretrained : bool or str
        Boolean value controls whether to load the default pretrained weights for model.
        String value represents the hashtag for a certain version of pretrained weights.
    ctx : Context, default CPU
        The context in which to load the pretrained weights.
    root : str, default $MXNET_HOME/models
        Location for keeping the model parameters.
    use_se : bool, default False
        Whether to use Squeeze-and-Excitation module
    norm_layer : object
        Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
        Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    norm_kwargs : dict
        Additional `norm_layer` arguments, for example `num_devices=4`
        for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
    """
    assert num_layers in resnet_spec, \
        "Invalid number of layers: %d. Options are %s"%(
            num_layers, str(resnet_spec.keys()))
    block_type, layers, channels = resnet_spec[num_layers]
    assert 1 <= version <= 2, \
        "Invalid resnet version: %d. Options are 1 and 2."%version
    resnet_class = resnet_net_versions[version-1]
    block_class = resnet_block_versions[version-1][block_type]
    net = resnet_class(block_class, layers, channels, **kwargs)
    if pretrained:
        from .model_store import get_model_file
        if not use_se:
            net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version),
                                               tag=pretrained, root=root), ctx=ctx)
        else:
            net.load_parameters(get_model_file('se_resnet%d_v%d'%(num_layers, version),
                                               tag=pretrained, root=root), ctx=ctx)
        from ..data import ImageNet1kAttr
        attrib = ImageNet1kAttr()
        net.synset = attrib.synset
        net.classes = attrib.classes
        net.classes_long = attrib.classes_long
    return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:52,代码来源:resnet.py


示例16: __init__

 def __init__(self, nclass, backbone='resnet50', aux=True, ctx=cpu(), pretrained_base=True,
              base_size=520, crop_size=480, **kwargs):
     super(DeepLabV3, self).__init__(nclass, aux, backbone, ctx=ctx, base_size=base_size,
                                  crop_size=crop_size, pretrained_base=pretrained_base, **kwargs)
     with self.name_scope():
         self.head = _DeepLabHead(nclass, height=self._up_kwargs['height']//8,
                                  width=self._up_kwargs['width']//8, **kwargs)
         self.head.initialize(ctx=ctx)
         self.head.collect_params().setattr('lr_mult', 10)
         if self.aux:
             self.auxlayer = _FCNHead(1024, nclass, **kwargs)
             self.auxlayer.initialize(ctx=ctx)
             self.auxlayer.collect_params().setattr('lr_mult', 10)
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:13,代码来源:deeplabv3.py


示例17: func

    def func(pretrained=False, tag=None, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
        r"""Quantized model.

        Parameters
        ----------
        pretrained : bool or str
            Boolean value controls whether to load the default pretrained weights for model.
            String value represents the hashtag for a certain version of pretrained weights.
        tag : str, default is None
            Optional length-8 sha1sum of parameter file. If `None`, best parameter file
            will be used.
        ctx : Context, default CPU
            The context in which to load the pretrained weights.
        root : str, default $MXNET_HOME/models
            Location for keeping the model parameters.
        """
        from ..model_zoo import get_model
        from ..model_store import get_model_file
        curr_dir = os.path.abspath(os.path.dirname(__file__))
        model_name = name.replace('mobilenet1_', 'mobilenet1.')
        model_name = model_name.replace('mobilenet0_', 'mobilenet0.')
        json_file = os.path.join(curr_dir, '{}-symbol.json'.format(model_name))
        base_name = '_'.join(model_name.split('_')[:-1])
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            param_file = get_model_file(base_name, tag=tag, root=root) if pretrained else None
            net = get_model('_'.join(model_name.split('_')[:-1]), prefix=sym_prefix)
            classes = getattr(net, 'classes', [])
            sym_net = SymbolBlock.imports(json_file, ['data'], None, ctx=ctx)
            if param_file:
                # directly imports weights saved by save_parameters is not applicable
                # so we hack it by load and export once to a temporary params file
                import tempfile
                net.load_params(param_file)
                net.hybridize()
                if '512' in base_name:
                    net(mx.nd.zeros((1, 3, 512, 512)))
                elif '300' in base_name:
                    net(mx.nd.zeros((1, 3, 300, 300)))
                else:
                    net(mx.nd.zeros((1, 3, 224, 224)))
                with tempfile.TemporaryDirectory() as tmpdirname:
                    prefix = os.path.join(tmpdirname, 'tmp')
                    net.export(prefix, epoch=0)
                    param_prefix = prefix + '-0000.params'
                    sym_net.collect_params().load(param_prefix)
        sym_net.classes = classes
        sym_net.reset_class = _not_impl
        sym_net.set_nms = _not_impl
        return sym_net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:50,代码来源:quantized.py


示例18: __init__

    def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
                 logger=logging, context=ctx.cpu(), work_load_list=None,
                 fixed_param_names=None, state_names=None):
        super(Module, self).__init__(logger=logger)

        if isinstance(context, ctx.Context):
            context = [context]
        self._context = context
        if work_load_list is None:
            work_load_list = [1] * len(self._context)
        assert len(work_load_list) == len(self._context)
        self._work_load_list = work_load_list

        self._symbol = symbol

        data_names = list(data_names) if data_names is not None else []
        label_names = list(label_names) if label_names is not None else []
        state_names = list(state_names) if state_names is not None else []
        fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []

        _check_input_names(symbol, data_names, "data", True)
        _check_input_names(symbol, label_names, "label", False)
        _check_input_names(symbol, state_names, "state", True)
        _check_input_names(symbol, fixed_param_names, "fixed_param", True)

        arg_names = symbol.list_arguments()
        input_names = data_names + label_names + state_names
        self._param_names = [x for x in arg_names if x not in input_names]
        self._fixed_param_names = fixed_param_names
        self._aux_names = symbol.list_auxiliary_states()
        self._data_names = data_names
        self._label_names = label_names
        self._state_names = state_names
        self._output_names = symbol.list_outputs()

        self._arg_params = None
        self._aux_params = None
        self._params_dirty = False

        self._optimizer = None
        self._kvstore = None
        self._update_on_kvstore = None
        self._updater = None
        self._preload_opt_states = None
        self._grad_req = None

        self._exec_group = None
        self._data_shapes = None
        self._label_shapes = None
开发者ID:ktr-hubrt,项目名称:Deformable-ConvNets,代码行数:49,代码来源:module.py


示例19: prune_gluon_block

def prune_gluon_block(net, prefix, params_shapes, params=None, pretrained=False, ctx=cpu(0)):
    """
    :param params_shapes: dictionary of shapes of convolutional weights
    :param prefix: prefix of the original resnet50_v1d
    :param pretrained: Boolean specifying if the pretrained model parameters needs to be loaded
    :param net: original network that is required to be pruned
    :param params: dictionary of parameters for the pruned network. Size of the parameters in
    this dictionary tells what
    should be the size of channels of each convolution layer.
    :param ctx: cpu(0)
    :return: "net"
    """
    for _, layer in net._children.items():
        if pretrained:
            if isinstance(layer, nn.BatchNorm):
                params_layer = layer._collect_params_with_prefix()
                for param_name in ['beta', 'gamma', 'running_mean', 'running_var']:
                    param_val = params[layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
                    layer.params.get(param_name)._shape = param_val.shape
                    params_layer[param_name]._load_init(param_val, ctx=ctx)

        if isinstance(layer, nn.Conv2D):
            param_shape = params_shapes[layer.name.replace(prefix, "resnetv1d") + "_weight"]
            layer._channels = param_shape[0]
            layer._kwargs['num_filter'] = param_shape[0]

            params_layer = layer._collect_params_with_prefix()
            for param_name in ['weight']:
                param_shape = params_shapes[
                    layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
                layer.params.get(param_name)._shape = param_shape
                if pretrained:
                    param_val = params[layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
                    params_layer[param_name]._load_init(param_val, ctx=ctx)

        if isinstance(layer, nn.Dense):
            layer._in_units = params_shapes[layer.name.replace(prefix, "resnetv1d") + "_weight"][1]

            params_layer = layer._collect_params_with_prefix()
            for param_name in ['weight', 'bias']:
                param_shape = params_shapes[
                    layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
                layer.params.get(param_name)._shape = param_shape
                if pretrained:
                    param_val = params[layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
                    params_layer[param_name]._load_init(param_val, ctx=ctx)
        else:
            prune_gluon_block(layer, prefix, params_shapes, params, pretrained, ctx)
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:48,代码来源:resnetv1b_pruned.py


示例20: __init__

    def __init__(self, symbol, data_names, label_names,
                 logger=logging, context=ctx.cpu(), work_load_list=None,
                 max_data_shapes=None, max_label_shapes=None):
        super(MutableModule, self).__init__(logger=logger)
        self._symbol = symbol
        self._data_names = data_names
        self._label_names = label_names
        self._context = context
        self._work_load_list = work_load_list

        self._curr_module = None
        self._max_data_shapes = max_data_shapes
        self._max_label_shapes = max_label_shapes
        if self._max_data_shapes is None:
            self._max_data_shapes = []
        if self._max_label_shapes is None:
            self._max_label_shapes = []
开发者ID:Alexbert1,项目名称:mxnet,代码行数:17,代码来源:module.py



注:本文中的mxnet.context.cpu函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python nd.array函数代码示例发布时间:2022-05-27
下一篇:
Python autograd.record函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap