本文整理汇总了Python中torch.softmax函数的典型用法代码示例。如果您正苦于以下问题:Python softmax函数的具体用法?Python softmax怎么用?Python softmax使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了softmax函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: sample_relax_given_class
def sample_relax_given_class(logits, samp):
cat = Categorical(logits=logits)
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = samp #torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
z = z_tilde
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
return z, z_tilde, logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:29,代码来源:plotting_cat_grads_dist.py
示例2: sample_relax
def sample_relax(logits): #, k=1):
# u = torch.rand(B,C).clamp(1e-8, 1.-1e-8) #.cuda()
u = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
b = torch.argmax(z, dim=1)
cat = Categorical(logits=logits)
logprob = cat.log_prob(b).view(B,1)
v_k = torch.rand(B,1).clamp(1e-12, 1.-1e-12)
z_tilde_b = -torch.log(-torch.log(v_k))
#this way seems biased even tho it shoudlnt be
# v_k = torch.gather(input=u, dim=1, index=b.view(B,1))
# z_tilde_b = torch.gather(input=z, dim=1, index=b.view(B,1))
v = torch.rand(B,C).clamp(1e-12, 1.-1e-12) #.cuda()
probs = torch.softmax(logits,dim=1).repeat(B,1)
# print (probs.shape, torch.log(v_k).shape, torch.log(v).shape)
# fasdfa
# print (v.shape)
# print (v.shape)
z_tilde = -torch.log((- torch.log(v) / probs) - torch.log(v_k))
# print (z_tilde)
# print (z_tilde_b)
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
# print (z_tilde)
# fasdfs
return z, b, logprob, z_tilde
开发者ID:chriscremer,项目名称:Other_Code,代码行数:34,代码来源:plotting_cat_grads_dist_4.py
示例3: updateOutput
def updateOutput(self, input):
if self.mininput is None:
self.mininput = input.new()
self.mininput.resize_as_(input).copy_(input).mul_(-1)
self.output = torch.softmax(
self.mininput,
self._get_dim(input)
)
return self.output
开发者ID:RichieMay,项目名称:pytorch,代码行数:9,代码来源:SoftMin.py
示例4: simplax
def simplax(surrogate, x, logits, mixtureweights, k=1):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
outputs = {}
net_loss = 0
surr_loss = 0
for jj in range(k):
cluster_S = cat.rsample()
cluster_H = H(cluster_S)
logq = cat.log_prob(cluster_S.detach()).view(B,1)
logpx_given_z = logprob_undercomponent(x, component=cluster_H)
logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq - 1.
surr_input = torch.cat([cluster_S, x, logits], dim=1) #[B,21]
surr_pred = surrogate.net(surr_input)
net_loss += - torch.mean((f.detach() - surr_pred.detach()) * logq + surr_pred)
# surr_loss += torch.mean(torch.abs(f.detach()-1.-surr_pred))
# grad_logq = torch.mean( torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
# grad_surr = torch.mean( torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0], dim=1, keepdim=True)
grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
grad_surr = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq + grad_surr)**2)
surr_dif = torch.mean(torch.abs(f.detach() - surr_pred))
# surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))
grad_path = torch.autograd.grad([torch.mean(surr_pred)], [logits], create_graph=True, retain_graph=True)[0]
grad_score = torch.autograd.grad([torch.mean((f.detach() - surr_pred.detach()) * logq)], [logits], create_graph=True, retain_graph=True)[0]
grad_path = torch.mean(torch.abs(grad_path))
grad_score = torch.mean(torch.abs(grad_score))
net_loss = net_loss/ k
surr_loss = surr_loss/ k
outputs['net_loss'] = net_loss
outputs['f'] = f
outputs['logpx_given_z'] = logpx_given_z
outputs['logpz'] = logpz
outputs['logq'] = logq
outputs['surr_loss'] = surr_loss
outputs['surr_dif'] = surr_dif
outputs['grad_path'] = grad_path
outputs['grad_score'] = grad_score
return outputs #net_loss, f, logpx_given_z, logpz, logq, surr_loss, surr_dif, grad_path, grad_score
开发者ID:chriscremer,项目名称:Other_Code,代码行数:57,代码来源:gmm_cleaned_v5.py
示例5: sample_relax_given_b
def sample_relax_given_b(logits, b):
u_b = torch.rand(B,1).clamp(1e-10, 1.-1e-10).cuda()
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).clamp(1e-10, 1.-1e-10).cuda()
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits,dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
return z_tilde
开发者ID:chriscremer,项目名称:Other_Code,代码行数:10,代码来源:gmm_cleaned_v5.py
示例6: logprob_givenmixtureeweights
def logprob_givenmixtureeweights(x, needsoftmax_mixtureweight):
mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
probs_sum = 0# = []
for c in range(n_components):
m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
# for x in xs:
component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
# probs.append(probs)
probs_sum+=component_i
logprob = torch.log(probs_sum)
return logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:12,代码来源:gmm_batch_v2.py
示例7: sample_relax_given_class_k
def sample_relax_given_class_k(logits, samp, k):
cat = Categorical(logits=logits)
b = samp #torch.argmax(z, dim=1)
logprob = cat.log_prob(b).view(B,1)
zs = []
z_tildes = []
for i in range(k):
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
gumbels = -torch.log(-torch.log(u))
z = logits + gumbels
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
z = z_tilde
u_b = torch.gather(input=u, dim=1, index=b.view(B,1))
z_tilde_b = -torch.log(-torch.log(u_b))
u = torch.rand(B,C).clamp(1e-8, 1.-1e-8)
z_tilde = -torch.log((- torch.log(u) / torch.softmax(logits, dim=1)) - torch.log(u_b))
z_tilde.scatter_(dim=1, index=b.view(B,1), src=z_tilde_b)
zs.append(z)
z_tildes.append(z_tilde)
zs= torch.stack(zs)
z_tildes= torch.stack(z_tildes)
z = torch.mean(zs, dim=0)
z_tilde = torch.mean(z_tildes, dim=0)
return z, z_tilde, logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:39,代码来源:plotting_cat_grads_dist.py
示例8: show_surr_preds
def show_surr_preds():
batch_size = 1
rows = 3
cols = 1
fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)
for i in range(rows):
x = sample_true(1).cuda() #.view(1,1)
logits = encoder.net(x)
probs = torch.softmax(logits, dim=1)
cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
cluster_S = cat.rsample()
cluster_H = H(cluster_S)
logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
check_nan(logprob_cluster)
z = cluster_S
n_evals = 40
x1 = np.linspace(-9,205, n_evals)
x = torch.from_numpy(x1).view(n_evals,1).float().cuda()
z = z.repeat(n_evals,1)
cluster_H = cluster_H.repeat(n_evals,1)
xz = torch.cat([z,x], dim=1)
logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
f = logpxz - logprob_cluster
surr_pred = surrogate.net(xz)
surr_pred = surr_pred.data.cpu().numpy()
f = f.data.cpu().numpy()
col =0
row = i
# print (row)
ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)
ax.plot(x1,surr_pred, label='Surr')
ax.plot(x1,f, label='f')
ax.set_title(str(cluster_H[0]))
ax.legend()
# save_dir = home+'/Documents/Grad_Estimators/GMM/'
plt_path = exp_dir+'gmm_surr.png'
plt.savefig(plt_path)
print ('saved training plot', plt_path)
plt.close()
开发者ID:chriscremer,项目名称:Other_Code,代码行数:51,代码来源:gmm_batch_v2.py
示例9: plot_dist
def plot_dist(x=None):
if x is None:
x1 = sample_true(1).cuda()
else:
x1 = x[0].cpu().numpy()#.view(1,1)
# print (x)
mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
rows = 1
cols = 1
fig = plt.figure(figsize=(10+cols,4+rows), facecolor='white') #, dpi=150)
col =0
row = 0
ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)
xs = np.linspace(-9,205, 300)
sum_ = np.zeros(len(xs))
C = 20
for c in range(C):
m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float())
ys = []
for x in xs:
# component_i = (torch.exp(m.log_prob(x) )* ((c+5.) / 290.)).numpy()
component_i = (torch.exp(m.log_prob(x) )* mixture_weights[c]).detach().cpu().numpy()
ys.append(component_i)
ys = np.reshape(np.array(ys), [-1])
sum_ += ys
ax.plot(xs, ys, label='')
ax.plot(xs, sum_, label='')
# print (x)
ax.plot([x1,x1+.001],[0.,.002])
# fasda
# save_dir = home+'/Documents/Grad_Estimators/GMM/'
plt_path = exp_dir+'gmm_plot_dist.png'
plt.savefig(plt_path)
print ('saved training plot', plt_path)
plt.close()
开发者ID:chriscremer,项目名称:Other_Code,代码行数:48,代码来源:gmm_batch_v2.py
示例10: true_posterior
def true_posterior(x, needsoftmax_mixtureweight):
mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
probs_ = []
for c in range(n_components):
m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
# print(component_i.shape)
# fsdf
probs_.append(component_i[0])
probs_ = torch.stack(probs_)
probs_ = probs_ / torch.sum(probs_)
# print (probs_.shape)
# fdssdfd
# logprob = torch.log(probs_sum)
return probs_
开发者ID:chriscremer,项目名称:Other_Code,代码行数:16,代码来源:gmm_batch_v2.py
示例11: reinforce_baseline
def reinforce_baseline(surrogate, x, logits, mixtureweights, k=1, get_grad=False):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
outputs = {}
cat = Categorical(probs=probs)
grads =[]
# net_loss = 0
for jj in range(k):
cluster_H = cat.sample()
outputs['logq'] = logq = cat.log_prob(cluster_H).view(B,1)
outputs['logpx_given_z'] = logpx_given_z = logprob_undercomponent(x, component=cluster_H)
outputs['logpz'] = logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
surr_pred = surrogate.net(x)
outputs['f'] = f = logpxz - logq - 1.
# outputs['net_loss'] = net_loss = net_loss - torch.mean((f.detach() ) * logq)
outputs['net_loss'] = net_loss = - torch.mean((f.detach() - surr_pred.detach()) * logq)
# net_loss += - torch.mean( -logq.detach()*logq)
# surr_loss = torch.mean(torch.abs(f.detach() - surr_pred))
grad_logq = torch.autograd.grad([torch.mean(logq)], [logits], create_graph=True, retain_graph=True)[0]
surr_loss = torch.mean(((f.detach() - surr_pred) * grad_logq )**2)
if get_grad:
grad = torch.autograd.grad([net_loss], [logits], create_graph=True, retain_graph=True)[0]
grads.append(grad)
# net_loss = net_loss/ k
if get_grad:
grads = torch.stack(grads)
# print (grads.shape)
outputs['grad_avg'] = torch.mean(torch.mean(grads, dim=0),dim=0)
outputs['grad_std'] = torch.std(grads, dim=0)[0]
outputs['surr_loss'] = surr_loss
# return net_loss, f, logpx_given_z, logpz, logq
return outputs
开发者ID:chriscremer,项目名称:Other_Code,代码行数:44,代码来源:gmm_cleaned_v5.py
示例12: logprob_undercomponent
def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
# c= component
# C = c.
B = x.shape[0]
# print()
# print (needsoftmax_mixtureweight.shape)
mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
# print (mixture_weights.shape)
# fdsfa
# probs_sum = 0# = []
# for c in range(n_components):
# m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
mean = (component.float()*10.).view(B,1)
std = (torch.ones([B]) *5.).view(B,1)
# print (mean.shape) #[B]
if not cuda:
m = Normal(mean, std)#.cuda())
else:
m = Normal(mean.cuda(), std.cuda())
# for x in xs:
# component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
# print (m.log_prob(x))
# print (torch.log(mixture_weights[c]))
# print(x.shape)
logpx_given_z = m.log_prob(x)
logpz = torch.log(mixture_weights[component]).view(B,1)
# print (px_given_z.shape)
# print (component)
# print (mixture_weights)
# print (mixture_weights[component])
# print (torch.log(mixture_weights[component]).shape)
# fdsasa
# print (logpx_given_z.shape)
# print (logpz.shape)
# fsdfas
logprob = logpx_given_z + logpz
# print (logprob.shape)
# fsfd
# probs.append(probs)
# probs_sum+=component_i
# logprob = torch.log(component_i)
return logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:42,代码来源:gmm_batch.py
示例13: inference_error
def inference_error():
x = sample_true(1).cuda()
trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)
logits = encoder.net(x)
probs = torch.softmax(logits, dim=1).view(n_components)
# print(trueposterior)
# print (probs)
# print ((trueposterior-probs)**2)
# print()
# print (trueposterior.shape)
# print (probs.shape)
# print (L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy()))
return L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
开发者ID:chriscremer,项目名称:Other_Code,代码行数:20,代码来源:gmm_batch_v2.py
示例14: inference_error
def inference_error():
error_sum = 0
kl_sum = 0
n=10
for i in range(n):
# if x is None:
x = sample_true(1).cuda()
trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)
logits = encoder.net(x)
probs = torch.softmax(logits/100., dim=1).view(n_components)
error = L2_mixtureweights(trueposterior.data.cpu().numpy(),probs.data.cpu().numpy())
kl = KL_mixutreweights(trueposterior.data.cpu().numpy(), probs.data.cpu().numpy())
error_sum+=error
kl_sum += kl
return error_sum/n, kl_sum/n
开发者ID:chriscremer,项目名称:Other_Code,代码行数:21,代码来源:gmm_batch.py
示例15: logprob_undercomponent
def logprob_undercomponent(x, component, needsoftmax_mixtureweight, cuda=False):
c= component
# print (needsoftmax_mixtureweight.shape)
mixture_weights = torch.softmax(needsoftmax_mixtureweight, dim=0)
# probs_sum = 0# = []
# for c in range(n_components):
# m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float() )#.cuda())
if not cuda:
m = Normal(torch.tensor([c*10.]).float(), torch.tensor([5.0]).float() )#.cuda())
else:
m = Normal(torch.tensor([c*10.]).float().cuda(), torch.tensor([5.0]).float().cuda())
# for x in xs:
# component_i = torch.exp(m.log_prob(x))* mixture_weights[c] #.numpy()
# print (m.log_prob(x))
# print (torch.log(mixture_weights[c]))
logprob = m.log_prob(x) + torch.log(mixture_weights[c])
# probs.append(probs)
# probs_sum+=component_i
# logprob = torch.log(component_i)
return logprob
开发者ID:chriscremer,项目名称:Other_Code,代码行数:21,代码来源:gmm.py
示例16: reinforce
def reinforce(x, logits, mixtureweights, k=1):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
cat = Categorical(probs=probs)
net_loss = 0
for jj in range(k):
cluster_H = cat.sample()
logq = cat.log_prob(cluster_H).view(B,1)
logpx_given_z = logprob_undercomponent(x, component=cluster_H)
logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq
net_loss += - torch.mean((f.detach() - 1.) * logq)
# net_loss += - torch.mean( -logq.detach()*logq)
net_loss = net_loss/ k
return net_loss, f, logpx_given_z, logpz, logq
开发者ID:chriscremer,项目名称:Other_Code,代码行数:22,代码来源:gmm_cleaned_v3.py
示例17: forward
def forward(self, hidden, attn, src_map):
"""
Compute a distribution over the target dictionary
extended by the dynamic dictionary implied by copying
source words.
Args:
hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
src_map (FloatTensor):
A sparse indicator matrix mapping each source word to
its index in the "extended" vocab containing.
``(src_len, batch, extra_words)``
"""
# CHECKS
batch_by_tlen, _ = hidden.size()
batch_by_tlen_, slen = attn.size()
slen_, batch, cvocab = src_map.size()
aeq(batch_by_tlen, batch_by_tlen_)
aeq(slen, slen_)
# Original probabilities.
logits = self.linear(hidden)
logits[:, self.pad_idx] = -float('inf')
prob = torch.softmax(logits, 1)
# Probability of copying p(z=1) batch.
p_copy = torch.sigmoid(self.linear_copy(hidden))
# Probability of not copying: p_{word}(w) * (1 - p(z))
out_prob = torch.mul(prob, 1 - p_copy)
mul_attn = torch.mul(attn, p_copy)
copy_prob = torch.bmm(
mul_attn.view(-1, batch, slen).transpose(0, 1),
src_map.transpose(0, 1)
).transpose(0, 1)
copy_prob = copy_prob.contiguous().view(-1, cvocab)
return torch.cat([out_prob, copy_prob], 1)
开发者ID:Unbabel,项目名称:OpenNMT-py,代码行数:38,代码来源:copy_generator.py
示例18: reinforce_pz
def reinforce_pz(x, logits, mixtureweights, k=1):
B = logits.shape[0]
probs = torch.softmax(logits, dim=1)
cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([1.]).cuda())
net_loss = 0
for jj in range(k):
cluster_S = cat.rsample()
cluster_H = H(cluster_S)
logq = cat.log_prob(cluster_S.detach()).view(B,1)
logpx_given_z = logprob_undercomponent(x, component=cluster_H)
logpz = torch.log(mixtureweights[cluster_H]).view(B,1)
logpxz = logpx_given_z + logpz #[B,1]
f = logpxz - logq - 1.
net_loss += - torch.mean((f.detach()) * logq)
# net_loss += - torch.mean( -logq.detach()*logq)
net_loss = net_loss/ k
return net_loss, f, logpx_given_z, logpz, logq
开发者ID:chriscremer,项目名称:Other_Code,代码行数:23,代码来源:gmm_cleaned_v5.py
示例19: plot_posteriors
def plot_posteriors(x=None, name=''):
if x is None:
x = sample_true(1).cuda()
else:
x = x[0].view(1,1)
trueposterior = true_posterior(x, needsoftmax_mixtureweight).view(n_components)
logits = encoder.net(x)
probs = torch.softmax(logits, dim=1).view(n_components)
trueposterior = trueposterior.data.cpu().numpy()
qz = probs.data.cpu().numpy()
rows = 1
cols = 1
fig = plt.figure(figsize=(8+cols,8+rows), facecolor='white') #, dpi=150)
col =0
row = 0
ax = plt.subplot2grid((rows,cols), (row,col), frameon=False, colspan=1, rowspan=1)
width = .3
ax.bar(range(len(qz)), trueposterior, width=width, label='True')
ax.bar(np.array(range(len(qz)))+width, qz, width=width, label='q')
# ax.bar(np.array(range(len(q_b)))+width+width, q_b, width=width)
ax.legend()
ax.grid(True, alpha=.3)
ax.set_title(str(x))
# save_dir = home+'/Documents/Grad_Estimators/GMM/'
plt_path = exp_dir+'posteriors' + name+'.png'
plt.savefig(plt_path)
print ('saved training plot', plt_path)
plt.close()
开发者ID:chriscremer,项目名称:Other_Code,代码行数:37,代码来源:gmm_batch_v2.py
示例20: get_loss
def get_loss():
x = sample_true(batch_size).cuda() #.view(1,1)
logits = encoder.net(x)
probs = torch.softmax(logits, dim=1)
cat = RelaxedOneHotCategorical(probs=probs, temperature=torch.tensor([temp]).cuda())
cluster_S = cat.rsample()
cluster_H = H(cluster_S)
# cluster_onehot = torch.zeros(n_components)
# cluster_onehot[cluster_H] = 1.
logprob_cluster = cat.log_prob(cluster_S.detach()).view(batch_size,1)
check_nan(logprob_cluster)
logpxz = logprob_undercomponent(x, component=cluster_H, needsoftmax_mixtureweight=needsoftmax_mixtureweight, cuda=True)
f = logpxz - logprob_cluster
surr_input = torch.cat([cluster_S, x], dim=1) #[B,21]
surr_pred = surrogate.net(surr_input)
# net_loss = - torch.mean((f.detach()-surr_pred.detach()) * logprob_cluster + surr_pred)
# loss = - torch.mean(f)
surr_loss = torch.mean(torch.abs(f.detach()-surr_pred))
return surr_loss
开发者ID:chriscremer,项目名称:Other_Code,代码行数:24,代码来源:gmm_batch_v2.py
注:本文中的torch.softmax函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论