本文整理汇总了Python中mxnet.context.cpu函数的典型用法代码示例。如果您正苦于以下问题:Python cpu函数的具体用法?Python cpu怎么用?Python cpu使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了cpu函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: get_mobilenet
def get_mobilenet(multiplier, pretrained=False, ctx=cpu(),
root='~/.mxnet/models', **kwargs):
r"""MobileNet model from the
`"MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
<https://arxiv.org/abs/1704.04861>`_ paper.
Parameters
----------
multiplier : float
The width multiplier for controling the model size. Only multipliers that are no
less than 0.25 are supported. The actual number of channels is equal to the original
channel size multiplied by this multiplier.
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = MobileNet(multiplier, **kwargs)
if pretrained:
from .model_store import get_model_file
version_suffix = '{0:.2f}'.format(multiplier)
if version_suffix in ('1.00', '0.50'):
version_suffix = version_suffix[:-1]
net.load_parameters(
get_model_file('mobilenet%s' % version_suffix, tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
net.synset = attrib.synset
net.classes = attrib.classes
net.classes_long = attrib.classes_long
return net
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:35,代码来源:mobilenet.py
示例2: get_vgg
def get_vgg(num_layers, pretrained=False, ctx=cpu(),
root='~/.mxnet/models', **kwargs):
r"""VGG model from the `"Very Deep Convolutional Networks for Large-Scale Image Recognition"
<https://arxiv.org/abs/1409.1556>`_ paper.
Parameters
----------
num_layers : int
Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
layers, filters = vgg_spec[num_layers]
net = VGG(layers, filters, **kwargs)
if pretrained:
from .model_store import get_model_file
batch_norm_suffix = '_bn' if kwargs.get('batch_norm') else ''
net.load_parameters(get_model_file('vgg%d%s'%(num_layers, batch_norm_suffix),
tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
net.synset = attrib.synset
net.classes = attrib.classes
net.classes_long = attrib.classes_long
return net
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:30,代码来源:vgg.py
示例3: inception_v3
def inception_v3(pretrained=False, ctx=cpu(),
root='~/.mxnet/models', **kwargs):
r"""Inception v3 model from
`"Rethinking the Inception Architecture for Computer Vision"
<http://arxiv.org/abs/1512.00567>`_ paper.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
norm_kwargs : dict
Additional `norm_layer` arguments, for example `num_devices=4`
for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""
net = Inception3(**kwargs)
if pretrained:
from .model_store import get_model_file
net.load_parameters(get_model_file('inceptionv3',
tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
net.synset = attrib.synset
net.classes = attrib.classes
net.classes_long = attrib.classes_long
return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:33,代码来源:inception.py
示例4: inception_v3
def inception_v3(pretrained=False, ctx=cpu(),
root='~/.mxnet/models', **kwargs):
r"""Inception v3 model from
`"Rethinking the Inception Architecture for Computer Vision"
<http://arxiv.org/abs/1512.00567>`_ paper.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
net = Inception3(**kwargs)
if pretrained:
from .model_store import get_model_file
net.load_parameters(get_model_file('inceptionv3',
tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
net.synset = attrib.synset
net.classes = attrib.classes
net.classes_long = attrib.classes_long
return net
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:27,代码来源:inception.py
示例5: __init__
def __init__(self, symbol, data_names, label_names,
logger=logging, context=ctx.cpu(), work_load_list=None,
max_data_shapes=None, max_label_shapes=None, fixed_param_prefix=None):
super(MutableModule, self).__init__(logger=logger)
self._symbol = symbol
self._data_names = data_names
self._label_names = label_names
self._context = context
self._work_load_list = work_load_list
self._curr_module = None
self._max_data_shapes = max_data_shapes
self._max_label_shapes = max_label_shapes
self._fixed_param_prefix = fixed_param_prefix
if self._max_data_shapes is None:
self._max_data_shapes = []
if self._max_label_shapes is None:
self._max_label_shapes = []
if self._fixed_param_prefix is None:
self._fixed_param_prefix = []
fixed_param_names = list()
for name in self._symbol.list_arguments():
for prefix in self._fixed_param_prefix:
if prefix in name:
fixed_param_names.append(name)
self._fixed_param_names = fixed_param_names
开发者ID:Alven8816,项目名称:mxnet,代码行数:28,代码来源:module.py
示例6: resnet18_v1b_89
def resnet18_v1b_89(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a ResNetV1b-18_2.6x model. Uses resnet18_v1b construction from resnetv1b.py
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
"""
model = ResNetV1b(BasicBlockV1b, [2, 2, 2, 2], name_prefix='resnetv1b_', **kwargs)
dirname = os.path.dirname(__file__)
json_filename = os.path.join(dirname, 'resnet%d_v%db_%.1fx' % (18, 1, 2.6) + ".json")
with open(json_filename, "r") as jsonFile:
params_shapes = json.load(jsonFile)
if pretrained:
from ..model_store import get_model_file
params_file = get_model_file('resnet%d_v%db_%.1fx' % (18, 1, 2.6), tag=pretrained,
root=root)
prune_gluon_block(model, model.name, params_shapes, params=ndarray.load(params_file),
pretrained=True, ctx=ctx)
else:
prune_gluon_block(model, model.name, params_shapes, params=None, pretrained=False, ctx=ctx)
if pretrained:
from ...data import ImageNet1kAttr
attrib = ImageNet1kAttr()
model.synset = attrib.synset
model.classes = attrib.classes
model.classes_long = attrib.classes_long
return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:33,代码来源:resnetv1b_pruned.py
示例7: resnet152_v1b
def resnet152_v1b(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a ResNetV1b-152 model.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
dilated: bool, default False
Whether to apply dilation strategy to ResNetV1b, yielding a stride 8 model.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
last_gamma : bool, default False
Whether to initialize the gamma of the last BatchNorm layer in each bottleneck to zero.
use_global_stats : bool, default False
Whether forcing BatchNorm to use global statistics instead of minibatch statistics;
optionally set to True if finetuning using ImageNet classification pretrained models.
"""
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], name_prefix='resnetv1b_', **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_parameters(get_model_file('resnet%d_v%db'%(152, 1),
tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
model.synset = attrib.synset
model.classes = attrib.classes
model.classes_long = attrib.classes_long
return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:34,代码来源:resnetv1b.py
示例8: resnet152_v1s
def resnet152_v1s(pretrained=False, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
"""Constructs a ResNetV1s-152 model.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
ctx : Context, default CPU
The context in which to load the pretrained weights.
dilated: bool, default False
Whether to apply dilation strategy to ResNetV1b, yielding a stride 8 model.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`).
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""
model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64,
name_prefix='resnetv1s_', **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_parameters(get_model_file('resnet%d_v%ds'%(152, 1),
tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
model.synset = attrib.synset
model.classes = attrib.classes
model.classes_long = attrib.classes_long
return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:30,代码来源:resnetv1b.py
示例9: get_densenet
def get_densenet(num_layers, pretrained=False, ctx=cpu(),
root='~/.mxnet/models', **kwargs):
r"""Densenet-BC model from the
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper.
Parameters
----------
num_layers : int
Number of layers for the variant of densenet. Options are 121, 161, 169, 201.
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
norm_kwargs : dict
Additional `norm_layer` arguments, for example `num_devices=4`
for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""
num_init_features, growth_rate, block_config = densenet_spec[num_layers]
net = DenseNet(num_init_features, growth_rate, block_config, **kwargs)
if pretrained:
from .model_store import get_model_file
net.load_parameters(get_model_file('densenet%d'%(num_layers),
tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
net.synset = attrib.synset
net.classes = attrib.classes
net.classes_long = attrib.classes_long
return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:35,代码来源:densenet.py
示例10: get_deeplab
def get_deeplab(dataset='pascal_voc', backbone='resnet50', pretrained=False,
root='~/.mxnet/models', ctx=cpu(0), **kwargs):
r"""DeepLabV3
Parameters
----------
dataset : str, default pascal_voc
The dataset that model pretrained on. (pascal_voc, ade20k)
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default '~/.mxnet/models'
Location for keeping the model parameters.
Examples
--------
>>> model = get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False)
>>> print(model)
"""
acronyms = {
'pascal_voc': 'voc',
'pascal_aug': 'voc',
'ade20k': 'ade',
'coco': 'coco',
}
from ..data import datasets
# infer number of classes
model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, ctx=ctx, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_parameters(get_model_file('deeplab_%s_%s'%(backbone, acronyms[dataset]),
tag=pretrained, root=root), ctx=ctx)
return model
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:34,代码来源:deeplabv3.py
示例11: fetcher_loop
def fetcher_loop(data_queue, data_buffer, pin_memory=False):
"""Fetcher loop for fetching data from queue and put in reorder dict."""
while True:
idx, batch = data_queue.get()
if idx is None:
break
if pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
else:
batch = _as_in_context(batch, context.cpu())
data_buffer[idx] = batch
开发者ID:mohamedelsiesyibra,项目名称:gluon-cv,代码行数:11,代码来源:dataloader.py
示例12: get_params
def get_params(self, arg_params, aux_params):
""" Copy data from each executor to `arg_params` and `aux_params`.
Parameters
----------
arg_params : list of NDArray
target parameter arrays
aux_params : list of NDArray
target aux arrays
Notes
-----
- This function will inplace update the NDArrays in arg_params and aux_params.
"""
for name, block in zip(self.param_names, self.param_arrays):
weight = sum(w.copyto(ctx.cpu()) for w in block) / len(block)
weight.astype(arg_params[name].dtype).copyto(arg_params[name])
for name, block in zip(self.aux_names, self.aux_arrays):
weight = sum(w.copyto(ctx.cpu()) for w in block) / len(block)
weight.astype(aux_params[name].dtype).copyto(aux_params[name])
开发者ID:ktr-hubrt,项目名称:Deformable-ConvNets,代码行数:20,代码来源:DataParallelExecutorGroup.py
示例13: get_simple_pose_resnet
def get_simple_pose_resnet(base_name, pretrained=False, ctx=cpu(),
root='~/.mxnet/models', **kwargs):
net = SimplePoseResNet(base_name, **kwargs)
if pretrained:
from ..model_store import get_model_file
net.load_parameters(get_model_file('simple_pose_%s'%(base_name),
tag=pretrained, root=root), ctx=ctx)
return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:11,代码来源:simple_pose_resnet.py
示例14: __init__
def __init__(self, nclass, backbone='resnet50', aux=True, ctx=cpu(), pretrained_base=True,
base_size=520, crop_size=480, **kwargs):
super(FCN, self).__init__(nclass, aux, backbone, ctx=ctx, base_size=base_size,
crop_size=crop_size, pretrained_base=pretrained_base, **kwargs)
with self.name_scope():
self.head = _FCNHead(2048, nclass, **kwargs)
self.head.initialize(ctx=ctx)
self.head.collect_params().setattr('lr_mult', 10)
if self.aux:
self.auxlayer = _FCNHead(1024, nclass, **kwargs)
self.auxlayer.initialize(ctx=ctx)
self.auxlayer.collect_params().setattr('lr_mult', 10)
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:12,代码来源:fcn.py
示例15: get_resnet
def get_resnet(version, num_layers, pretrained=False, ctx=cpu(),
root='~/.mxnet/models', use_se=False, **kwargs):
r"""ResNet V1 model from `"Deep Residual Learning for Image Recognition"
<http://arxiv.org/abs/1512.03385>`_ paper.
ResNet V2 model from `"Identity Mappings in Deep Residual Networks"
<https://arxiv.org/abs/1603.05027>`_ paper.
Parameters
----------
version : int
Version of ResNet. Options are 1, 2.
num_layers : int
Numbers of layers. Options are 18, 34, 50, 101, 152.
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
use_se : bool, default False
Whether to use Squeeze-and-Excitation module
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
norm_kwargs : dict
Additional `norm_layer` arguments, for example `num_devices=4`
for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""
assert num_layers in resnet_spec, \
"Invalid number of layers: %d. Options are %s"%(
num_layers, str(resnet_spec.keys()))
block_type, layers, channels = resnet_spec[num_layers]
assert 1 <= version <= 2, \
"Invalid resnet version: %d. Options are 1 and 2."%version
resnet_class = resnet_net_versions[version-1]
block_class = resnet_block_versions[version-1][block_type]
net = resnet_class(block_class, layers, channels, **kwargs)
if pretrained:
from .model_store import get_model_file
if not use_se:
net.load_parameters(get_model_file('resnet%d_v%d'%(num_layers, version),
tag=pretrained, root=root), ctx=ctx)
else:
net.load_parameters(get_model_file('se_resnet%d_v%d'%(num_layers, version),
tag=pretrained, root=root), ctx=ctx)
from ..data import ImageNet1kAttr
attrib = ImageNet1kAttr()
net.synset = attrib.synset
net.classes = attrib.classes
net.classes_long = attrib.classes_long
return net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:52,代码来源:resnet.py
示例16: __init__
def __init__(self, nclass, backbone='resnet50', aux=True, ctx=cpu(), pretrained_base=True,
base_size=520, crop_size=480, **kwargs):
super(DeepLabV3, self).__init__(nclass, aux, backbone, ctx=ctx, base_size=base_size,
crop_size=crop_size, pretrained_base=pretrained_base, **kwargs)
with self.name_scope():
self.head = _DeepLabHead(nclass, height=self._up_kwargs['height']//8,
width=self._up_kwargs['width']//8, **kwargs)
self.head.initialize(ctx=ctx)
self.head.collect_params().setattr('lr_mult', 10)
if self.aux:
self.auxlayer = _FCNHead(1024, nclass, **kwargs)
self.auxlayer.initialize(ctx=ctx)
self.auxlayer.collect_params().setattr('lr_mult', 10)
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:13,代码来源:deeplabv3.py
示例17: func
def func(pretrained=False, tag=None, root='~/.mxnet/models', ctx=cpu(0), **kwargs):
r"""Quantized model.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
tag : str, default is None
Optional length-8 sha1sum of parameter file. If `None`, best parameter file
will be used.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
"""
from ..model_zoo import get_model
from ..model_store import get_model_file
curr_dir = os.path.abspath(os.path.dirname(__file__))
model_name = name.replace('mobilenet1_', 'mobilenet1.')
model_name = model_name.replace('mobilenet0_', 'mobilenet0.')
json_file = os.path.join(curr_dir, '{}-symbol.json'.format(model_name))
base_name = '_'.join(model_name.split('_')[:-1])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
param_file = get_model_file(base_name, tag=tag, root=root) if pretrained else None
net = get_model('_'.join(model_name.split('_')[:-1]), prefix=sym_prefix)
classes = getattr(net, 'classes', [])
sym_net = SymbolBlock.imports(json_file, ['data'], None, ctx=ctx)
if param_file:
# directly imports weights saved by save_parameters is not applicable
# so we hack it by load and export once to a temporary params file
import tempfile
net.load_params(param_file)
net.hybridize()
if '512' in base_name:
net(mx.nd.zeros((1, 3, 512, 512)))
elif '300' in base_name:
net(mx.nd.zeros((1, 3, 300, 300)))
else:
net(mx.nd.zeros((1, 3, 224, 224)))
with tempfile.TemporaryDirectory() as tmpdirname:
prefix = os.path.join(tmpdirname, 'tmp')
net.export(prefix, epoch=0)
param_prefix = prefix + '-0000.params'
sym_net.collect_params().load(param_prefix)
sym_net.classes = classes
sym_net.reset_class = _not_impl
sym_net.set_nms = _not_impl
return sym_net
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:50,代码来源:quantized.py
示例18: __init__
def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
logger=logging, context=ctx.cpu(), work_load_list=None,
fixed_param_names=None, state_names=None):
super(Module, self).__init__(logger=logger)
if isinstance(context, ctx.Context):
context = [context]
self._context = context
if work_load_list is None:
work_load_list = [1] * len(self._context)
assert len(work_load_list) == len(self._context)
self._work_load_list = work_load_list
self._symbol = symbol
data_names = list(data_names) if data_names is not None else []
label_names = list(label_names) if label_names is not None else []
state_names = list(state_names) if state_names is not None else []
fixed_param_names = list(fixed_param_names) if fixed_param_names is not None else []
_check_input_names(symbol, data_names, "data", True)
_check_input_names(symbol, label_names, "label", False)
_check_input_names(symbol, state_names, "state", True)
_check_input_names(symbol, fixed_param_names, "fixed_param", True)
arg_names = symbol.list_arguments()
input_names = data_names + label_names + state_names
self._param_names = [x for x in arg_names if x not in input_names]
self._fixed_param_names = fixed_param_names
self._aux_names = symbol.list_auxiliary_states()
self._data_names = data_names
self._label_names = label_names
self._state_names = state_names
self._output_names = symbol.list_outputs()
self._arg_params = None
self._aux_params = None
self._params_dirty = False
self._optimizer = None
self._kvstore = None
self._update_on_kvstore = None
self._updater = None
self._preload_opt_states = None
self._grad_req = None
self._exec_group = None
self._data_shapes = None
self._label_shapes = None
开发者ID:ktr-hubrt,项目名称:Deformable-ConvNets,代码行数:49,代码来源:module.py
示例19: prune_gluon_block
def prune_gluon_block(net, prefix, params_shapes, params=None, pretrained=False, ctx=cpu(0)):
"""
:param params_shapes: dictionary of shapes of convolutional weights
:param prefix: prefix of the original resnet50_v1d
:param pretrained: Boolean specifying if the pretrained model parameters needs to be loaded
:param net: original network that is required to be pruned
:param params: dictionary of parameters for the pruned network. Size of the parameters in
this dictionary tells what
should be the size of channels of each convolution layer.
:param ctx: cpu(0)
:return: "net"
"""
for _, layer in net._children.items():
if pretrained:
if isinstance(layer, nn.BatchNorm):
params_layer = layer._collect_params_with_prefix()
for param_name in ['beta', 'gamma', 'running_mean', 'running_var']:
param_val = params[layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
layer.params.get(param_name)._shape = param_val.shape
params_layer[param_name]._load_init(param_val, ctx=ctx)
if isinstance(layer, nn.Conv2D):
param_shape = params_shapes[layer.name.replace(prefix, "resnetv1d") + "_weight"]
layer._channels = param_shape[0]
layer._kwargs['num_filter'] = param_shape[0]
params_layer = layer._collect_params_with_prefix()
for param_name in ['weight']:
param_shape = params_shapes[
layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
layer.params.get(param_name)._shape = param_shape
if pretrained:
param_val = params[layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
params_layer[param_name]._load_init(param_val, ctx=ctx)
if isinstance(layer, nn.Dense):
layer._in_units = params_shapes[layer.name.replace(prefix, "resnetv1d") + "_weight"][1]
params_layer = layer._collect_params_with_prefix()
for param_name in ['weight', 'bias']:
param_shape = params_shapes[
layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
layer.params.get(param_name)._shape = param_shape
if pretrained:
param_val = params[layer.name.replace(prefix, "resnetv1d") + "_" + param_name]
params_layer[param_name]._load_init(param_val, ctx=ctx)
else:
prune_gluon_block(layer, prefix, params_shapes, params, pretrained, ctx)
开发者ID:xiayongtao,项目名称:gluon-cv,代码行数:48,代码来源:resnetv1b_pruned.py
示例20: __init__
def __init__(self, symbol, data_names, label_names,
logger=logging, context=ctx.cpu(), work_load_list=None,
max_data_shapes=None, max_label_shapes=None):
super(MutableModule, self).__init__(logger=logger)
self._symbol = symbol
self._data_names = data_names
self._label_names = label_names
self._context = context
self._work_load_list = work_load_list
self._curr_module = None
self._max_data_shapes = max_data_shapes
self._max_label_shapes = max_label_shapes
if self._max_data_shapes is None:
self._max_data_shapes = []
if self._max_label_shapes is None:
self._max_label_shapes = []
开发者ID:Alexbert1,项目名称:mxnet,代码行数:17,代码来源:module.py
注:本文中的mxnet.context.cpu函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论