本文整理汇总了Python中torch.distributions.utils.broadcast_all函数的典型用法代码示例。如果您正苦于以下问题:Python broadcast_all函数的具体用法?Python broadcast_all怎么用?Python broadcast_all使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了broadcast_all函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: __init__
def __init__(self, probs=None, logits=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.probs, = broadcast_all(probs)
else:
self.logits, = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, Number):
batch_shape = torch.Size()
else:
batch_shape = probs_or_logits.size()
super(Bernoulli, self).__init__(batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:13,代码来源:bernoulli.py
示例2: __init__
def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
is_scalar = isinstance(probs, Number)
self.probs, = broadcast_all(probs)
else:
is_scalar = isinstance(logits, Number)
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:15,代码来源:bernoulli.py
示例3: __init__
def __init__(self, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.probs, = broadcast_all(probs)
if not self.probs.gt(0).all():
raise ValueError('All elements of probs must be greater than 0')
else:
self.logits, = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, Number):
batch_shape = torch.Size()
else:
batch_shape = probs_or_logits.size()
super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:15,代码来源:geometric.py
示例4: __init__
def __init__(self, loc, scale):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()
super(Laplace, self).__init__(batch_shape)
开发者ID:MaheshBhosale,项目名称:pytorch,代码行数:7,代码来源:laplace.py
示例5: __init__
def __init__(self, rate, validate_args=None):
self.rate, = broadcast_all(rate)
if isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.rate.size()
super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:7,代码来源:poisson.py
示例6: __init__
def __init__(self, alpha, beta):
self.alpha, self.beta = broadcast_all(alpha, beta)
if isinstance(alpha, Number) and isinstance(beta, Number):
batch_shape = torch.Size()
else:
batch_shape = self.alpha.size()
super(Gamma, self).__init__(batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:7,代码来源:gamma.py
示例7: __init__
def __init__(self, scale, alpha):
self.scale, self.alpha = broadcast_all(scale, alpha)
if isinstance(scale, Number) and isinstance(alpha, Number):
batch_shape = torch.Size()
else:
batch_shape = self.scale.size()
super(Pareto, self).__init__(batch_shape)
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:7,代码来源:pareto.py
示例8: __init__
def __init__(self, low, high):
self.low, self.high = broadcast_all(low, high)
if isinstance(low, Number) and isinstance(high, Number):
batch_shape = torch.Size()
else:
batch_shape = self.low.size()
super(Uniform, self).__init__(batch_shape)
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:7,代码来源:uniform.py
示例9: __init__
def __init__(self, concentration, rate):
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.concentration.size()
super(Gamma, self).__init__(batch_shape)
开发者ID:MaheshBhosale,项目名称:pytorch,代码行数:7,代码来源:gamma.py
示例10: __init__
def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()
super(Normal, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:7,代码来源:normal.py
示例11: __init__
def __init__(self, alpha, beta):
if isinstance(alpha, Number) and isinstance(beta, Number):
alpha_beta = torch.Tensor([alpha, beta])
else:
alpha, beta = broadcast_all(alpha, beta)
alpha_beta = torch.stack([alpha, beta], -1)
self._dirichlet = Dirichlet(alpha_beta)
super(Beta, self).__init__(self._dirichlet._batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:8,代码来源:beta.py
示例12: __init__
def __init__(self, concentration1, concentration0, validate_args=None):
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0)
super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:8,代码来源:beta.py
示例13: log_prob
def log_prob(self, value):
self._validate_log_prob_arg(value)
logits, value = broadcast_all(self.logits.clone(), value)
log_factorial_n = torch.lgamma(value.sum(-1) + 1)
log_factorial_xs = torch.lgamma(value + 1).sum(-1)
logits[(value == 0) & (logits == -float('inf'))] = 0
log_powers = (logits * value).sum(-1)
return log_factorial_n - log_factorial_xs + log_powers
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:8,代码来源:multinomial.py
示例14: __init__
def __init__(self, concentration1, concentration0):
if isinstance(concentration1, Number) and isinstance(concentration0, Number):
concentration1_concentration0 = variable([concentration1, concentration0])
else:
concentration1, concentration0 = broadcast_all(concentration1, concentration0)
concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
self._dirichlet = Dirichlet(concentration1_concentration0)
super(Beta, self).__init__(self._dirichlet._batch_shape)
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:8,代码来源:beta.py
示例15: __init__
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.total_count, self.probs, = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.logits)
is_scalar = isinstance(self.probs, Number)
else:
self.total_count, self.logits, = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)
is_scalar = isinstance(self.logits, Number)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:18,代码来源:binomial.py
示例16: log_prob
def log_prob(self, value):
K = self._categorical._num_events
if self._validate_args:
self._validate_sample(value)
logits, value = broadcast_all(self.logits, value)
log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
self.temperature.log().mul(-(K - 1)))
score = logits - value.mul(self.temperature)
score = (score - _log_sum_exp(score)).sum(-1)
return score + log_scale
开发者ID:gtgalone,项目名称:pytorch,代码行数:10,代码来源:relaxed_categorical.py
示例17: __init__
def __init__(self, df1, df2, validate_args=None):
self.df1, self.df2 = broadcast_all(df1, df2)
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
if isinstance(df1, Number) and isinstance(df2, Number):
batch_shape = torch.Size()
else:
batch_shape = self.df1.size()
super(FisherSnedecor, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:10,代码来源:fishersnedecor.py
示例18: __init__
def __init__(self, total_count=1, probs=None, logits=None):
if not isinstance(total_count, Number):
raise NotImplementedError('inhomogeneous total_count is not supported')
self.total_count = total_count
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
is_scalar = isinstance(probs, Number)
self.probs, = broadcast_all(probs)
else:
is_scalar = isinstance(logits, Number)
self.logits, = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
super(Binomial, self).__init__(batch_shape)
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:19,代码来源:binomial.py
示例19: __init__
def __init__(self, low, high, validate_args=None):
self.low, self.high = broadcast_all(low, high)
if isinstance(low, Number) and isinstance(high, Number):
batch_shape = torch.Size()
else:
batch_shape = self.low.size()
super(Uniform, self).__init__(batch_shape, validate_args=validate_args)
if self._validate_args and not torch.lt(self.low, self.high).all():
raise ValueError("Uniform is not defined when low>= high")
开发者ID:gtgalone,项目名称:pytorch,代码行数:11,代码来源:uniform.py
示例20: __init__
def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = broadcast_all(loc, scale)
finfo = _finfo(self.loc)
if isinstance(loc, Number) and isinstance(scale, Number):
batch_shape = torch.Size()
base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
else:
batch_shape = self.scale.size()
base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps)
transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:12,代码来源:gumbel.py
注:本文中的torch.distributions.utils.broadcast_all函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论