• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python torch.softmax函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中torch.softmax函数的典型用法代码示例。如果您正苦于以下问题:Python softmax函数的具体用法?Python softmax怎么用?Python softmax使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了softmax函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: sample_relax_given_class

def sample_relax_given_class(logits, samp):

    cat = Categorical(logits=logits)

    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels

    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)


    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)


    z = z_tilde

    u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
    z_tilde_b = -torch.log(-torch.log(u_b))
    
    u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
    z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

    return z, z_tilde, logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:29,代码来源:plotting_cat_grads_dist.py


示例2: sample_relax

def sample_relax(logits): #, k=1):
    

    # u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
    u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    gumbels = -torch.log(-torch.log(u))
    z = logits + gumbels
    b = torch.argmax(z, dim=1)

    cat = Categorical(logits=logits)
    logprob = cat.log_prob(b).view(B,1)

    v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
    z_tilde_b = -torch.log(-torch.log(v_k))
    #this way seems biased even tho it shoudlnt be
    # v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
    # z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))

    v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
    probs = torch.softmax(logits,dim=1).repeat(B,1)
    # print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
    # fasdfa

    # print (v.shape)
    # print (v.shape)
    z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))

    # print (z_tilde)
    # print (z_tilde_b)
    z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
    # print (z_tilde)
    # fasdfs

    return z, b, logprob, z_tilde
开发者ID:chriscremer,项目名称:Other_Code,代码行数:34,代码来源:plotting_cat_grads_dist_4.py


示例3: updateOutput

 def updateOutput(self, input):
     if self.mininput is None:
         self.mininput = input.new()
     self.mininput.resize_as_(input).copy_(input).mul_(-1)
     self.output = torch.softmax(
         self.mininput,
         self._get_dim(input)
     )
     return self.output
开发者ID:RichieMay,项目名称:pytorch,代码行数:9,代码来源:SoftMin.py


示例4: simplax

def simplax(surrogate, x, logits, mixtureweights, k=1):

    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    outputs = {}
    net_loss = 0
    surr_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)

        logq = cat.log_prob(cluster_S.detach()).view(B,1)
        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.

        surr_input = torch.cat([cluster_S, x, logits], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)

        net_loss += - torch.mean((f.detach() - surr_pred.detach()) * logq  + surr_pred)


        # surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred))
        # grad_logq =  torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
        # grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_surr =  torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq + grad_surr)**2)

        surr_dif = torch.mean(torch.abs(f.detach() - surr_pred))
        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
        grad_score = torch.autograd.grad([torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0]
        grad_path = torch.mean(torch.abs(grad_path))
        grad_score = torch.mean(torch.abs(grad_score))
   
    net_loss = net_loss/ k
    surr_loss = surr_loss/ k

    outputs['net_loss'] = net_loss
    outputs['f'] = f
    outputs['logpx_given_z'] = logpx_given_z
    outputs['logpz'] = logpz
    outputs['logq'] = logq
    outputs['surr_loss'] = surr_loss
    outputs['surr_dif'] = surr_dif   
    outputs['grad_path'] = grad_path   
    outputs['grad_score'] = grad_score   

    return outputs #net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
开发者ID:chriscremer,项目名称:Other_Code,代码行数:57,代码来源:gmm_cleaned_v5.py


示例5: sample_relax_given_b

    def sample_relax_given_b(logits, b):

        u_b = torch.rand(B,1).clamp(1e-10, 1.-1e-10).cuda()
        z_tilde_b = -torch.log(-torch.log(u_b))

        u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits,dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        return z_tilde
开发者ID:chriscremer,项目名称:Other_Code,代码行数:10,代码来源:gmm_cleaned_v5.py


示例6: logprob_givenmixtureeweights

def logprob_givenmixtureeweights(x, needsoftmax_mixtureweight):

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    probs_sum = 0# = []
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        # for x in xs:
        component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
        # probs.append(probs)
        probs_sum+=component_i
    logprob = torch.log(probs_sum)
    return logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:12,代码来源:gmm_batch_v2.py


示例7: sample_relax_given_class_k

def sample_relax_given_class_k(logits, samp, k):

    cat = Categorical(logits=logits)
    b = samp #torch.argmax(z, dim=1)
    logprob = cat.log_prob(b).view(B,1)

    zs = []
    z_tildes = []
    for i in range(k):

        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        gumbels = -torch.log(-torch.log(u))
        z = logits + gumbels

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        z = z_tilde

        u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
        z_tilde_b = -torch.log(-torch.log(u_b))
        
        u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
        z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
        z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)

        zs.append(z)
        z_tildes.append(z_tilde)

    zs= torch.stack(zs)
    z_tildes= torch.stack(z_tildes)
    
    z = torch.mean(zs, dim=0)
    z_tilde = torch.mean(z_tildes, dim=0)

    return z, z_tilde, logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:39,代码来源:plotting_cat_grads_dist.py


示例8: show_surr_preds

def show_surr_preds():

    batch_size = 1

    rows = 3
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    for i in range(rows):

        x = sample_true(1).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        z = cluster_S

        n_evals = 40
        x1 = np.linspace(-9,205, n_evals)
        x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
        z = z.repeat(n_evals,1)
        cluster_H = cluster_H.repeat(n_evals,1)
        xz = torch.cat([z,x], dim=1) 

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_pred = surrogate.net(xz)
        surr_pred = surr_pred.data.cpu().numpy()
        f = f.data.cpu().numpy()

        col =0
        row = i
        # print (row)
        ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

        ax.plot(x1,surr_pred, label='Surr')
        ax.plot(x1,f, label='f')
        ax.set_title(str(cluster_H[0]))
        ax.legend()


    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_surr.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
开发者ID:chriscremer,项目名称:Other_Code,代码行数:51,代码来源:gmm_batch_v2.py


示例9: plot_dist

def plot_dist(x=None):

    if x is None:
        x1 = sample_true(1).cuda() 
    else:
        x1 = x[0].cpu().numpy()#.view(1,1)
        # print (x)

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)

    rows = 1
    cols = 1
    fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)


    xs = np.linspace(-9,205, 300)
    sum_ = np.zeros(len(xs))

    C = 20
    for c in range(C):
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
        ys = []
        for x in xs:
            # component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
            component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()


            ys.append(component_i)

        ys = np.reshape(np.array(ys), [-1])
        sum_ += ys
        ax.plot(xs, ys, label='')

    ax.plot(xs, sum_, label='')

    # print (x)
    ax.plot([x1,x1+.001],[0.,.002])
    # fasda

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'gmm_plot_dist.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
开发者ID:chriscremer,项目名称:Other_Code,代码行数:48,代码来源:gmm_batch_v2.py


示例10: true_posterior

def true_posterior(x, needsoftmax_mixtureweight):

    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    probs_ = []
    for c in range(n_components):
        m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
        component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
        # print(component_i.shape)
        # fsdf
        probs_.append(component_i[0])
    probs_ = torch.stack(probs_)
    probs_ = probs_ / torch.sum(probs_)
    # print (probs_.shape)
    # fdssdfd
    # logprob = torch.log(probs_sum)
    return probs_
开发者ID:chriscremer,项目名称:Other_Code,代码行数:16,代码来源:gmm_batch_v2.py


示例11: reinforce_baseline

def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)
    outputs = {}

    cat = Categorical(probs=probs)

    grads =[]
    # net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
        outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]

        surr_pred = surrogate.net(x)

        outputs['f'] = f = logpxz - logq - 1. 
        # outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
        outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

        # surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))

        grad_logq =  torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
        surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)

        if get_grad:
            grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
            grads.append(grad)

    # net_loss = net_loss/ k

    if get_grad:
        grads = torch.stack(grads)
        # print (grads.shape)
        outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
        outputs['grad_std'] = torch.std(grads, dim=0)[0]

    outputs['surr_loss'] = surr_loss
    # return net_loss, f, logpx_given_z, logpz, logq
    return outputs
开发者ID:chriscremer,项目名称:Other_Code,代码行数:44,代码来源:gmm_cleaned_v5.py


示例12: logprob_undercomponent

def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
    # c= component
    # C = c.
    B = x.shape[0]
    # print()
    # print (needsoftmax_mixtureweight.shape)
    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    # print (mixture_weights.shape)
    # fdsfa
    # probs_sum = 0# = []
    # for c in range(n_components):
    # m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
    mean = (component.float()*10.).view(B,1)
    std = (torch.ones([B]) *5.).view(B,1)
    # print (mean.shape) #[B]
    if not cuda:
        m = Normal(mean, std)#.cuda())
    else:
        m = Normal(mean.cuda(), std.cuda())
    # for x in xs:
    # component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
    # print (m.log_prob(x))
    # print (torch.log(mixture_weights[c]))
    # print(x.shape)
    logpx_given_z = m.log_prob(x)
    logpz = torch.log(mixture_weights[component]).view(B,1)
    # print (px_given_z.shape)
    # print (component)
    # print (mixture_weights)
    # print (mixture_weights[component])
    # print (torch.log(mixture_weights[component]).shape)
    # fdsasa
    # print (logpx_given_z.shape)
    # print (logpz.shape)
    # fsdfas
    logprob = logpx_given_z + logpz
    # print (logprob.shape)
    # fsfd
    # probs.append(probs)
    # probs_sum+=component_i
    # logprob = torch.log(component_i)
    return logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:42,代码来源:gmm_batch.py


示例13: inference_error

def inference_error():

    x = sample_true(1).cuda() 
    trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)



    logits = encoder.net(x)
    probs = torch.softmax(logits, dim=1).view(n_components)


    # print(trueposterior)
    # print (probs)
    # print ((trueposterior-probs)**2)
    # print()

    # print (trueposterior.shape)
    # print (probs.shape)
    # print (L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy()))
    return L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
开发者ID:chriscremer,项目名称:Other_Code,代码行数:20,代码来源:gmm_batch_v2.py


示例14: inference_error

def inference_error():

    error_sum = 0
    kl_sum = 0
    n=10
    for i in range(n):

        # if x is None:
        x = sample_true(1).cuda() 
        trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

        logits = encoder.net(x)
        probs = torch.softmax(logits/100., dim=1).view(n_components)

        error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
        kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy())

        error_sum+=error
        kl_sum += kl
    
    return error_sum/n, kl_sum/n
开发者ID:chriscremer,项目名称:Other_Code,代码行数:21,代码来源:gmm_batch.py


示例15: logprob_undercomponent

def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
    c= component
    # print (needsoftmax_mixtureweight.shape)
    mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
    # probs_sum = 0# = []
    # for c in range(n_components):
    # m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
    if not cuda:
        m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float() )#.cuda())
    else:
        m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
    # for x in xs:
    # component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
    # print (m.log_prob(x))
    # print (torch.log(mixture_weights[c]))

    logprob = m.log_prob(x) + torch.log(mixture_weights[c])
    # probs.append(probs)
    # probs_sum+=component_i
    # logprob = torch.log(component_i)
    return logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:21,代码来源:gmm.py


示例16: reinforce

def reinforce(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = Categorical(probs=probs)

    net_loss = 0
    for jj in range(k):

        cluster_H = cat.sample()
        logq = cat.log_prob(cluster_H).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq
        net_loss += - torch.mean((f.detach() - 1.) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
开发者ID:chriscremer,项目名称:Other_Code,代码行数:22,代码来源:gmm_cleaned_v3.py


示例17: forward

    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)
        ).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)
开发者ID:Unbabel,项目名称:OpenNMT-py,代码行数:38,代码来源:copy_generator.py


示例18: reinforce_pz

def reinforce_pz(x, logits, mixtureweights, k=1):
    B = logits.shape[0]
    probs = torch.softmax(logits, dim=1)

    cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())

    net_loss = 0
    for jj in range(k):

        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        logq = cat.log_prob(cluster_S.detach()).view(B,1)

        logpx_given_z = logprob_undercomponent(x, component=cluster_H)
        logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
        logpxz = logpx_given_z + logpz #[B,1]
        f = logpxz - logq - 1.
        net_loss += - torch.mean((f.detach()) * logq)
        # net_loss += - torch.mean( -logq.detach()*logq)

    net_loss = net_loss/ k

    return net_loss, f, logpx_given_z, logpz, logq
开发者ID:chriscremer,项目名称:Other_Code,代码行数:23,代码来源:gmm_cleaned_v5.py


示例19: plot_posteriors

def plot_posteriors(x=None, name=''):

    if x is None:
        x = sample_true(1).cuda() 
    else:
        x = x[0].view(1,1)


    trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)

    logits = encoder.net(x)
    probs = torch.softmax(logits, dim=1).view(n_components)

    trueposterior = trueposterior.data.cpu().numpy()
    qz = probs.data.cpu().numpy()

    rows = 1
    cols = 1
    fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150)

    col =0
    row = 0
    ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)

    width = .3
    ax.bar(range(len(qz)), trueposterior, width=width, label='True')
    ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q')
    # ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width)
    ax.legend()
    ax.grid(True, alpha=.3)
    ax.set_title(str(x))

    # save_dir = home+'/Documents/Grad_Estimators/GMM/'
    plt_path = exp_dir+'posteriors' + name+'.png'
    plt.savefig(plt_path)
    print ('saved training plot', plt_path)
    plt.close()
开发者ID:chriscremer,项目名称:Other_Code,代码行数:37,代码来源:gmm_batch_v2.py


示例20: get_loss

    def get_loss():

        x = sample_true(batch_size).cuda() #.view(1,1)
        logits = encoder.net(x)
        probs = torch.softmax(logits, dim=1)
        cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
        cluster_S = cat.rsample()
        cluster_H = H(cluster_S)
        # cluster_onehot = torch.zeros(n_components)
        # cluster_onehot[cluster_H] = 1.
        logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
        check_nan(logprob_cluster)

        logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
        f = logpxz - logprob_cluster

        surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
        surr_pred = surrogate.net(surr_input)
        
        # net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
        # loss = - torch.mean(f)
        surr_loss = torch.mean(torch.abs(f.detach()-surr_pred))

        return surr_loss
开发者ID:chriscremer,项目名称:Other_Code,代码行数:24,代码来源:gmm_batch_v2.py



注:本文中的torch.softmax函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python torch.sort函数代码示例发布时间:2022-05-27
下一篇:
Python torch.sigmoid函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap