• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python utils.save_image函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中torchvision.utils.save_image函数的典型用法代码示例。如果您正苦于以下问题:Python save_image函数的具体用法?Python save_image怎么用?Python save_image使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了save_image函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test

    def test(self):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        
        # Set data loader.
        if self.dataset == 'CelebA':
            data_loader = self.celeba_loader
        elif self.dataset == 'RaFD':
            data_loader = self.rafd_loader
        
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs)

                # Translate images.
                x_fake_list = [x_real]
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:28,代码来源:solver.py


示例2: test_multi

    def test_multi(self):
        """Translate images using StarGAN trained on multiple datasets."""
        # Load the trained generator.
        self.restore_model(self.test_iters)
        
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(self.celeba_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs)
                c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD')
                zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device)            # Zero vector for CelebA.
                zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device)             # Zero vector for RaFD.
                mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device)  # Mask vector: [1, 0].
                mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device)     # Mask vector: [0, 1].

                # Translate images.
                x_fake_list = [x_real]
                for c_celeba in c_celeba_list:
                    c_trg = torch.cat([c_celeba, zero_rafd, mask_celeba], dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))
                for c_rafd in c_rafd_list:
                    c_trg = torch.cat([zero_celeba, c_rafd, mask_rafd], dim=1)
                    x_fake_list.append(self.G(x_real, c_trg))

                # Save the translated images.
                x_concat = torch.cat(x_fake_list, dim=3)
                result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                print('Saved real and fake images into {}...'.format(result_path))
开发者ID:JacobLee121,项目名称:StarGAN,代码行数:31,代码来源:solver.py


示例3: sampleTrue

def sampleTrue(dataset, imageSize, dataroot, sampleSize, batchSize, saveFolder, workers=4):
    print('sampling real images ...')
    saveFolder = saveFolder + '0/'

    dataset = make_dataset(dataset, dataroot, imageSize)
    dataloader = torch.utils.data.DataLoader(
        dataset, shuffle=True, batch_size=batchSize, num_workers=int(workers))

    if not os.path.exists(saveFolder):
        try:
            os.makedirs(saveFolder)
        except OSError:
            pass

    iter = 0
    for i, data in enumerate(dataloader, 0):
        img, _ = data
        for j in range(0, len(img)):

            vutils.save_image(img[j].mul(0.5).add(
                0.5), saveFolder + giveName(iter) + ".png")
            iter += 1
            if iter >= sampleSize:
                break
        if iter >= sampleSize:
            break
开发者ID:RobinROAR,项目名称:TensorflowTutorialsCode,代码行数:26,代码来源:metric.py


示例4: save_img_results

def save_img_results(imgs_tcpu, fake_imgs, num_imgs,
                     count, image_dir, summary_writer):
    num = cfg.TRAIN.VIS_COUNT

    # The range of real_img (i.e., self.imgs_tcpu[i][0:num])
    # is changed to [0, 1] by function vutils.save_image
    real_img = imgs_tcpu[-1][0:num]
    vutils.save_image(
        real_img, '%s/real_samples.png' % (image_dir),
        normalize=True)
    real_img_set = vutils.make_grid(real_img).numpy()
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    real_img_set = real_img_set * 255
    real_img_set = real_img_set.astype(np.uint8)
    sup_real_img = summary.image('real_img', real_img_set)
    summary_writer.add_summary(sup_real_img, count)

    for i in range(num_imgs):
        fake_img = fake_imgs[i][0:num]
        # The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
        # is still [-1. 1]...
        vutils.save_image(
            fake_img.data, '%s/count_%09d_fake_samples%d.png' %
            (image_dir, count, i), normalize=True)

        fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()

        fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
        fake_img_set = (fake_img_set + 1) * 255 / 2
        fake_img_set = fake_img_set.astype(np.uint8)

        sup_fake_img = summary.image('fake_img%d' % i, fake_img_set)
        summary_writer.add_summary(sup_fake_img, count)
        summary_writer.flush()
开发者ID:tensoralex,项目名称:StackGAN-v2,代码行数:34,代码来源:trainer.py


示例5: plot_rec

def plot_rec(x, netEC, netEP, netD):
    x_c = x[0]
    x_p = x[np.random.randint(1, opt.max_step)]

    h_c = netEC(x_c)
    h_p = netEP(x_p)

    # print('h_c shape: ', h_c.shape)
    # print('h p shape: ', h_p.shape)
    rec = netD([h_c, h_p])

    x_c, x_p, rec = x_c.data, x_p.data, rec.data
    fname = '%s/rec/rec_test.png' % (opt.log_dir)

    comparison = None
    for i in range(len(x_c)):
        if comparison is None:
            comparison = torch.stack([x_c[i], x_p[i], rec[i]])
        else:
            new_comparison = torch.stack([x_c[i], x_p[i], rec[i]])
            comparison = torch.cat([comparison, new_comparison])
    print('comparison: ', comparison.shape)

    # row_sz = 5
    # nplot = 20
    # for i in range(0, nplot - row_sz, row_sz):
    #     row = [[xc, xp, xr] for xc, xp, xr in zip(x_c[i:i + row_sz], x_p[i:i + row_sz], rec[i:i + row_sz])]
    #     print('row: ', row)
    #     to_plot.append(list(itertools.chain(*row)))
    # print(len(to_plot[0]))
    # utils.save_tensors_image(fname, comparison)
    if not os.path.exists(os.path.dirname(fname)):
        os.makedirs(os.path.dirname(fname))
    save_image(comparison.cpu(), fname, nrow=3)
开发者ID:ZhenyueQin,项目名称:drnet-py,代码行数:34,代码来源:drnet_test_field.py


示例6: test

    def test(self):
        """Facial attribute transfer on CelebA or facial expression synthesis on RaFD."""
        # Load trained parameters
        G_path = os.path.join(self.model_save_path, '{}_G.pth'.format(self.test_model))
        self.G.load_state_dict(torch.load(G_path))
        self.G.eval()

        if self.dataset == 'CelebA':
            data_loader = self.celebA_loader
        else:
            data_loader = self.rafd_loader

        for i, (real_x, org_c) in enumerate(data_loader):
            real_x = self.to_var(real_x, volatile=True)

            if self.dataset == 'CelebA':
                target_c_list = self.make_celeb_labels(org_c)
            else:
                target_c_list = []
                for j in range(self.c_dim):
                    target_c = self.one_hot(torch.ones(real_x.size(0)) * j, self.c_dim)
                    target_c_list.append(self.to_var(target_c, volatile=True))

            # Start translations
            fake_image_list = [real_x]
            for target_c in target_c_list:
                fake_image_list.append(self.G(real_x, target_c))
            fake_images = torch.cat(fake_image_list, dim=3)
            save_path = os.path.join(self.result_path, '{}_fake.png'.format(i+1))
            save_image(self.denorm(fake_images.data), save_path, nrow=1, padding=0)
            print('Translated test images and saved into "{}"..!'.format(save_path))
开发者ID:rafalsc,项目名称:StarGAN,代码行数:31,代码来源:solver.py


示例7: saver

 def saver(state):
     if state[torchbearer.BATCH] == 0:
         data = state[torchbearer.X]
         recon_batch = state[torchbearer.Y_PRED]
         comparison = torch.cat([data[:num_images],
                                 recon_batch.view(128, 1, 28, 28)[:num_images]])
         save_image(comparison.cpu(),
                    str(folder) + 'reconstruction_' + str(state[torchbearer.EPOCH]) + '.png', nrow=num_images)
开发者ID:little1tow,项目名称:torchbearer,代码行数:8,代码来源:vae.py


示例8: sample_image

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:9,代码来源:cgan.py


示例9: _train

    def _train(self, epoch):
        """Perform the actual train."""
        # put model into train mode
        self.d_model.train()
        # TODO: why?
        cp_loader = deepcopy(self.train_loader)
        if self.verbose:
            progress_bar = tqdm(total=len(cp_loader),
                                desc='Current Epoch',
                                file=sys.stdout,
                                leave=False,
                                ncols=75,
                                position=0,
                                unit=' Batch')
        else:
            progress_bar = None
        real_label = 1
        fake_label = 0
        for batch_idx, inputs in enumerate(cp_loader):
            # Update Discriminator network maximize log(D(x)) + log(1 - D(G(z)))
            # train with real
            self.optimizer_d.zero_grad()
            inputs = inputs.to(self.device)
            batch_size = inputs.size(0)
            outputs = self.d_model(inputs)

            label = torch.full((batch_size,), real_label, device=self.device)
            loss_d_real = self.loss_function(outputs, label)
            loss_d_real.backward()

            # train with fake
            noise = torch.randn((batch_size, self.g_model.nz, 1, 1,), device=self.device)
            fake_outputs = self.g_model(noise)
            label.fill_(fake_label)
            outputs = self.d_model(fake_outputs.detach())
            loss_g_fake = self.loss_function(outputs, label)
            loss_g_fake.backward()
            self.optimizer_d.step()
            # (2) Update G network: maximize log(D(G(z)))
            self.g_model.zero_grad()
            label.fill_(real_label)
            outputs = self.d_model(fake_outputs)
            loss_g = self.loss_function(outputs, label)
            loss_g.backward()
            self.optimizer_g.step()

            if self.verbose:
                if batch_idx % 10 == 0:
                    progress_bar.update(10)
            if self.out_f is not None and batch_idx % 100 == 0:
                fake = self.g_model(self.sample_noise)
                vutils.save_image(
                    fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (self.out_f, epoch),
                    normalize=True)
        if self.verbose:
            progress_bar.close()
开发者ID:Saiuz,项目名称:autokeras,代码行数:57,代码来源:model_trainer.py


示例10: sample_images

def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    real_A = Variable(imgs['A'].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs['B'].type(Tensor))
    fake_A = G_BA(real_B)
    img_sample = torch.cat((real_A.data, fake_B.data,
                            real_B.data, fake_A.data), 0)
    save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:10,代码来源:cyclegan.py


示例11: save_image

def save_image(img):
    post = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1./255)),
         transforms.Normalize(mean=[-0.40760392, -0.45795686, -0.48501961], std=[1,1,1]),
         transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])]), #turn to RGB
         ])
    img = post(img)
    img = img.clamp_(0,1)
    vutils.save_image(img,
                '%s/transfer.png' % (opt.outf),
                normalize=True)
    return
开发者ID:HadXu,项目名称:machine-learning,代码行数:11,代码来源:train.py


示例12: reconstruction_loss

    def reconstruction_loss(self, images, input, size_average=True):
        # Get the lengths of capsule outputs.
        v_mag = torch.sqrt((input**2).sum(dim=2))

        # Get index of longest capsule output.
        _, v_max_index = v_mag.max(dim=1)
        v_max_index = v_max_index.data

        # Use just the winning capsule's representation (and zeros for other capsules) to reconstruct input image.
        batch_size = input.size(0)
        all_masked = [None] * batch_size
        for batch_idx in range(batch_size):
            # Get one sample from the batch.
            input_batch = input[batch_idx]

            # Copy only the maximum capsule index from this batch sample.
            # This masks out (leaves as zero) the other capsules in this sample.
            batch_masked = Variable(torch.zeros(input_batch.size())).cuda()
            batch_masked[v_max_index[batch_idx]] = input_batch[v_max_index[batch_idx]]
            all_masked[batch_idx] = batch_masked

        # Stack masked capsules over the batch dimension.
        masked = torch.stack(all_masked, dim=0)

        # Reconstruct input image.
        masked = masked.view(input.size(0), -1)
        output = self.relu(self.reconstruct0(masked))
        output = self.relu(self.reconstruct1(output))
        output = self.sigmoid(self.reconstruct2(output))
        output = output.view(-1, self.image_channels, self.image_height, self.image_width)

        # Save reconstructed images occasionally.
        if self.reconstructed_image_count % 10 == 0:
            if output.size(1) == 2:
                # handle two-channel images
                zeros = torch.zeros(output.size(0), 1, output.size(2), output.size(3))
                output_image = torch.cat([zeros, output.data.cpu()], dim=1)
            else:
                # assume RGB or grayscale
                output_image = output.data.cpu()
            vutils.save_image(output_image, "reconstruction.png")
        self.reconstructed_image_count += 1

        # The reconstruction loss is the sum squared difference between the input image and reconstructed image.
        # Multiplied by a small number so it doesn't dominate the margin (class) loss.
        error = (output - images).view(output.size(0), -1)
        error = error**2
        error = torch.sum(error, dim=1) * 0.0005

        # Average over batch
        if size_average:
            error = error.mean()

        return error
开发者ID:weridmaid,项目名称:pytorch-capsule,代码行数:54,代码来源:capsule_network.py


示例13: sample_images

def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    X1 = Variable(imgs['A'].type(Tensor))
    X2 = Variable(imgs['B'].type(Tensor))
    _, Z1 = E1(X1)
    _, Z2 = E2(X2)
    fake_X1 = G1(Z2)
    fake_X2 = G2(Z1)
    img_sample = torch.cat((X1.data, fake_X2.data,
                            X2.data, fake_X1.data), 0)
    save_image(img_sample, 'images/%s/%s.png' % (opt.dataset_name, batches_done), nrow=5, normalize=True)
开发者ID:hjpwhu,项目名称:PyTorch-GAN,代码行数:12,代码来源:unit.py


示例14: save_images

def save_images(netG, fixed_noise, outputDir, epoch):
    '''
    Generates a batch of images from the given 'noise'.
    Saves 64 of the generated samples to 'outputDir' system path.
    Inputs are the network (netG), a 'noise' input, system path to which images will be saved (outputDir) and current 'epoch'.
    '''
    noise = Variable(fixed_noise)
    netG.eval()
    fake = netG(noise)
    netG.train()
    vutils.save_image(
        fake.data[0:64, :, :, :], '%s/fake_samples_epoch_%03d.png' % (outputDir, epoch), nrow=8)
开发者ID:apsvieira,项目名称:Projeto_final_ia368z,代码行数:12,代码来源:classifier_DCGAN_review.py


示例15: reconstruct_test

 def reconstruct_test(self,epoch):
     for i,batch in enumerate(self.test_dataloader):
         images = batch['image']
         images = images.float()
         bumps = batch['bump']
         bumps = bumps.float()
         masks = batch['mask']
         masks = masks.float()
         images = Variable(images.cuda())
         recon_mask, recon = self.Gnet.forward(images)
         output = torch.cat((masks,recon_mask.data.cpu(),bumps,recon.data.cpu()),dim=3)
         utils.save_image(output, net.outpath + '/'+str(epoch)+'.'+str(i)+'.jpg',nrow=4, normalize=True)
开发者ID:whztt07,项目名称:extreme_3d_faces,代码行数:12,代码来源:bumpMapRegressor.py


示例16: save

    def save(self, source, iteration):

        save_dir = os.path.join(self.model_dir, "gen_images")

        if os.path.exists(save_dir) == False:
            os.mkdir(save_dir)
        images_file = os.path.join(save_dir, "image_{}.png".format(iteration))

        if self.cuda:
            source = source.cuda()

        source = Variable(source)
        outputs = self.gen_model(source)
        vutils.save_image(outputs.cpu().data, images_file, normalize=True)
开发者ID:mamonraab,项目名称:TorchFusion,代码行数:14,代码来源:models.py


示例17: pp_interp

def pp_interp(net, alpha):
    """
    Only works with model_resnet_preproc.py as your
      architecture!!!
    """
    conv2d = net.d.preproc
    deconv2d = nn.ConvTranspose2d(16, 3, 3, stride=1, padding=1)
    deconv2d = deconv2d.cuda()
    deconv2d.weight = conv2d.weight

    gz1 = net.sample(bs=128)
    gz2 = net.sample(bs=128)

    #alpha = net.sample_lambda(gz1.size(0))
    gz_mix = alpha*gz1 + (1.-alpha)*gz2

    save_image(gz1*0.5 + 0.5, filename="gz1.png")
    save_image(gz2*0.5 + 0.5, filename="gz2.png")
    save_image(gz_mix*0.5 + 0.5, filename="gz_mix.png")

    # Ok, do the mixup in hidden space.

    gz1_h = conv2d(gz1)
    gz2_h = conv2d(gz2)
    #alpha = 0.05
    gz_mix_h = alpha*gz1_h + (1.-alpha)*gz2_h
    gz_mix_h_dec = deconv2d(gz_mix_h)
    save_image(gz_mix_h_dec*0.5 + 0.5, filename="gz_mix_h_dec.png")

    print(conv2d.weight == deconv2d.weight)

    
    import pdb
    pdb.set_trace()
开发者ID:kazk1018,项目名称:manifold_mixup,代码行数:34,代码来源:interactive.py


示例18: closure

	def closure():
		optimizer.zero_grad()
		out = resnet(generated)
		style_loss = [GramMSELoss().cuda()(out[i],style_target[i])*style_weight[i] for i in range(len(style_target))]
		content_loss = nn.MSELoss().cuda()(out[content_layer_num],content_target)
		total_loss = 1000 * sum(style_loss) + sum(content_loss)
		total_loss.backward()

		if iteration[0] % 100 == 0:
			print(total_loss)
			v_utils.save_image(image_postprocess(generated.data),"./gen_{}.png".format(iteration[0]))
		iteration[0] += 1

		return total_loss
开发者ID:c00lrain,项目名称:PyTorch-FastCampus,代码行数:14,代码来源:StyleTransfer_LBFGS_gpu.py


示例19: generate

 def generate(self, input_sample=None):
     if input_sample is None:
         input_sample = torch.randn(self.gen_training_result[1], self.nz, 1, 1, device=self.device)
     if not isinstance(input_sample, torch.Tensor) and \
             isinstance(input_sample, np.ndarray):
         input_sample = torch.from_numpy(input_sample)
     if not isinstance(input_sample, torch.Tensor) and \
             not isinstance(input_sample, np.ndarray):
         raise TypeError("Input should be a torch.tensor or a numpy.ndarray")
     self.net_g.eval()
     with torch.no_grad():
         input_sample = input_sample.to(self.device)
         generated_fake = self.net_g(input_sample)
     vutils.save_image(generated_fake.detach(),
                       '%s/evaluation.png' % self.gen_training_result[0],
                       normalize=True)
开发者ID:Saiuz,项目名称:autokeras,代码行数:16,代码来源:gan.py


示例20: test

def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
开发者ID:chiminghui,项目名称:examples,代码行数:17,代码来源:main.py



注:本文中的torchvision.utils.save_image函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python mcatalog.MCatalog类代码示例发布时间:2022-05-27
下一篇:
Python torchbearer.Model类代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap