本文整理汇总了Python中torch.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了load函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _load
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
开发者ID:Saiuz,项目名称:autokeras,代码行数:7,代码来源:model_helper.py
示例2: init_model
def init_model(word2id, opt):
model = Seq2SeqLSTMAttention(
emb_dim=opt.word_vec_size,
vocab_size=opt.vocab_size,
src_hidden_dim=opt.rnn_size,
trg_hidden_dim=opt.rnn_size,
ctx_hidden_dim=opt.rnn_size,
attention_mode='dot',
batch_size=opt.batch_size,
bidirectional=opt.bidirectional,
pad_token_src = word2id[pykp.io.PAD_WORD],
pad_token_trg = word2id[pykp.io.PAD_WORD],
nlayers_src=opt.enc_layers,
nlayers_trg=opt.dec_layers,
dropout=opt.dropout,
teacher_forcing_ratio=opt.teacher_forcing_ratio,
scheduled_sampling=opt.scheduled_sampling,
scheduled_sampling_batches=opt.scheduled_sampling_batches
)
logging.info('====================== Model Parameters =========================')
if opt.train_from:
logging.info("loading previous checkpoint from %s" % opt.train_from)
if torch.cuda.is_available():
model.load_state_dict(torch.load(open(opt.train_from, 'rb')))
else:
model.load_state_dict(torch.load(
open(opt.train_from, 'rb'), map_location=lambda storage, loc: storage
))
utils.tally_parameters(model)
return model
开发者ID:zhhengcs,项目名称:seq2seq-keyphrase-pytorch,代码行数:32,代码来源:train(old,no+copy,max+entropy+loss).py
示例3: generate
def generate(**kwargs):
"""
随机生成动漫头像,并根据netd的分数选择较好的
"""
for k_, v_ in kwargs.items():
setattr(opt, k_, v_)
device=t.device('cuda') if opt.gpu else t.device('cpu')
netg, netd = NetG(opt).eval(), NetD(opt).eval()
noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
noises = noises.to(device)
map_location = lambda storage, loc: storage
netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))
netd.to(device)
netg.to(device)
# 生成图片,并计算图片在判别器的分数
fake_img = netg(noises)
scores = netd(fake_img).detach()
# 挑选最好的某几张
indexs = scores.topk(opt.gen_num)[1]
result = []
for ii in indexs:
result.append(fake_img.data[ii])
# 保存图片
tv.utils.save_image(t.stack(result), opt.gen_img, normalize=True, range=(-1, 1))
开发者ID:672401341,项目名称:pytorch-book,代码行数:31,代码来源:main.py
示例4: load_model
def load_model(self):
if len(glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')) == 0:
return
if args.load_iter is None:
f_list = glob.glob(os.path.join(args.save_dir, args.corpus) + '-selector-*.pth')
iter_list = [int(i.split('-')[-1].split('.')[0]) for i in f_list]
start_iter = sorted(iter_list)[-1]
else:
start_iter = args.load_iter
name = args.corpus + '-selector-{}.pth'.format(start_iter)
model_file_path = os.path.join(args.save_dir, name)
print("loading model", model_file_path)
if opt.device == torch.device('cuda'):
state = torch.load(model_file_path)
else:
state = torch.load(model_file_path, map_location=opt.device)
self._epoch = state['epoch']
self._iter = state['iter']
self.running_avg_loss = state['current_loss']
self.min_loss = state['min_loss']
self.model.sentence_selector.load_state_dict(state['selector_state_dict'])
if not args.is_coverage:
self.optimizer.load_state_dict(state['optimizer'])
if opt.device == torch.device('cuda'):
for state in list(self.optimizer.state.values()):
for k, v in list(state.items()):
if torch.is_tensor(v):
state[k] = v.cuda()
开发者ID:coder352,项目名称:shellscript,代码行数:34,代码来源:train_selector.py
示例5: load_checkpoint
def load_checkpoint(checkpoint):
if torch.cuda.is_available():
checkpoint = torch.load(checkpoint)
else:
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
return checkpoint
开发者ID:Wilson-Sunshine,项目名称:Udacity_AI_Program_Basic,代码行数:7,代码来源:predict.py
示例6: run
def run(args, run_args, rank=0, world_size=1):
set_seed(args, rank=rank)
logger = initialize_logger(args, rank)
field, train_sets, val_sets, save_dict = run_args
logger.start = time.time()
logger.info(f'Preparing iterators')
train_iters = [(name, to_iter(args, world_size, tok, x, token_testing=args.token_testing))
for name, x, tok in zip(args.train_tasks, train_sets, args.train_batch_tokens)]
val_iters = [(name, to_iter(args, world_size, tok, x, train=False, token_testing=args.token_testing, sort=False if 'sql' in name else None))
for name, x, tok in zip(args.val_tasks, val_sets, args.val_batch_size)]
logger.info(f'Initializing Writer')
writer = SummaryWriter(log_dir=args.log_dir)
model = init_model(args, field, logger, world_size)
opt = init_opt(args, model)
start_iteration = 1
if save_dict is not None:
logger.info(f'Loading model from {os.path.join(args.save, args.load)}')
save_dict = torch.load(os.path.join(args.save, args.load))
model.load_state_dict(save_dict['model_state_dict'])
if args.resume:
logger.info(f'Resuming Training from {os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')
opt.load_state_dict(torch.load(os.path.join(args.save, f'{os.path.splitext(args.load)[0]}_rank_{rank}_optim.pth')))
start_iteration = int(os.path.splitext(os.path.basename(args.load))[0].split('_')[1])
logger.info(f'Begin Training')
train(args, model, opt, train_iters, args.train_iterations, field, val_iters=val_iters,
rank=rank, world_size=world_size,
log_every=args.log_every, val_every=args.val_every, rounds=len(train_iters)>1,
writer=writer if rank==0 else None, save_every=args.save_every, start_iteration=start_iteration)
开发者ID:AhlamMD,项目名称:decaNLP,代码行数:34,代码来源:train.py
示例7: restore_model
def restore_model(self, resume_iters):
"""Restore the trained generator and discriminator."""
print('Loading the trained models from step {}...'.format(resume_iters))
G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:7,代码来源:solver.py
示例8: get_pretrained_net
def get_pretrained_net(name):
"""Loads pretrained network"""
if name == 'alexnet_caffe':
if not os.path.exists('alexnet-torch_py3.pth'):
print('Downloading AlexNet')
os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download')
return torch.load('alexnet-torch_py3.pth')
elif name == 'vgg19_caffe':
if not os.path.exists('vgg19-caffe-py3.pth'):
print('Downloading VGG-19')
os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download')
vgg = get_vgg19_caffe()
return vgg
elif name == 'vgg16_caffe':
if not os.path.exists('vgg16-caffe-py3.pth'):
print('Downloading VGG-16')
os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download')
vgg = get_vgg16_caffe()
return vgg
elif name == 'vgg19_pytorch_modified':
# os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1')
model = VGGModified(vgg19(pretrained=False), 0.2)
model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict'])
return model
else:
assert False
开发者ID:1exx,项目名称:deep-image-prior,代码行数:32,代码来源:perceptual_loss.py
示例9: get_vanilla_vgg_features
def get_vanilla_vgg_features(cut_idx=-1):
if not os.path.exists('vgg_features.pth'):
os.system(
'wget --no-check-certificate -N https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth')
vgg_weights = torch.load('vgg19-d01eb7cb.pth')
# fix compatibility issues
map = {'classifier.6.weight':u'classifier.7.weight', 'classifier.6.bias':u'classifier.7.bias'}
vgg_weights = OrderedDict([(map[k] if k in map else k,v) for k,v in vgg_weights.iteritems()])
model = models.vgg19()
model.classifier = nn.Sequential(View(), *model.classifier._modules.values())
model.load_state_dict(vgg_weights)
torch.save(model.features, 'vgg_features.pth')
torch.save(model.classifier, 'vgg_classifier.pth')
vgg = torch.load('vgg_features.pth')
if cut_idx > 36:
vgg_classifier = torch.load('vgg_classifier.pth')
vgg = nn.Sequential(*(vgg._modules.values() + vgg_classifier._modules.values()))
vgg.eval()
return vgg
开发者ID:1exx,项目名称:deep-image-prior,代码行数:28,代码来源:feature_inversion_utils.py
示例10: load
def load(self, filename, legacy=False, ignore_d=False):
"""
ignore_d: if `True`, then don't load in the
discriminator.
"""
if not self.use_cuda:
map_location = lambda storage, loc: storage
else:
map_location = None
if legacy:
g, d = torch.load(filename,
map_location=map_location)
self.g.load_state_dict(g)
if not ignore_d:
self.d.load_state_dict(d)
else:
dd = torch.load(filename,
map_location=map_location)
self.g.load_state_dict(dd['g'])
if not ignore_d:
self.d.load_state_dict(dd['d'])
for key in self.optim:
if ignore_d and key == 'd':
continue
self.optim[key].load_state_dict(dd['optim_'+key])
self.last_epoch = dd['epoch']
开发者ID:kazk1018,项目名称:manifold_mixup,代码行数:26,代码来源:base.py
示例11: load_network_stageI
def load_network_stageI(self):
from model import STAGE1_G, STAGE1_D
netG = STAGE1_G()
netG.apply(weights_init)
print(netG)
netD = STAGE1_D()
netD.apply(weights_init)
print(netD)
if cfg.NET_G != '':
state_dict = \
torch.load(cfg.NET_G,
map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load from: ', cfg.NET_G)
if cfg.NET_D != '':
state_dict = \
torch.load(cfg.NET_D,
map_location=lambda storage, loc: storage)
netD.load_state_dict(state_dict)
print('Load from: ', cfg.NET_D)
if cfg.CUDA:
netG.cuda()
netD.cuda()
return netG, netD
开发者ID:tensoralex,项目名称:StackGAN-Pytorch,代码行数:25,代码来源:trainer.py
示例12: __init__
def __init__(self,
root, mnist_root="data",
train=True,
transform=None, target_transform=None,
download=False):
"""Init MNIST-M dataset."""
super(MNISTM, self).__init__()
self.root = os.path.expanduser(root)
self.mnist_root = os.path.expanduser(mnist_root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = \
torch.load(os.path.join(self.root,
self.processed_folder,
self.training_file))
else:
self.test_data, self.test_labels = \
torch.load(os.path.join(self.root,
self.processed_folder,
self.test_file))
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:30,代码来源:mnistm.py
示例13: load_models
def load_models(load_path):
model_args = json.load(open("{}/args.json".format(load_path), "r"))
word2idx = json.load(open("{}/vocab.json".format(load_path), "r"))
idx2word = {v: k for k, v in word2idx.items()}
autoencoder = Seq2Seq(emsize=model_args['emsize'],
nhidden=model_args['nhidden'],
ntokens=model_args['ntokens'],
nlayers=model_args['nlayers'],
hidden_init=model_args['hidden_init'])
gan_gen = MLP_G(ninput=model_args['z_size'],
noutput=model_args['nhidden'],
layers=model_args['arch_g'])
gan_disc = MLP_D(ninput=model_args['nhidden'],
noutput=1,
layers=model_args['arch_d'])
print('Loading models from'+load_path)
ae_path = os.path.join(load_path, "autoencoder_model.pt")
gen_path = os.path.join(load_path, "gan_gen_model.pt")
disc_path = os.path.join(load_path, "gan_disc_model.pt")
autoencoder.load_state_dict(torch.load(ae_path))
gan_gen.load_state_dict(torch.load(gen_path))
gan_disc.load_state_dict(torch.load(disc_path))
return model_args, idx2word, autoencoder, gan_gen, gan_disc
开发者ID:wangwang110,项目名称:ARAE,代码行数:26,代码来源:models.py
示例14: demo
def demo(data, save, depth=40, growth_rate=12, batch_size=256):
"""
Applies temperature scaling to a trained model.
Takes a pretrained DenseNet-CIFAR100 model, and a validation set
(parameterized by indices on train set).
Applies temperature scaling, and saves a temperature scaled version.
NB: the "save" parameter references a DIRECTORY, not a file.
In that directory, there should be two files:
- model.pth (model state dict)
- valid_indices.pth (a list of indices corresponding to the validation set).
data (str) - path to directory where data should be loaded from/downloaded
save (str) - directory with necessary files (see above)
"""
# Load model state dict
model_filename = os.path.join(save, 'model.pth')
if not os.path.exists(model_filename):
raise RuntimeError('Cannot find file %s to load' % model_filename)
state_dict = torch.load(model_filename)
# Load validation indices
valid_indices_filename = os.path.join(save, 'valid_indices.pth')
if not os.path.exists(valid_indices_filename):
raise RuntimeError('Cannot find file %s to load' % valid_indices_filename)
valid_indices = torch.load(valid_indices_filename)
# Regenerate validation set loader
mean = [0.5071, 0.4867, 0.4408]
stdv = [0.2675, 0.2565, 0.2761]
test_transforms = tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=mean, std=stdv),
])
valid_set = tv.datasets.CIFAR100(data, train=True, transform=test_transforms, download=True)
valid_loader = torch.utils.data.DataLoader(valid_set, pin_memory=True, batch_size=batch_size,
sampler=SubsetRandomSampler(valid_indices))
# Load original model
if (depth - 4) % 3:
raise Exception('Invalid depth')
block_config = [(depth - 4) // 6 for _ in range(3)]
orig_model = DenseNetEfficientMulti(
growth_rate=growth_rate,
block_config=block_config,
num_classes=100
).cuda()
orig_model.load_state_dict(state_dict)
# Now we're going to wrap the model with a decorator that adds temperature scaling
model = ModelWithTemperature(orig_model)
# Tune the model temperature, and save the results
model.set_temperature(valid_loader)
model_filename = os.path.join(save, 'model_with_temperature.pth')
torch.save(model.state_dict(), model_filename)
print('Temperature scaled model sved to %s' % model_filename)
print('Done!')
开发者ID:zhenglm,项目名称:temperature_scaling,代码行数:59,代码来源:demo.py
示例15: __init__
def __init__(self, file, labelFile):
self.train = torch.load(file)
self.label = torch.load(labelFile)
self.len = len(self.train) # get how many data points.
for i in range(0, self.len): # transform the imgs.
self.train[i] = transforms.Normalize((0.1307,), (0.3081,))(
self.train[i].view(1, -1)) # do a small transformation
self.train = self.train.view(-1, 1, 28, 28)
开发者ID:RobinROAR,项目名称:TensorflowTutorialsCode,代码行数:8,代码来源:utils.py
示例16: main_test
def main_test():
img_net, text_net = torch.load('img_net.pt'), torch.load('text_net.pt')
tiidlst = [l.strip() for l in file('test_ids.txt')]
img_dir = '/home/datasets/coco/raw/vgg19_feat/val/'
text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/val/'
img_feat_dataset = COCOImgFeatDataset(tiidlst, img_dir)
text_feat_dataset = COCOTextFeatDataset(tiidlst,text_dir)
test(tiidlst,img_feat_dataset,text_feat_dataset,img_net,text_net)
开发者ID:tyhu,项目名称:PyAI,代码行数:8,代码来源:train.py
示例17: load_train_valid_data
def load_train_valid_data(opt):
logging.info("Loading train and validate data from '%s'" % opt.data)
logging.info("Loading train/valid from disk: %s" % (opt.data))
data_dict = torch.load(opt.data, 'wb')
train_src = np.asarray([d['src'] for d in data_dict['train']])
train_trg = np.asarray([d['trg'] for d in data_dict['train']])
valid_src = np.asarray([d['src'] for d in data_dict['valid']])
valid_trg = np.asarray([d['trg'] for d in data_dict['valid']])
word2id, id2word, vocab = torch.load(opt.vocab, 'wb')
# training_data_loader = DataLoader(dataset=list(zip(train_src, train_trg)), num_workers=opt.batch_workers, batch_size=opt.batch_size, shuffle=True)
# validation_data_loader = DataLoader(dataset=list(zip(valid_src, valid_trg)), num_workers=opt.batch_workers, batch_size=opt.batch_size, shuffle=True)
src_field = torchtext.data.Field(
use_vocab = False,
init_token=word2id[pykp.io.BOS_WORD],
eos_token=word2id[pykp.io.EOS_WORD],
pad_token=word2id[pykp.io.PAD_WORD],
batch_first = True
)
trg_field = torchtext.data.Field(
use_vocab = False,
init_token=word2id[pykp.io.BOS_WORD],
eos_token=word2id[pykp.io.EOS_WORD],
pad_token=word2id[pykp.io.PAD_WORD],
batch_first=True
)
train = KeyphraseDatasetTorchText(list(zip(train_src, train_trg)), [('src', src_field), ('trg', trg_field)])
valid = KeyphraseDatasetTorchText(list(zip(valid_src, valid_trg)), [('src', src_field), ('trg', trg_field)])
if torch.cuda.is_available():
device = opt.gpuid
else:
device = -1
# training_data_loader = torchtext.data.BucketIterator(dataset=train, batch_size=opt.batch_size, train=True, repeat=True, shuffle=True, sort=False, device=device)
if torch.cuda.is_available():
training_data_loader = torchtext.data.BucketIterator(dataset=train, batch_size=opt.batch_size, train=True, shuffle=True, repeat=False, sort=True, device = None)
validation_data_loader = torchtext.data.BucketIterator(dataset=valid, batch_size=opt.batch_size, train=False, shuffle=False, repeat=False, sort=False, device = None)
else:
training_data_loader = torchtext.data.BucketIterator(dataset=train, batch_size=opt.batch_size, train=True, shuffle=True, repeat=False, sort=True, device = -1)
validation_data_loader = torchtext.data.BucketIterator(dataset=valid, batch_size=opt.batch_size, train=False, shuffle=False, repeat=False, sort=False, device = -1)
opt.word2id = word2id
opt.id2word = id2word
opt.vocab = vocab
logging.info('====================== Dataset =========================')
logging.info('#(training data pairs)=%d' % len(training_data_loader.dataset))
logging.info('#(validation data pairs)=%d' % len(validation_data_loader.dataset))
logging.info('#(vocab)=%d' % len(vocab))
logging.info('#(vocab used)=%d' % opt.vocab_size)
return training_data_loader, validation_data_loader, word2id, id2word, vocab
开发者ID:zhhengcs,项目名称:seq2seq-keyphrase-pytorch,代码行数:58,代码来源:train(old,no+copy,max+entropy+loss).py
示例18: get_outputs
def get_outputs(image_dir, filename):
models_name = ['resnet152', 'vgg19_bn', 'densenet161', 'nasnetalarge']
res = {}
res_labels = {}
for name in models_name:
if name == 'densenet161':
model_ft = torch.load('model_pretrained_densenet161.pkl')
elif name == 'resnet152':
model_ft = torch.load('model_pretrained_resnet152.pkl')
elif name == 'vgg19_bn':
model_ft = torch.load('model_pretrained_vgg19.pkl')
data_transforms = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
if name == 'nasnetalarge':
model_ft = pretrainedmodels.nasnetalarge(num_classes=1000, pretrained='imagenet')
data_transforms = transforms.Compose([
transforms.Scale(377),
transforms.CenterCrop(331),
transforms.ToTensor(),
transforms.Normalize(mean=model_ft.mean,
std=model_ft.std)])
model_ft = torch.load('model_pretrained_nasnet.pkl')
use_gpu = torch.cuda.is_available()
model_ft.eval()
test_dataset = TestData(image_dir, data_transforms)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)
since = time.time()
temp = []
temp_list = []
for i, batch in enumerate(test_dataloader):
inputs, cid = batch
temp_list.append(cid)
if use_gpu:
inputs = Variable(inputs.cuda())
else:
inputs = Variable(inputs)
outputs = model_ft(inputs)
temp.append(softmax(outputs.data.cpu().numpy()))
if i % 200 == 199:
print('iter:{}'.format(i+1))
time_elapsed = time.time() - since
print('Testing complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
res[name] = np.concatenate(temp)
res_labels[name] = [y for x in temp_list for y in x]
print('{} finish'.format(name))
torch.save(res, filename)
torch.save(res_labels, filename + '_label')
return res
开发者ID:LinfeiHe,项目名称:schoolwork,代码行数:58,代码来源:generate_outputs.py
示例19: setup
def setup(args, inject_train=None, inject_dev=None, inject_test=None):
torch.cuda.set_device(args.gpu)
### setup data
TEXT = data.Field()
LABEL = data.Field(sequential=False)
train_set, dev_set, test_set = datasets.SST.splits(
TEXT, LABEL, fine_grained=False, train_subtrees=True,
filter_pred=lambda x: x.label != 'neutral')
### inject special place holders to the datasets
if inject_train is not None:
train_set = inject_train(train_set)
if inject_dev is not None:
dev_set = inject_dev(dev_set)
if inject_test is not None:
test_set = inject_test(test_set)
TEXT.build_vocab(train_set)
LABEL.build_vocab(train_set)
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train_set, dev_set, test_set),
batch_size=args.batch_size, device=args.gpu)
# load word vectors
if args.wv_type:
if os.path.isfile(args.wv_cache):
TEXT.vocab.vectors = torch.load(args.wv_cache)
else:
TEXT.vocab.load_vectors(wv_dir=args.data_cache,
wv_type=args.wv_type, wv_dim=args.embed_size)
makedirs(os.path.dirname(args.wv_cache))
torch.save(TEXT.vocab.vectors, args.wv_cache)
args.vocab_size = len(TEXT.vocab)
args.embed_size = TEXT.vocab.vectors.size(1)
args.output_size = len(LABEL.vocab)
print('vocab size', args.vocab_size)
print('embed size', args.embed_size)
print('output size', args.output_size)
### setup model
if args.resume_snapshot:
print('loading snapshot', args.resume_snapshot)
model = torch.load(args.resume_snapshot,
map_location=lambda storage, location: storage.cuda(args.gpu))
else:
model = globals()[args.model_class](args)
if args.wv_type:
model.embed.weight.data = TEXT.vocab.vectors
if args.gpu >= 0:
model.cuda()
return args, TEXT, LABEL, train_iter, dev_iter, test_iter, model
开发者ID:ihsgnef,项目名称:imdb_word_replace,代码行数:58,代码来源:inject.py
示例20: load_pretrained
def load_pretrained(self):
self.D_cVAE.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cVAE.pkl')))
self.D_cLR.load_state_dict(torch.load(os.path.join(self.weight_dir, 'D_cLR.pkl')))
self.G.load_state_dict(torch.load(os.path.join(self.weight_dir, 'G.pkl')))
self.E.load_state_dict(torch.load(os.path.join(self.weight_dir, 'E.pkl')))
log_file = open('log.txt', 'r')
line = log_file.readline()
self.start_epoch = int(line)
开发者ID:Pandinosaurus,项目名称:BicycleGAN-pytorch,代码行数:9,代码来源:solver.py
注:本文中的torch.load函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论