• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python utils.broadcast_all函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中torch.distributions.utils.broadcast_all函数的典型用法代码示例。如果您正苦于以下问题:Python broadcast_all函数的具体用法?Python broadcast_all怎么用?Python broadcast_all使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了broadcast_all函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: __init__

 def __init__(self, probs=None, logits=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Bernoulli, self).__init__(batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:13,代码来源:bernoulli.py


示例2: __init__

 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         is_scalar = isinstance(probs, Number)
         self.probs, = broadcast_all(probs)
     else:
         is_scalar = isinstance(logits, Number)
         self.logits, = broadcast_all(logits)
     self._param = self.probs if probs is not None else self.logits
     if is_scalar:
         batch_shape = torch.Size()
     else:
         batch_shape = self._param.size()
     super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:15,代码来源:bernoulli.py


示例3: __init__

 def __init__(self, probs=None, logits=None, validate_args=None):
     if (probs is None) == (logits is None):
         raise ValueError("Either `probs` or `logits` must be specified, but not both.")
     if probs is not None:
         self.probs, = broadcast_all(probs)
         if not self.probs.gt(0).all():
             raise ValueError('All elements of probs must be greater than 0')
     else:
         self.logits, = broadcast_all(logits)
     probs_or_logits = probs if probs is not None else logits
     if isinstance(probs_or_logits, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = probs_or_logits.size()
     super(Geometric, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:15,代码来源:geometric.py


示例4: __init__

 def __init__(self, loc, scale):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Laplace, self).__init__(batch_shape)
开发者ID:MaheshBhosale,项目名称:pytorch,代码行数:7,代码来源:laplace.py


示例5: __init__

 def __init__(self, rate, validate_args=None):
     self.rate, = broadcast_all(rate)
     if isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.rate.size()
     super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:7,代码来源:poisson.py


示例6: __init__

 def __init__(self, alpha, beta):
     self.alpha, self.beta = broadcast_all(alpha, beta)
     if isinstance(alpha, Number) and isinstance(beta, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.alpha.size()
     super(Gamma, self).__init__(batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:7,代码来源:gamma.py


示例7: __init__

 def __init__(self, scale, alpha):
     self.scale, self.alpha = broadcast_all(scale, alpha)
     if isinstance(scale, Number) and isinstance(alpha, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.scale.size()
     super(Pareto, self).__init__(batch_shape)
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:7,代码来源:pareto.py


示例8: __init__

 def __init__(self, low, high):
     self.low, self.high = broadcast_all(low, high)
     if isinstance(low, Number) and isinstance(high, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.low.size()
     super(Uniform, self).__init__(batch_shape)
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:7,代码来源:uniform.py


示例9: __init__

 def __init__(self, concentration, rate):
     self.concentration, self.rate = broadcast_all(concentration, rate)
     if isinstance(concentration, Number) and isinstance(rate, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.concentration.size()
     super(Gamma, self).__init__(batch_shape)
开发者ID:MaheshBhosale,项目名称:pytorch,代码行数:7,代码来源:gamma.py


示例10: __init__

 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
     else:
         batch_shape = self.loc.size()
     super(Normal, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:7,代码来源:normal.py


示例11: __init__

 def __init__(self, alpha, beta):
     if isinstance(alpha, Number) and isinstance(beta, Number):
         alpha_beta = torch.Tensor([alpha, beta])
     else:
         alpha, beta = broadcast_all(alpha, beta)
         alpha_beta = torch.stack([alpha, beta], -1)
     self._dirichlet = Dirichlet(alpha_beta)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
开发者ID:lxlhh,项目名称:pytorch,代码行数:8,代码来源:beta.py


示例12: __init__

 def __init__(self, concentration1, concentration0, validate_args=None):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = torch.tensor([float(concentration1), float(concentration0)])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:8,代码来源:beta.py


示例13: log_prob

 def log_prob(self, value):
     self._validate_log_prob_arg(value)
     logits, value = broadcast_all(self.logits.clone(), value)
     log_factorial_n = torch.lgamma(value.sum(-1) + 1)
     log_factorial_xs = torch.lgamma(value + 1).sum(-1)
     logits[(value == 0) & (logits == -float('inf'))] = 0
     log_powers = (logits * value).sum(-1)
     return log_factorial_n - log_factorial_xs + log_powers
开发者ID:bhuWenDongchao,项目名称:pytorch,代码行数:8,代码来源:multinomial.py


示例14: __init__

 def __init__(self, concentration1, concentration0):
     if isinstance(concentration1, Number) and isinstance(concentration0, Number):
         concentration1_concentration0 = variable([concentration1, concentration0])
     else:
         concentration1, concentration0 = broadcast_all(concentration1, concentration0)
         concentration1_concentration0 = torch.stack([concentration1, concentration0], -1)
     self._dirichlet = Dirichlet(concentration1_concentration0)
     super(Beta, self).__init__(self._dirichlet._batch_shape)
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:8,代码来源:beta.py


示例15: __init__

    def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            self.total_count, self.probs, = broadcast_all(total_count, probs)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.probs, Number)
        else:
            self.total_count, self.logits, = broadcast_all(total_count, logits)
            self.total_count = self.total_count.type_as(self.logits)
            is_scalar = isinstance(self.logits, Number)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:18,代码来源:binomial.py


示例16: log_prob

 def log_prob(self, value):
     K = self._categorical._num_events
     if self._validate_args:
         self._validate_sample(value)
     logits, value = broadcast_all(self.logits, value)
     log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
                  self.temperature.log().mul(-(K - 1)))
     score = logits - value.mul(self.temperature)
     score = (score - _log_sum_exp(score)).sum(-1)
     return score + log_scale
开发者ID:gtgalone,项目名称:pytorch,代码行数:10,代码来源:relaxed_categorical.py


示例17: __init__

    def __init__(self, df1, df2, validate_args=None):
        self.df1, self.df2 = broadcast_all(df1, df2)
        self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
        self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

        if isinstance(df1, Number) and isinstance(df2, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.df1.size()
        super(FisherSnedecor, self).__init__(batch_shape, validate_args=validate_args)
开发者ID:RichieMay,项目名称:pytorch,代码行数:10,代码来源:fishersnedecor.py


示例18: __init__

    def __init__(self, total_count=1, probs=None, logits=None):
        if not isinstance(total_count, Number):
            raise NotImplementedError('inhomogeneous total_count is not supported')
        self.total_count = total_count
        if (probs is None) == (logits is None):
            raise ValueError("Either `probs` or `logits` must be specified, but not both.")
        if probs is not None:
            is_scalar = isinstance(probs, Number)
            self.probs, = broadcast_all(probs)
        else:
            is_scalar = isinstance(logits, Number)
            self.logits, = broadcast_all(logits)

        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Binomial, self).__init__(batch_shape)
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:19,代码来源:binomial.py


示例19: __init__

    def __init__(self, low, high, validate_args=None):
        self.low, self.high = broadcast_all(low, high)

        if isinstance(low, Number) and isinstance(high, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.low.size()
        super(Uniform, self).__init__(batch_shape, validate_args=validate_args)

        if self._validate_args and not torch.lt(self.low, self.high).all():
            raise ValueError("Uniform is not defined when low>= high")
开发者ID:gtgalone,项目名称:pytorch,代码行数:11,代码来源:uniform.py


示例20: __init__

 def __init__(self, loc, scale, validate_args=None):
     self.loc, self.scale = broadcast_all(loc, scale)
     finfo = _finfo(self.loc)
     if isinstance(loc, Number) and isinstance(scale, Number):
         batch_shape = torch.Size()
         base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
     else:
         batch_shape = self.scale.size()
         base_dist = Uniform(self.loc.new(self.loc.size()).fill_(finfo.tiny), 1 - finfo.eps)
     transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
                   ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
     super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
开发者ID:gtgalone,项目名称:pytorch,代码行数:12,代码来源:gumbel.py



注:本文中的torch.distributions.utils.broadcast_all函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python functional.avg_pool2d函数代码示例发布时间:2022-05-27
下一篇:
Python categorical.Categorical类代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap