本文整理汇总了Python中tests.common.assert_equal函数的典型用法代码示例。如果您正苦于以下问题:Python assert_equal函数的具体用法?Python assert_equal怎么用?Python assert_equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了assert_equal函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_mean_and_var
def test_mean_and_var(self):
torch_samples = [dist.Delta(self.v).sample().detach().cpu().numpy()
for _ in range(self.n_samples)]
torch_mean = np.mean(torch_samples)
torch_var = np.var(torch_samples)
assert_equal(torch_mean, self.analytic_mean)
assert_equal(torch_var, self.analytic_var)
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_delta.py
示例2: test_batch_log_dims
def test_batch_log_dims(dim, vs, one_hot, ps):
batch_pdf_shape = (3,) + (1,) * dim
expected_log_pdf = np.array(wrap_nested(list(np.log(ps)), dim-1)).reshape(*batch_pdf_shape)
ps, vs = modify_params_using_dims(ps, vs, dim)
support = dist.categorical.enumerate_support(ps, vs, one_hot=one_hot)
batch_log_pdf = dist.categorical.batch_log_pdf(support, ps, vs, one_hot=one_hot)
assert_equal(batch_log_pdf.data.cpu().numpy(), expected_log_pdf)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_categorical_dimensions.py
示例3: test_mask
def test_mask(batch_dim, event_dim, mask_dim):
# Construct base distribution.
shape = torch.Size([2, 3, 4, 5, 6][:batch_dim + event_dim])
batch_shape = shape[:batch_dim]
mask_shape = batch_shape[batch_dim - mask_dim:]
base_dist = Bernoulli(0.1).expand_by(shape).independent(event_dim)
# Construct masked distribution.
mask = checker_mask(mask_shape)
dist = base_dist.mask(mask)
# Check shape.
sample = base_dist.sample()
assert dist.batch_shape == base_dist.batch_shape
assert dist.event_shape == base_dist.event_shape
assert sample.shape == sample.shape
assert dist.log_prob(sample).shape == base_dist.log_prob(sample).shape
# Check values.
assert_equal(dist.mean, base_dist.mean)
assert_equal(dist.variance, base_dist.variance)
assert_equal(dist.log_prob(sample), base_dist.log_prob(sample) * mask)
assert_equal(dist.score_parts(sample), base_dist.score_parts(sample) * mask, prec=0)
if not dist.event_shape:
assert_equal(dist.enumerate_support(), base_dist.enumerate_support())
开发者ID:lewisKit,项目名称:pyro,代码行数:25,代码来源:test_mask.py
示例4: test_decorator_interface_primitives
def test_decorator_interface_primitives():
@poutine.trace
def model():
pyro.param("p", torch.zeros(1, requires_grad=True))
pyro.sample("a", Bernoulli(torch.tensor([0.5])),
infer={"enumerate": "parallel"})
pyro.sample("b", Bernoulli(torch.tensor([0.5])))
tr = model.get_trace()
assert isinstance(tr, poutine.Trace)
assert tr.graph_type == "flat"
@poutine.trace(graph_type="dense")
def model():
pyro.param("p", torch.zeros(1, requires_grad=True))
pyro.sample("a", Bernoulli(torch.tensor([0.5])),
infer={"enumerate": "parallel"})
pyro.sample("b", Bernoulli(torch.tensor([0.5])))
tr = model.get_trace()
assert isinstance(tr, poutine.Trace)
assert tr.graph_type == "dense"
tr2 = poutine.trace(poutine.replay(model, trace=tr)).get_trace()
assert_equal(tr2.nodes["a"]["value"], tr.nodes["a"]["value"])
开发者ID:lewisKit,项目名称:pyro,代码行数:27,代码来源:test_poutines.py
示例5: test_iter_discrete_traces_vector
def test_iter_discrete_traces_vector(graph_type):
pyro.clear_param_store()
def model():
p = pyro.param("p", Variable(torch.Tensor([[0.05], [0.15]])))
ps = pyro.param("ps", Variable(torch.Tensor([[0.1, 0.2, 0.3, 0.4],
[0.4, 0.3, 0.2, 0.1]])))
x = pyro.sample("x", dist.Bernoulli(p))
y = pyro.sample("y", dist.Categorical(ps, one_hot=False))
assert x.size() == (2, 1)
assert y.size() == (2, 1)
return dict(x=x, y=y)
traces = list(iter_discrete_traces(graph_type, model))
p = pyro.param("p").data
ps = pyro.param("ps").data
assert len(traces) == 2 * ps.size(-1)
for scale, trace in traces:
x = trace.nodes["x"]["value"].data.squeeze().long()[0]
y = trace.nodes["y"]["value"].data.squeeze().long()[0]
expected_scale = torch.exp(dist.Bernoulli(p).log_pdf(x) *
dist.Categorical(ps, one_hot=False).log_pdf(y))
expected_scale = expected_scale.data.view(-1)[0]
assert_equal(scale, expected_scale)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:26,代码来源:test_enum.py
示例6: test_quantiles
def test_quantiles(auto_class, Elbo):
def model():
pyro.sample("x", dist.Normal(0.0, 1.0))
pyro.sample("y", dist.LogNormal(0.0, 1.0))
pyro.sample("z", dist.Beta(2.0, 2.0))
guide = auto_class(model)
infer = SVI(model, guide, Adam({'lr': 0.01}), Elbo(strict_enumeration_warning=False))
for _ in range(100):
infer.step()
quantiles = guide.quantiles([0.1, 0.5, 0.9])
median = guide.median()
for name in ["x", "y", "z"]:
assert_equal(median[name], quantiles[name][1])
quantiles = {name: [v.item() for v in value] for name, value in quantiles.items()}
assert -3.0 < quantiles["x"][0]
assert quantiles["x"][0] + 1.0 < quantiles["x"][1]
assert quantiles["x"][1] + 1.0 < quantiles["x"][2]
assert quantiles["x"][2] < 3.0
assert 0.01 < quantiles["y"][0]
assert quantiles["y"][0] * 2.0 < quantiles["y"][1]
assert quantiles["y"][1] * 2.0 < quantiles["y"][2]
assert quantiles["y"][2] < 100.0
assert 0.01 < quantiles["z"][0]
assert quantiles["z"][0] + 0.1 < quantiles["z"][1]
assert quantiles["z"][1] + 0.1 < quantiles["z"][2]
assert quantiles["z"][2] < 0.99
开发者ID:lewisKit,项目名称:pyro,代码行数:32,代码来源:test_advi.py
示例7: test_optimizers
def test_optimizers(factory):
optim = factory()
def model(loc, cov):
x = pyro.param("x", torch.randn(2))
y = pyro.param("y", torch.randn(3, 2))
z = pyro.param("z", torch.randn(4, 2).abs(), constraint=constraints.greater_than(-1))
pyro.sample("obs_x", dist.MultivariateNormal(loc, cov), obs=x)
with pyro.iarange("y_iarange", 3):
pyro.sample("obs_y", dist.MultivariateNormal(loc, cov), obs=y)
with pyro.iarange("z_iarange", 4):
pyro.sample("obs_z", dist.MultivariateNormal(loc, cov), obs=z)
loc = torch.tensor([-0.5, 0.5])
cov = torch.tensor([[1.0, 0.09], [0.09, 0.1]])
for step in range(100):
tr = poutine.trace(model).get_trace(loc, cov)
loss = -tr.log_prob_sum()
params = {name: pyro.param(name).unconstrained() for name in ["x", "y", "z"]}
optim.step(loss, params)
for name in ["x", "y", "z"]:
actual = pyro.param(name)
expected = loc.expand(actual.shape)
assert_equal(actual, expected, prec=1e-2,
msg='{} in correct: {} vs {}'.format(name, actual, expected))
开发者ID:lewisKit,项目名称:pyro,代码行数:26,代码来源:test_multi.py
示例8: test_bern_elbo_gradient
def test_bern_elbo_gradient(enum_discrete, trace_graph):
pyro.clear_param_store()
num_particles = 2000
def model():
p = Variable(torch.Tensor([0.25]))
pyro.sample("z", dist.Bernoulli(p))
def guide():
p = pyro.param("p", Variable(torch.Tensor([0.5]), requires_grad=True))
pyro.sample("z", dist.Bernoulli(p))
print("Computing gradients using surrogate loss")
Elbo = TraceGraph_ELBO if trace_graph else Trace_ELBO
elbo = Elbo(enum_discrete=enum_discrete,
num_particles=(1 if enum_discrete else num_particles))
with xfail_if_not_implemented():
elbo.loss_and_grads(model, guide)
params = sorted(pyro.get_param_store().get_all_param_names())
assert params, "no params found"
actual_grads = {name: pyro.param(name).grad.clone() for name in params}
print("Computing gradients using finite difference")
elbo = Trace_ELBO(num_particles=num_particles)
expected_grads = finite_difference(lambda: elbo.loss(model, guide))
for name in params:
print("{} {}{}{}".format(name, "-" * 30, actual_grads[name].data,
expected_grads[name].data))
assert_equal(actual_grads, expected_grads, prec=0.1)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:30,代码来源:test_enum.py
示例9: test_categorical_gradient_with_logits
def test_categorical_gradient_with_logits(init_tensor_type):
p = Variable(init_tensor_type([-float('inf'), 0]), requires_grad=True)
categorical = Categorical(logits=p)
log_pdf = categorical.batch_log_pdf(Variable(init_tensor_type([0, 1])))
log_pdf.sum().backward()
assert_equal(log_pdf.data[0], 0)
assert_equal(p.grad.data[0], 0)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_gradient_flow.py
示例10: test_bernoulli_with_logits_overflow_gradient
def test_bernoulli_with_logits_overflow_gradient(init_tensor_type):
p = Variable(init_tensor_type([1e40]), requires_grad=True)
bernoulli = Bernoulli(logits=p)
log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([1])))
log_pdf.sum().backward()
assert_equal(log_pdf.data[0], 0)
assert_equal(p.grad.data[0], 0)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_gradient_flow.py
示例11: test_bernoulli_underflow_gradient
def test_bernoulli_underflow_gradient(init_tensor_type):
p = Variable(init_tensor_type([0]), requires_grad=True)
bernoulli = Bernoulli(sigmoid(p) * 0.0)
log_pdf = bernoulli.batch_log_pdf(Variable(init_tensor_type([0])))
log_pdf.sum().backward()
assert_equal(log_pdf.data[0], 0)
assert_equal(p.grad.data[0], 0)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:7,代码来源:test_gradient_flow.py
示例12: test_unweighted_samples
def test_unweighted_samples(batch_shape, sample_shape, dtype):
empirical_dist = Empirical()
for i in range(5):
empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
samples = empirical_dist.sample(sample_shape=sample_shape)
assert_equal(samples.size(), sample_shape + batch_shape)
assert_equal(set(samples.view(-1).tolist()), set(range(5)))
开发者ID:lewisKit,项目名称:pyro,代码行数:7,代码来源:test_empirical.py
示例13: test_compute_downstream_costs_iarange_reuse
def test_compute_downstream_costs_iarange_reuse(dim1, dim2):
guide_trace = poutine.trace(iarange_reuse_model_guide,
graph_type="dense").get_trace(include_obs=False, dim1=dim1, dim2=dim2)
model_trace = poutine.trace(poutine.replay(iarange_reuse_model_guide, trace=guide_trace),
graph_type="dense").get_trace(include_obs=True, dim1=dim1, dim2=dim2)
guide_trace = prune_subsample_sites(guide_trace)
model_trace = prune_subsample_sites(model_trace)
model_trace.compute_log_prob()
guide_trace.compute_log_prob()
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
dc, dc_nodes = _compute_downstream_costs(model_trace, guide_trace,
non_reparam_nodes)
dc_brute, dc_nodes_brute = _brute_force_compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes)
assert dc_nodes == dc_nodes_brute
for k in dc:
assert(guide_trace.nodes[k]['log_prob'].size() == dc[k].size())
assert_equal(dc[k], dc_brute[k])
expected_c1 = model_trace.nodes['c1']['log_prob'] - guide_trace.nodes['c1']['log_prob']
expected_c1 += (model_trace.nodes['b1']['log_prob'] - guide_trace.nodes['b1']['log_prob']).sum()
expected_c1 += model_trace.nodes['c2']['log_prob'] - guide_trace.nodes['c2']['log_prob']
expected_c1 += model_trace.nodes['obs']['log_prob']
assert_equal(expected_c1, dc['c1'])
开发者ID:lewisKit,项目名称:pyro,代码行数:26,代码来源:test_compute_downstream_costs.py
示例14: test_hmc_conjugate_gaussian
def test_hmc_conjugate_gaussian(fixture,
num_samples,
warmup_steps,
hmc_params,
expected_means,
expected_precs,
mean_tol,
std_tol):
pyro.get_param_store().clear()
hmc_kernel = HMC(fixture.model, **hmc_params)
mcmc_run = MCMC(hmc_kernel, num_samples, warmup_steps).run(fixture.data)
for i in range(1, fixture.chain_len + 1):
param_name = 'loc_' + str(i)
marginal = EmpiricalMarginal(mcmc_run, sites=param_name)
latent_loc = marginal.mean
latent_std = marginal.variance.sqrt()
expected_mean = torch.ones(fixture.dim) * expected_means[i - 1]
expected_std = 1 / torch.sqrt(torch.ones(fixture.dim) * expected_precs[i - 1])
# Actual vs expected posterior means for the latents
logger.info('Posterior mean (actual) - {}'.format(param_name))
logger.info(latent_loc)
logger.info('Posterior mean (expected) - {}'.format(param_name))
logger.info(expected_mean)
assert_equal(rmse(latent_loc, expected_mean).item(), 0.0, prec=mean_tol)
# Actual vs expected posterior precisions for the latents
logger.info('Posterior std (actual) - {}'.format(param_name))
logger.info(latent_std)
logger.info('Posterior std (expected) - {}'.format(param_name))
logger.info(expected_std)
assert_equal(rmse(latent_std, expected_std).item(), 0.0, prec=std_tol)
开发者ID:lewisKit,项目名称:pyro,代码行数:32,代码来源:test_hmc.py
示例15: test_elbo_bern
def test_elbo_bern(quantity, enumerate1):
pyro.clear_param_store()
num_particles = 1 if enumerate1 else 10000
prec = 0.001 if enumerate1 else 0.1
q = pyro.param("q", torch.tensor(0.5, requires_grad=True))
kl = kl_divergence(dist.Bernoulli(q), dist.Bernoulli(0.25))
def model():
with pyro.iarange("particles", num_particles):
pyro.sample("z", dist.Bernoulli(0.25).expand_by([num_particles]))
@config_enumerate(default=enumerate1)
def guide():
q = pyro.param("q")
with pyro.iarange("particles", num_particles):
pyro.sample("z", dist.Bernoulli(q).expand_by([num_particles]))
elbo = TraceEnum_ELBO(max_iarange_nesting=1,
strict_enumeration_warning=any([enumerate1]))
if quantity == "loss":
actual = elbo.loss(model, guide) / num_particles
expected = kl.item()
assert_equal(actual, expected, prec=prec, msg="".join([
"\nexpected = {}".format(expected),
"\n actual = {}".format(actual),
]))
else:
elbo.loss_and_grads(model, guide)
actual = q.grad / num_particles
expected = grad(kl, [q])[0]
assert_equal(actual, expected, prec=prec, msg="".join([
"\nexpected = {}".format(expected.detach().cpu().numpy()),
"\n actual = {}".format(actual.detach().cpu().numpy()),
]))
开发者ID:lewisKit,项目名称:pyro,代码行数:35,代码来源:test_enum.py
示例16: test_elbo_hmm_in_guide
def test_elbo_hmm_in_guide(enumerate1, num_steps):
pyro.clear_param_store()
data = torch.ones(num_steps)
init_probs = torch.tensor([0.5, 0.5])
def model(data):
transition_probs = pyro.param("transition_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=constraints.simplex)
emission_probs = pyro.param("emission_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=constraints.simplex)
x = None
for i, y in enumerate(data):
probs = init_probs if x is None else transition_probs[x]
x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
pyro.sample("y_{}".format(i), dist.Categorical(emission_probs[x]), obs=y)
@config_enumerate(default=enumerate1)
def guide(data):
transition_probs = pyro.param("transition_probs",
torch.tensor([[0.75, 0.25], [0.25, 0.75]]),
constraint=constraints.simplex)
x = None
for i, y in enumerate(data):
probs = init_probs if x is None else transition_probs[x]
x = pyro.sample("x_{}".format(i), dist.Categorical(probs))
elbo = TraceEnum_ELBO(max_iarange_nesting=0)
elbo.loss_and_grads(model, guide, data)
# These golden values simply test agreement between parallel and sequential.
expected_grads = {
2: {
"transition_probs": [[0.1029949, -0.1029949], [0.1029949, -0.1029949]],
"emission_probs": [[0.75, -0.75], [0.25, -0.25]],
},
3: {
"transition_probs": [[0.25748726, -0.25748726], [0.25748726, -0.25748726]],
"emission_probs": [[1.125, -1.125], [0.375, -0.375]],
},
10: {
"transition_probs": [[1.64832076, -1.64832076], [1.64832076, -1.64832076]],
"emission_probs": [[3.75, -3.75], [1.25, -1.25]],
},
20: {
"transition_probs": [[3.70781687, -3.70781687], [3.70781687, -3.70781687]],
"emission_probs": [[7.5, -7.5], [2.5, -2.5]],
},
}
for name, value in pyro.get_param_store().named_parameters():
actual = value.grad
expected = torch.tensor(expected_grads[num_steps][name])
assert_equal(actual, expected, msg=''.join([
'\nexpected {}.grad = {}'.format(name, expected.cpu().numpy()),
'\n actual {}.grad = {}'.format(name, actual.detach().cpu().numpy()),
]))
开发者ID:lewisKit,项目名称:pyro,代码行数:59,代码来源:test_enum.py
示例17: test_batch_log_pdf
def test_batch_log_pdf(dist):
d = dist.pyro_dist
for idx in dist.get_batch_data_indices():
dist_params = dist.get_dist_params(idx)
test_data = dist.get_test_data(idx)
logpdf_sum_pyro = unwrap_variable(torch.sum(d.batch_log_pdf(test_data, **dist_params)))[0]
logpdf_sum_np = np.sum(dist.get_scipy_batch_logpdf(-1))
assert_equal(logpdf_sum_pyro, logpdf_sum_np)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:8,代码来源:test_distributions.py
示例18: test_log_pdf
def test_log_pdf(dist):
d = dist.pyro_dist
for idx in dist.get_test_data_indices():
dist_params = dist.get_dist_params(idx)
test_data = dist.get_test_data(idx)
pyro_log_pdf = unwrap_variable(d.log_pdf(test_data, **dist_params))[0]
scipy_log_pdf = dist.get_scipy_logpdf(idx)
assert_equal(pyro_log_pdf, scipy_log_pdf)
开发者ID:Magica-Chen,项目名称:pyro,代码行数:8,代码来源:test_distributions.py
示例19: test_unweighted_mean_and_var
def test_unweighted_mean_and_var(size, dtype):
empirical_dist = Empirical()
for i in range(5):
empirical_dist.add(torch.ones(size, dtype=dtype) * i)
true_mean = torch.ones(size) * 2
true_var = torch.ones(size) * 2
assert_equal(empirical_dist.mean, true_mean)
assert_equal(empirical_dist.variance, true_var)
开发者ID:lewisKit,项目名称:pyro,代码行数:8,代码来源:test_empirical.py
示例20: test_double_type
def test_double_type(test_data, alpha, beta):
log_px_torch = dist.Beta(alpha, beta).log_prob(test_data).data
assert isinstance(log_px_torch, torch.DoubleTensor)
log_px_val = log_px_torch.numpy()
log_px_np = sp.beta.logpdf(
test_data.detach().cpu().numpy(),
alpha.detach().cpu().numpy(),
beta.detach().cpu().numpy())
assert_equal(log_px_val, log_px_np, prec=1e-4)
开发者ID:lewisKit,项目名称:pyro,代码行数:9,代码来源:test_tensor_type.py
注:本文中的tests.common.assert_equal函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论