• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python torch.load函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中torch.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了load函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _load

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint
开发者ID:Saiuz,项目名称:autokeras,代码行数:7,代码来源:model_helper.py


示例2: init_model

def init_model(word2id, opt):
    model = Seq2SeqLSTMAttention(
        emb_dim=opt.word_vec_size,
        vocab_size=opt.vocab_size,
        src_hidden_dim=opt.rnn_size,
        trg_hidden_dim=opt.rnn_size,
        ctx_hidden_dim=opt.rnn_size,
        attention_mode='dot',
        batch_size=opt.batch_size,
        bidirectional=opt.bidirectional,
        pad_token_src = word2id[pykp.io.PAD_WORD],
        pad_token_trg = word2id[pykp.io.PAD_WORD],
        nlayers_src=opt.enc_layers,
        nlayers_trg=opt.dec_layers,
        dropout=opt.dropout,
        teacher_forcing_ratio=opt.teacher_forcing_ratio,
        scheduled_sampling=opt.scheduled_sampling,
        scheduled_sampling_batches=opt.scheduled_sampling_batches
    )

    logging.info('======================  Model Parameters  =========================')
    if opt.train_from:
        logging.info("loading previous checkpoint from %s" % opt.train_from)
        if torch.cuda.is_available():
            model.load_state_dict(torch.load(open(opt.train_from, 'rb')))
        else:
            model.load_state_dict(torch.load(
                open(opt.train_from, 'rb'), map_location=lambda storage, loc: storage
            ))
    utils.tally_parameters(model)

    return model
开发者ID:zhhengcs,项目名称:seq2seq-keyphrase-pytorch,代码行数:32,代码来源:train(old,no+copy,max+entropy+loss).py


示例3: generate

def generate(**kwargs):
    """
    随机生成动漫头像,并根据netd的分数选择较好的
    """
    for k_, v_ in kwargs.items():
        setattr(opt, k_, v_)
    
    device=t.device('cuda') if opt.gpu else t.device('cpu')

    netg, netd = NetG(opt).eval(), NetD(opt).eval()
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = noises.to(device)

    map_location = lambda storage, loc: storage
    netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
    netd.to(device)
    netg.to(device)


    # 生成图片,并计算图片在判别器的分数
    fake_img = netg(noises)
    scores = netd(fake_img).detach()

    # 挑选最好的某几张
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])
    # 保存图片
    tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))
开发者ID:672401341,项目名称:pytorch-book,代码行数:31,代码来源:main.py


示例4: load_model

    def load_model(self):
        if len(glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')) == 0:
            return

        if args.load_iter is None:
            f_list = glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')
            iter_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
            start_iter = sorted(iter_list)[-1]
        else:
            start_iter = args.load_iter

        name = args.corpus + '-selector-{}.pth'.format(start_iter)
        model_file_path = os.path.join(args.save_dir, name)
        print("loading model", model_file_path)

        if opt.device == torch.device('cuda'):
            state = torch.load(model_file_path)
        else:
            state = torch.load(model_file_path, map_location=opt.device)

        self._epoch = state['epoch']
        self._iter = state['iter']
        self.running_avg_loss = state['current_loss']
        self.min_loss = state['min_loss']

        self.model.sentence_selector.load_state_dict(state['selector_state_dict'])

        if not args.is_coverage:
            self.optimizer.load_state_dict(state['optimizer'])
            if opt.device == torch.device('cuda'):
                for state in list(self.optimizer.state.values()):
                    for k, v in list(state.items()):
                        if torch.is_tensor(v):
                            state[k] = v.cuda()
开发者ID:coder352,项目名称:shellscript,代码行数:34,代码来源:train_selector.py


示例5: load_checkpoint

def load_checkpoint(checkpoint):
    if torch.cuda.is_available():
        checkpoint = torch.load(checkpoint)
    else:
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    
    return checkpoint
开发者ID:Wilson-Sunshine,项目名称:Udacity_AI_Program_Basic,代码行数:7,代码来源:predict.py


示例6: run

def run(args, run_args, rank=0, world_size=1):
    set_seed(args, rank=rank)
    logger = initialize_logger(args, rank)
    field, train_sets, val_sets, save_dict = run_args

    logger.start = time.time()

    logger.info(f'Preparing iterators')
    train_iters = [(name, to_iter(args, world_size, tok, x, token_testing=args.token_testing)) 
                      for name, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
    val_iters = [(name, to_iter(args, world_size, tok, x, train=False, token_testing=args.token_testing, sort=False if 'sql' in name else None))
                    for name, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)]

    logger.info(f'Initializing Writer')
    writer = SummaryWriter(log_dir=args.log_dir)

    model = init_model(args, field, logger, world_size)
    opt = init_opt(args, model) 
    start_iteration = 1

    if save_dict is not None:
        logger.info(f'Loading model from {os.path.join(args.save, args.load)}')
        save_dict = torch.load(os.path.join(args.save, args.load))
        model.load_state_dict(save_dict['model_state_dict'])
        if args.resume:
            logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')
            opt.load_state_dict(torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')))
            start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1])

    logger.info(f'Begin Training')
    train(args, model, opt, train_iters, args.train_iterations, field, val_iters=val_iters, 
        rank=rank, world_size=world_size, 
        log_every=args.log_every, val_every=args.val_every, rounds=len(train_iters)>1,
        writer=writer if rank==0 else None, save_every=args.save_every, start_iteration=start_iteration)
开发者ID:AhlamMD,项目名称:decaNLP,代码行数:34,代码来源:train.py


示例7: restore_model

 def restore_model(self, resume_iters):
     """Restore the trained generator and discriminator."""
     print('Loading the trained models from step {}...'.format(resume_iters))
     G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
     D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
     self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
     self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:7,代码来源:solver.py


示例8: get_pretrained_net

def get_pretrained_net(name):
    """Loads pretrained network"""
    if name == 'alexnet_caffe':
        if not os.path.exists('alexnet-torch_py3.pth'):
            print('Downloading AlexNet')
            os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
        return torch.load('alexnet-torch_py3.pth')
    elif name == 'vgg19_caffe':
        if not os.path.exists('vgg19-caffe-py3.pth'):
            print('Downloading VGG-19')
            os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download')
        
        vgg = get_vgg19_caffe()
        
        return vgg
    elif name == 'vgg16_caffe':
        if not os.path.exists('vgg16-caffe-py3.pth'):
            print('Downloading VGG-16')
            os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download')
        
        vgg = get_vgg16_caffe()
        
        return vgg
    elif name == 'vgg19_pytorch_modified':
        # os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1')
        
        model = VGGModified(vgg19(pretrained=False), 0.2)
        model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict'])

        return model
    else:
        assert False
开发者ID:1exx,项目名称:deep-image-prior,代码行数:32,代码来源:perceptual_loss.py


示例9: get_vanilla_vgg_features

def get_vanilla_vgg_features(cut_idx=-1):
    if not os.path.exists('vgg_features.pth'):
        os.system(
            'wget --no-check-certificate -N https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth')
        vgg_weights = torch.load('vgg19-d01eb7cb.pth')
        # fix compatibility issues
        map = {'classifier.6.weight':u'classifier.7.weight', 'classifier.6.bias':u'classifier.7.bias'}
        vgg_weights = OrderedDict([(map[k] if k in map else k,v) for k,v in vgg_weights.iteritems()])

        

        model = models.vgg19()
        model.classifier = nn.Sequential(View(), *model.classifier._modules.values())
        

        model.load_state_dict(vgg_weights)
        
        torch.save(model.features, 'vgg_features.pth')
        torch.save(model.classifier, 'vgg_classifier.pth')

    vgg = torch.load('vgg_features.pth')
    if cut_idx > 36:
        vgg_classifier = torch.load('vgg_classifier.pth')
        vgg = nn.Sequential(*(vgg._modules.values() + vgg_classifier._modules.values()))

    vgg.eval()

    return vgg
开发者ID:1exx,项目名称:deep-image-prior,代码行数:28,代码来源:feature_inversion_utils.py


示例10: load

 def load(self, filename, legacy=False, ignore_d=False):
     """
     ignore_d: if `True`, then don't load in the
       discriminator.
     """
     if not self.use_cuda:
         map_location = lambda storage, loc: storage
     else:
         map_location = None
     if legacy:
         g, d = torch.load(filename,
                           map_location=map_location)
         self.g.load_state_dict(g)
         if not ignore_d:
             self.d.load_state_dict(d)
     else:
         dd = torch.load(filename,
                         map_location=map_location)
         self.g.load_state_dict(dd['g'])
         if not ignore_d:
             self.d.load_state_dict(dd['d'])
         for key in self.optim:
             if ignore_d and key == 'd':
                 continue
             self.optim[key].load_state_dict(dd['optim_'+key])
         self.last_epoch = dd['epoch']
开发者ID:kazk1018,项目名称:manifold_mixup,代码行数:26,代码来源:base.py


示例11: load_network_stageI

    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD
开发者ID:tensoralex,项目名称:StackGAN-Pytorch,代码行数:25,代码来源:trainer.py


示例12: __init__

    def __init__(self,
                 root, mnist_root="data",
                 train=True,
                 transform=None, target_transform=None,
                 download=False):
        """Init MNIST-M dataset."""
        super(MNISTM, self).__init__()
        self.root = os.path.expanduser(root)
        self.mnist_root = os.path.expanduser(mnist_root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = \
                torch.load(os.path.join(self.root,
                                        self.processed_folder,
                                        self.training_file))
        else:
            self.test_data, self.test_labels = \
                torch.load(os.path.join(self.root,
                                        self.processed_folder,
                                        self.test_file))
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:30,代码来源:mnistm.py


示例13: load_models

def load_models(load_path):
    model_args = json.load(open("{}/args.json".format(load_path), "r"))
    word2idx = json.load(open("{}/vocab.json".format(load_path), "r"))
    idx2word = {v: k for k, v in word2idx.items()}

    autoencoder = Seq2Seq(emsize=model_args['emsize'],
                          nhidden=model_args['nhidden'],
                          ntokens=model_args['ntokens'],
                          nlayers=model_args['nlayers'],
                          hidden_init=model_args['hidden_init'])
    gan_gen = MLP_G(ninput=model_args['z_size'],
                    noutput=model_args['nhidden'],
                    layers=model_args['arch_g'])
    gan_disc = MLP_D(ninput=model_args['nhidden'],
                     noutput=1,
                     layers=model_args['arch_d'])

    print('Loading models from'+load_path)
    ae_path = os.path.join(load_path, "autoencoder_model.pt")
    gen_path = os.path.join(load_path, "gan_gen_model.pt")
    disc_path = os.path.join(load_path, "gan_disc_model.pt")

    autoencoder.load_state_dict(torch.load(ae_path))
    gan_gen.load_state_dict(torch.load(gen_path))
    gan_disc.load_state_dict(torch.load(disc_path))
    return model_args, idx2word, autoencoder, gan_gen, gan_disc
开发者ID:wangwang110,项目名称:ARAE,代码行数:26,代码来源:models.py


示例14: demo

def demo(data, save, depth=40, growth_rate=12, batch_size=256):
    """
    Applies temperature scaling to a trained model.

    Takes a pretrained DenseNet-CIFAR100 model, and a validation set
    (parameterized by indices on train set).
    Applies temperature scaling, and saves a temperature scaled version.

    NB: the "save" parameter references a DIRECTORY, not a file.
    In that directory, there should be two files:
    - model.pth (model state dict)
    - valid_indices.pth (a list of indices corresponding to the validation set).

    data (str) - path to directory where data should be loaded from/downloaded
    save (str) - directory with necessary files (see above)
    """
    # Load model state dict
    model_filename = os.path.join(save, 'model.pth')
    if not os.path.exists(model_filename):
        raise RuntimeError('Cannot find file %s to load' % model_filename)
    state_dict = torch.load(model_filename)

    # Load validation indices
    valid_indices_filename = os.path.join(save, 'valid_indices.pth')
    if not os.path.exists(valid_indices_filename):
        raise RuntimeError('Cannot find file %s to load' % valid_indices_filename)
    valid_indices = torch.load(valid_indices_filename)

    # Regenerate validation set loader
    mean = [0.5071, 0.4867, 0.4408]
    stdv = [0.2675, 0.2565, 0.2761]
    test_transforms = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize(mean=mean, std=stdv),
    ])
    valid_set = tv.datasets.CIFAR100(data, train=True, transform=test_transforms, download=True)
    valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size,
                                               sampler=SubsetRandomSampler(valid_indices))

    # Load original model
    if (depth - 4) % 3:
        raise Exception('Invalid depth')
    block_config = [(depth - 4) // 6 for _ in range(3)]
    orig_model = DenseNetEfficientMulti(
        growth_rate=growth_rate,
        block_config=block_config,
        num_classes=100
    ).cuda()
    orig_model.load_state_dict(state_dict)

    # Now we're going to wrap the model with a decorator that adds temperature scaling
    model = ModelWithTemperature(orig_model)

    # Tune the model temperature, and save the results
    model.set_temperature(valid_loader)
    model_filename = os.path.join(save, 'model_with_temperature.pth')
    torch.save(model.state_dict(), model_filename)
    print('Temperature scaled model sved to %s' % model_filename)
    print('Done!')
开发者ID:zhenglm,项目名称:temperature_scaling,代码行数:59,代码来源:demo.py


示例15: __init__

 def __init__(self, file, labelFile):
     self.train = torch.load(file)
     self.label = torch.load(labelFile)
     self.len = len(self.train)  # get how many data points.
     for i in range(0, self.len):  # transform the imgs.
         self.train[i] = transforms.Normalize((0.1307,), (0.3081,))(
             self.train[i].view(1, -1))  # do a small transformation
     self.train = self.train.view(-1, 1, 28, 28)
开发者ID:RobinROAR,项目名称:TensorflowTutorialsCode,代码行数:8,代码来源:utils.py


示例16: main_test

def main_test():
    img_net, text_net = torch.load('img_net.pt'), torch.load('text_net.pt')
    tiidlst = [l.strip() for l in file('test_ids.txt')]
    img_dir = '/home/datasets/coco/raw/vgg19_feat/val/'
    text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/val/'
    img_feat_dataset = COCOImgFeatDataset(tiidlst, img_dir)
    text_feat_dataset = COCOTextFeatDataset(tiidlst,text_dir)
    test(tiidlst,img_feat_dataset,text_feat_dataset,img_net,text_net)
开发者ID:tyhu,项目名称:PyAI,代码行数:8,代码来源:train.py


示例17: load_train_valid_data

def load_train_valid_data(opt):
    logging.info("Loading train and validate data from '%s'" % opt.data)

    logging.info("Loading train/valid from disk: %s" % (opt.data))
    data_dict = torch.load(opt.data, 'wb')

    train_src = np.asarray([d['src'] for d in data_dict['train']])
    train_trg = np.asarray([d['trg'] for d in data_dict['train']])
    valid_src = np.asarray([d['src'] for d in data_dict['valid']])
    valid_trg = np.asarray([d['trg'] for d in data_dict['valid']])

    word2id, id2word, vocab  = torch.load(opt.vocab, 'wb')

    # training_data_loader = DataLoader(dataset=list(zip(train_src, train_trg)), num_workers=opt.batch_workers, batch_size=opt.batch_size, shuffle=True)
    # validation_data_loader = DataLoader(dataset=list(zip(valid_src, valid_trg)), num_workers=opt.batch_workers, batch_size=opt.batch_size, shuffle=True)

    src_field = torchtext.data.Field(
        use_vocab = False,
        init_token=word2id[pykp.io.BOS_WORD],
        eos_token=word2id[pykp.io.EOS_WORD],
        pad_token=word2id[pykp.io.PAD_WORD],
        batch_first = True
    )
    trg_field = torchtext.data.Field(
        use_vocab = False,
        init_token=word2id[pykp.io.BOS_WORD],
        eos_token=word2id[pykp.io.EOS_WORD],
        pad_token=word2id[pykp.io.PAD_WORD],
        batch_first=True
    )

    train = KeyphraseDatasetTorchText(list(zip(train_src, train_trg)), [('src', src_field), ('trg', trg_field)])
    valid = KeyphraseDatasetTorchText(list(zip(valid_src, valid_trg)), [('src', src_field), ('trg', trg_field)])

    if torch.cuda.is_available():
        device = opt.gpuid
    else:
        device = -1

    # training_data_loader    = torchtext.data.BucketIterator(dataset=train, batch_size=opt.batch_size, train=True, repeat=True, shuffle=True, sort=False, device=device)
    if torch.cuda.is_available():
        training_data_loader    = torchtext.data.BucketIterator(dataset=train, batch_size=opt.batch_size, train=True, shuffle=True, repeat=False, sort=True, device = None)
        validation_data_loader  = torchtext.data.BucketIterator(dataset=valid, batch_size=opt.batch_size, train=False, shuffle=False, repeat=False, sort=False, device = None)
    else:
        training_data_loader    = torchtext.data.BucketIterator(dataset=train, batch_size=opt.batch_size, train=True, shuffle=True, repeat=False, sort=True, device = -1)
        validation_data_loader  = torchtext.data.BucketIterator(dataset=valid, batch_size=opt.batch_size, train=False, shuffle=False, repeat=False, sort=False, device = -1)

    opt.word2id = word2id
    opt.id2word = id2word
    opt.vocab   = vocab

    logging.info('======================  Dataset  =========================')
    logging.info('#(training data pairs)=%d' % len(training_data_loader.dataset))
    logging.info('#(validation data pairs)=%d' % len(validation_data_loader.dataset))
    logging.info('#(vocab)=%d' % len(vocab))
    logging.info('#(vocab used)=%d' % opt.vocab_size)

    return training_data_loader, validation_data_loader, word2id, id2word, vocab
开发者ID:zhhengcs,项目名称:seq2seq-keyphrase-pytorch,代码行数:58,代码来源:train(old,no+copy,max+entropy+loss).py


示例18: get_outputs

def get_outputs(image_dir, filename):
    models_name = ['resnet152', 'vgg19_bn', 'densenet161', 'nasnetalarge']
    res = {}
    res_labels = {}
    for name in models_name:
        if name == 'densenet161':
            model_ft = torch.load('model_pretrained_densenet161.pkl')
        elif name == 'resnet152':
            model_ft = torch.load('model_pretrained_resnet152.pkl')
        elif name == 'vgg19_bn':
            model_ft = torch.load('model_pretrained_vgg19.pkl')

        data_transforms = transforms.Compose([
            transforms.Scale(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])
        if name == 'nasnetalarge':
            model_ft = pretrainedmodels.nasnetalarge(num_classes=1000, pretrained='imagenet')
            data_transforms = transforms.Compose([
                transforms.Scale(377),
                transforms.CenterCrop(331),
                transforms.ToTensor(),
                transforms.Normalize(mean=model_ft.mean,
                                     std=model_ft.std)])
            model_ft = torch.load('model_pretrained_nasnet.pkl')
        use_gpu = torch.cuda.is_available()
        model_ft.eval()


        test_dataset = TestData(image_dir, data_transforms)
        test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)
        since = time.time()
        temp = []
        temp_list = []
        for i, batch in enumerate(test_dataloader):
            inputs, cid = batch
            temp_list.append(cid)
            if use_gpu:
                inputs = Variable(inputs.cuda())
            else:
                inputs = Variable(inputs)

            outputs = model_ft(inputs)
            temp.append(softmax(outputs.data.cpu().numpy()))
            if i % 200 == 199:
                print('iter:{}'.format(i+1))
        time_elapsed = time.time() - since
        print('Testing complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        res[name] = np.concatenate(temp)
        res_labels[name] = [y for x in temp_list for y in x]
        print('{} finish'.format(name))

    torch.save(res, filename)
    torch.save(res_labels, filename + '_label')
    return res
开发者ID:LinfeiHe,项目名称:schoolwork,代码行数:58,代码来源:generate_outputs.py


示例19: setup

def setup(args, inject_train=None, inject_dev=None, inject_test=None):
    torch.cuda.set_device(args.gpu)

    ### setup data
    TEXT = data.Field()
    LABEL = data.Field(sequential=False)
    
    train_set, dev_set, test_set = datasets.SST.splits(
            TEXT, LABEL, fine_grained=False, train_subtrees=True,
            filter_pred=lambda x: x.label != 'neutral')

    ### inject special place holders to the datasets
    if inject_train is not None:
        train_set = inject_train(train_set)
    if inject_dev is not None:
        dev_set = inject_dev(dev_set)
    if inject_test is not None:
        test_set = inject_test(test_set)

    TEXT.build_vocab(train_set)
    LABEL.build_vocab(train_set)

    train_iter, dev_iter, test_iter = data.BucketIterator.splits(
            (train_set, dev_set, test_set), 
            batch_size=args.batch_size, device=args.gpu)

    # load word vectors
    if args.wv_type:
        if os.path.isfile(args.wv_cache):
            TEXT.vocab.vectors = torch.load(args.wv_cache)
        else:
            TEXT.vocab.load_vectors(wv_dir=args.data_cache,
                    wv_type=args.wv_type, wv_dim=args.embed_size)
            makedirs(os.path.dirname(args.wv_cache))
            torch.save(TEXT.vocab.vectors, args.wv_cache)

    args.vocab_size = len(TEXT.vocab)
    args.embed_size = TEXT.vocab.vectors.size(1)
    args.output_size = len(LABEL.vocab)
    print('vocab size', args.vocab_size)
    print('embed size', args.embed_size)
    print('output size', args.output_size)

    ### setup model
    if args.resume_snapshot:
        print('loading snapshot', args.resume_snapshot)
        model = torch.load(args.resume_snapshot, 
                map_location=lambda storage, location: storage.cuda(args.gpu))
    else:
        model = globals()[args.model_class](args)
    
        if args.wv_type:
            model.embed.weight.data = TEXT.vocab.vectors
        
        if args.gpu >= 0:
            model.cuda()

    return args, TEXT, LABEL, train_iter, dev_iter, test_iter, model
开发者ID:ihsgnef,项目名称:imdb_word_replace,代码行数:58,代码来源:inject.py


示例20: load_pretrained

 def load_pretrained(self):
     self.D_cVAE.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cVAE.pkl')))
     self.D_cLR.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cLR.pkl')))
     self.G.load_state_dict(torch.load(os.path.join(self.weight_dir, 'G.pkl')))
     self.E.load_state_dict(torch.load(os.path.join(self.weight_dir, 'E.pkl')))
     
     log_file = open('log.txt', 'r')
     line = log_file.readline()
     self.start_epoch = int(line)
开发者ID:Pandinosaurus,项目名称:BicycleGAN-pytorch,代码行数:9,代码来源:solver.py



注:本文中的torch.load函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python torch.log函数代码示例发布时间:2022-05-27
下一篇:
Python torch.is_tensor函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap