本文整理汇总了Python中torch.equal函数的典型用法代码示例。如果您正苦于以下问题:Python equal函数的具体用法?Python equal怎么用?Python equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了equal函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_degenerate_GPyTorchPosterior
def test_degenerate_GPyTorchPosterior(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype in (torch.float, torch.double):
# singular covariance matrix
degenerate_covar = torch.tensor(
[[1, 1, 0], [1, 1, 0], [0, 0, 2]], dtype=dtype, device=device
)
mean = torch.rand(3, dtype=dtype, device=device)
mvn = MultivariateNormal(mean, lazify(degenerate_covar))
posterior = GPyTorchPosterior(mvn=mvn)
# basics
self.assertEqual(posterior.device.type, device.type)
self.assertTrue(posterior.dtype == dtype)
self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
variance_exp = degenerate_covar.diag().unsqueeze(-1)
self.assertTrue(torch.equal(posterior.variance, variance_exp))
# rsample
with warnings.catch_warnings(record=True) as w:
# we check that the p.d. warning is emitted - this only
# happens once per posterior, so we need to check only once
samples = posterior.rsample(sample_shape=torch.Size([4]))
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
self.assertTrue("not p.d." in str(w[-1].message))
self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
# rsample w/ base samples
base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
samples_b1 = posterior.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
samples_b2 = posterior.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
self.assertTrue(torch.allclose(samples_b1, samples_b2))
base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
samples2_b1 = posterior.rsample(
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
)
samples2_b2 = posterior.rsample(
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
)
self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
# collapse_batch_dims
b_mean = torch.rand(2, 3, dtype=dtype, device=device)
b_degenerate_covar = degenerate_covar.expand(2, *degenerate_covar.shape)
b_mvn = MultivariateNormal(b_mean, lazify(b_degenerate_covar))
b_posterior = GPyTorchPosterior(mvn=b_mvn)
b_base_samples = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
with warnings.catch_warnings(record=True) as w:
b_samples = b_posterior.rsample(
sample_shape=torch.Size([4]), base_samples=b_base_samples
)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, RuntimeWarning))
self.assertTrue("not p.d." in str(w[-1].message))
self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
开发者ID:saschwan,项目名称:botorch,代码行数:60,代码来源:test_gpytorch.py
示例2: test_remote_tensor_multi_var_methods
def test_remote_tensor_multi_var_methods(self):
hook = TorchHook(verbose=False)
local = hook.local_worker
remote = VirtualWorker(hook, 1)
local.add_worker(remote)
x = torch.FloatTensor([[1, 2], [4, 3], [5, 6]])
x.send(remote)
y, z = torch.max(x, 1)
assert torch.equal(y.get(), torch.FloatTensor([2, 4, 6]))
assert torch.equal(z.get(), torch.LongTensor([1, 0, 1]))
x = torch.FloatTensor([[0, 0], [1, 0]]).send(remote)
y, z = torch.qr(x)
assert (y.get() == torch.FloatTensor([[0, -1], [-1, 0]])).all()
assert (z.get() == torch.FloatTensor([[-1, 0], [0, 0]])).all()
x = torch.arange(1, 6).send(remote)
y, z = torch.kthvalue(x, 4)
assert (y.get() == torch.FloatTensor([4])).all()
assert (z.get() == torch.LongTensor([3])).all()
x = torch.FloatTensor([[0, 0], [1, 1]]).send(remote)
y, z = torch.eig(x, True)
assert (y.get() == torch.FloatTensor([[1, 0], [0, 0]])).all()
assert ((z.get() == torch.FloatTensor([[0, 0], [1, 0]])) == torch.ByteTensor([[1, 0], [1, 0]])).all()
x = torch.zeros(3, 3).send(remote)
w, y, z = torch.svd(x)
assert (w.get() == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
assert (y.get() == torch.FloatTensor([0, 0, 0])).all()
assert (z.get() == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:32,代码来源:torch_test.py
示例3: test_regex_matches_are_initialized_correctly
def test_regex_matches_are_initialized_correctly(self):
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
self.linear_2 = torch.nn.Linear(10, 5)
self.conv = torch.nn.Conv1d(5, 5, 5)
def forward(self, inputs): # pylint: disable=arguments-differ
pass
# pyhocon does funny things if there's a . in a key. This test makes sure that we
# handle these kinds of regexes correctly.
json_params = """{"initializer": [
["conv", {"type": "constant", "val": 5}],
["funky_na.*bi", {"type": "constant", "val": 7}]
]}
"""
params = Params(pyhocon.ConfigFactory.parse_string(json_params))
initializers = InitializerApplicator.from_params(params['initializer'])
model = Net()
initializers(model)
for parameter in model.conv.parameters():
assert torch.equal(parameter.data, torch.ones(parameter.size()) * 5)
parameter = model.linear_1_with_funky_name.bias
assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
开发者ID:Jordan-Sauchuk,项目名称:allennlp,代码行数:28,代码来源:initializers_test.py
示例4: test_add_output_dim
def test_add_output_dim(self, cuda=False):
for double in (False, True):
tkwargs = {
"device": torch.device("cuda") if cuda else torch.device("cpu"),
"dtype": torch.double if double else torch.float,
}
original_batch_shape = torch.Size([2])
# check exception is raised
X = torch.rand(2, 1, **tkwargs)
with self.assertRaises(ValueError):
add_output_dim(X=X, original_batch_shape=original_batch_shape)
# test no new batch dims
X = torch.rand(2, 2, 1, **tkwargs)
X_out, output_dim_idx = add_output_dim(
X=X, original_batch_shape=original_batch_shape
)
self.assertTrue(torch.equal(X_out, X.unsqueeze(0)))
self.assertEqual(output_dim_idx, 0)
# test new batch dims
X = torch.rand(3, 2, 2, 1, **tkwargs)
X_out, output_dim_idx = add_output_dim(
X=X, original_batch_shape=original_batch_shape
)
self.assertTrue(torch.equal(X_out, X.unsqueeze(1)))
self.assertEqual(output_dim_idx, 1)
开发者ID:saschwan,项目名称:botorch,代码行数:25,代码来源:test_utils.py
示例5: test_q_noisy_expected_improvement
def test_q_noisy_expected_improvement(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype in (torch.float, torch.double):
# the event shape is `b x q x t` = 1 x 2 x 1
samples_noisy = torch.tensor([1.0, 0.0], device=device, dtype=dtype)
samples_noisy = samples_noisy.view(1, 2, 1)
# X_baseline is `q' x d` = 1 x 1
X_baseline = torch.zeros(1, 1, device=device, dtype=dtype)
mm_noisy = MockModel(MockPosterior(samples=samples_noisy))
# X is `q x d` = 1 x 1
X = torch.zeros(1, 1, device=device, dtype=dtype)
# basic test
sampler = IIDNormalSampler(num_samples=2)
acqf = qNoisyExpectedImprovement(
model=mm_noisy, X_baseline=X_baseline, sampler=sampler
)
res = acqf(X)
self.assertEqual(res.item(), 1.0)
# basic test, no resample
sampler = IIDNormalSampler(num_samples=2, seed=12345)
acqf = qNoisyExpectedImprovement(
model=mm_noisy, X_baseline=X_baseline, sampler=sampler
)
res = acqf(X)
self.assertEqual(res.item(), 1.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
# basic test, qmc, no resample
sampler = SobolQMCNormalSampler(num_samples=2)
acqf = qNoisyExpectedImprovement(
model=mm_noisy, X_baseline=X_baseline, sampler=sampler
)
res = acqf(X)
self.assertEqual(res.item(), 1.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
# basic test, qmc, resample
sampler = SobolQMCNormalSampler(num_samples=2, resample=True, seed=12345)
acqf = qNoisyExpectedImprovement(
model=mm_noisy, X_baseline=X_baseline, sampler=sampler
)
res = acqf(X)
self.assertEqual(res.item(), 1.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 2, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertFalse(torch.equal(acqf.sampler.base_samples, bs))
开发者ID:saschwan,项目名称:botorch,代码行数:55,代码来源:test_monte_carlo.py
示例6: test_local_tensor_iterable_methods
def test_local_tensor_iterable_methods(self):
x = torch.FloatTensor([1, 2, 3])
y = torch.FloatTensor([2, 3, 4])
z = torch.FloatTensor([5, 6, 7])
assert(torch.equal(torch.stack([x, y, z]), torch.FloatTensor([[1, 2, 3], [2, 3, 4], [5, 6, 7]])))
x = torch.FloatTensor([1, 2, 3])
y = torch.FloatTensor([2, 3, 4])
z = torch.FloatTensor([5, 6, 7])
assert (torch.equal(torch.cat([x, y, z]), torch.FloatTensor([1, 2, 3, 2, 3, 4, 5, 6, 7])))
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:11,代码来源:torch_test.py
示例7: test_MockPosterior
def test_MockPosterior(self):
mean = torch.rand(2)
variance = torch.eye(2)
samples = torch.rand(1, 2)
mp = MockPosterior(mean=mean, variance=variance, samples=samples)
self.assertTrue(torch.equal(mp.mean, mean))
self.assertTrue(torch.equal(mp.variance, variance))
self.assertTrue(torch.all(mp.sample() == samples.unsqueeze(0)))
self.assertTrue(
torch.all(mp.sample(torch.Size([2])) == samples.repeat(2, 1, 1))
)
开发者ID:saschwan,项目名称:botorch,代码行数:11,代码来源:test_mock.py
示例8: test_GPyTorchPosterior
def test_GPyTorchPosterior(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype in (torch.float, torch.double):
mean = torch.rand(3, dtype=dtype, device=device)
variance = 1 + torch.rand(3, dtype=dtype, device=device)
covar = variance.diag()
mvn = MultivariateNormal(mean, lazify(covar))
posterior = GPyTorchPosterior(mvn=mvn)
# basics
self.assertEqual(posterior.device.type, device.type)
self.assertTrue(posterior.dtype == dtype)
self.assertEqual(posterior.event_shape, torch.Size([3, 1]))
self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1)))
self.assertTrue(torch.equal(posterior.variance, variance.unsqueeze(-1)))
# rsample
samples = posterior.rsample()
self.assertEqual(samples.shape, torch.Size([1, 3, 1]))
samples = posterior.rsample(sample_shape=torch.Size([4]))
self.assertEqual(samples.shape, torch.Size([4, 3, 1]))
samples2 = posterior.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(samples2.shape, torch.Size([4, 2, 3, 1]))
# rsample w/ base samples
base_samples = torch.randn(4, 3, 1, device=device, dtype=dtype)
# incompatible shapes
with self.assertRaises(RuntimeError):
posterior.rsample(
sample_shape=torch.Size([3]), base_samples=base_samples
)
samples_b1 = posterior.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
samples_b2 = posterior.rsample(
sample_shape=torch.Size([4]), base_samples=base_samples
)
self.assertTrue(torch.allclose(samples_b1, samples_b2))
base_samples2 = torch.randn(4, 2, 3, 1, device=device, dtype=dtype)
samples2_b1 = posterior.rsample(
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
)
samples2_b2 = posterior.rsample(
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
)
self.assertTrue(torch.allclose(samples2_b1, samples2_b2))
# collapse_batch_dims
b_mean = torch.rand(2, 3, dtype=dtype, device=device)
b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=device)
b_covar = b_variance.unsqueeze(-1) * torch.eye(3).type_as(b_variance)
b_mvn = MultivariateNormal(b_mean, lazify(b_covar))
b_posterior = GPyTorchPosterior(mvn=b_mvn)
b_base_samples = torch.randn(4, 1, 3, 1, device=device, dtype=dtype)
b_samples = b_posterior.rsample(
sample_shape=torch.Size([4]), base_samples=b_base_samples
)
self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
开发者ID:saschwan,项目名称:botorch,代码行数:54,代码来源:test_gpytorch.py
示例9: test_make_grid_not_inplace
def test_make_grid_not_inplace(self):
t = torch.rand(5, 3, 10, 10)
t_clone = t.clone()
utils.make_grid(t, normalize=False)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
utils.make_grid(t, normalize=True, scale_each=False)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
utils.make_grid(t, normalize=True, scale_each=True)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
开发者ID:WENXINGEVIN,项目名称:vision,代码行数:12,代码来源:test_utils.py
示例10: test_generic_mc_objective
def test_generic_mc_objective(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype in (torch.float, torch.double):
obj = GenericMCObjective(generic_obj)
samples = torch.randn(1, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
samples = torch.randn(2, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
samples = torch.randn(3, 1, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
samples = torch.randn(3, 2, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), generic_obj(samples)))
开发者ID:saschwan,项目名称:botorch,代码行数:12,代码来源:test_objective.py
示例11: test_q_expected_improvement
def test_q_expected_improvement(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype in (torch.float, torch.double):
# the event shape is `b x q x t` = 1 x 1 x 1
samples = torch.zeros(1, 1, 1, device=device, dtype=dtype)
mm = MockModel(MockPosterior(samples=samples))
# X is `q x d` = 1 x 1. X is a dummy and unused b/c of mocking
X = torch.zeros(1, 1, device=device, dtype=dtype)
# basic test
sampler = IIDNormalSampler(num_samples=2)
acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
res = acqf(X)
self.assertEqual(res.item(), 0.0)
# test shifting best_f value
acqf = qExpectedImprovement(model=mm, best_f=-1, sampler=sampler)
res = acqf(X)
self.assertEqual(res.item(), 1.0)
# basic test, no resample
sampler = IIDNormalSampler(num_samples=2, seed=12345)
acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
res = acqf(X)
self.assertEqual(res.item(), 0.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1]))
bs = acqf.sampler.base_samples.clone()
res = acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
# basic test, qmc, no resample
sampler = SobolQMCNormalSampler(num_samples=2)
acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
res = acqf(X)
self.assertEqual(res.item(), 0.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))
# basic test, qmc, resample
sampler = SobolQMCNormalSampler(num_samples=2, resample=True)
acqf = qExpectedImprovement(model=mm, best_f=0, sampler=sampler)
res = acqf(X)
self.assertEqual(res.item(), 0.0)
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1]))
bs = acqf.sampler.base_samples.clone()
acqf(X)
self.assertFalse(torch.equal(acqf.sampler.base_samples, bs))
开发者ID:saschwan,项目名称:botorch,代码行数:49,代码来源:test_monte_carlo.py
示例12: do_test_per_param_optim
def do_test_per_param_optim(self, fixed_param, free_param):
pyro.clear_param_store()
def model():
prior_dist = Normal(self.mu0, torch.pow(self.lam0, -0.5))
mu_latent = pyro.sample("mu_latent", prior_dist)
x_dist = Normal(mu_latent, torch.pow(self.lam, -0.5))
pyro.observe("obs", x_dist, self.data)
return mu_latent
def guide():
mu_q = pyro.param(
"mu_q",
Variable(
torch.zeros(1),
requires_grad=True))
log_sig_q = pyro.param(
"log_sig_q", Variable(
torch.zeros(1), requires_grad=True))
sig_q = torch.exp(log_sig_q)
pyro.sample("mu_latent", Normal(mu_q, sig_q))
def optim_params(module_name, param_name, tags):
if param_name == fixed_param:
return {'lr': 0.00}
elif param_name == free_param:
return {'lr': 0.01}
adam = optim.Adam(optim_params)
adam2 = optim.Adam(optim_params)
svi = SVI(model, guide, adam, loss="ELBO", trace_graph=True)
svi2 = SVI(model, guide, adam2, loss="ELBO", trace_graph=True)
svi.step()
adam_initial_step_count = list(adam.get_state()['mu_q']['state'].items())[0][1]['step']
adam.save('adam.unittest.save')
svi.step()
adam_final_step_count = list(adam.get_state()['mu_q']['state'].items())[0][1]['step']
adam2.load('adam.unittest.save')
svi2.step()
adam2_step_count_after_load_and_step = list(adam2.get_state()['mu_q']['state'].items())[0][1]['step']
assert adam_initial_step_count == 1
assert adam_final_step_count == 2
assert adam2_step_count_after_load_and_step == 2
free_param_unchanged = torch.equal(pyro.param(free_param).data, torch.zeros(1))
fixed_param_unchanged = torch.equal(pyro.param(fixed_param).data, torch.zeros(1))
assert fixed_param_unchanged and not free_param_unchanged
开发者ID:Magica-Chen,项目名称:pyro,代码行数:49,代码来源:test_optim.py
示例13: test_match_batch_shape
def test_match_batch_shape(self):
X = torch.rand(3, 2)
Y = torch.rand(1, 3, 2)
X_tf = match_batch_shape(X, Y)
self.assertTrue(torch.equal(X_tf, X.unsqueeze(0)))
X = torch.rand(1, 3, 2)
Y = torch.rand(2, 3, 2)
X_tf = match_batch_shape(X, Y)
self.assertTrue(torch.equal(X_tf, X.repeat(2, 1, 1)))
X = torch.rand(2, 3, 2)
Y = torch.rand(1, 3, 2)
with self.assertRaises(RuntimeError):
match_batch_shape(X, Y)
开发者ID:saschwan,项目名称:botorch,代码行数:15,代码来源:test_transforms.py
示例14: test_standardize
def test_standardize(self, cuda=False):
tkwargs = {"device": torch.device("cuda" if cuda else "cpu")}
for dtype in (torch.float, torch.double):
tkwargs["dtype"] = dtype
X = torch.tensor([0.0, 0.0], **tkwargs)
self.assertTrue(torch.equal(X, standardize(X)))
X2 = torch.tensor([0.0, 1.0, 1.0, 1.0], **tkwargs)
expected_X2_stdized = torch.tensor([-1.5, 0.5, 0.5, 0.5], **tkwargs)
self.assertTrue(torch.equal(expected_X2_stdized, standardize(X2)))
X3 = torch.tensor(
[[0.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], **tkwargs
).transpose(1, 0)
X3_stdized = standardize(X3)
self.assertTrue(torch.equal(X3_stdized[:, 0], expected_X2_stdized))
self.assertTrue(torch.equal(X3_stdized[:, 1], torch.zeros(4, **tkwargs)))
开发者ID:saschwan,项目名称:botorch,代码行数:15,代码来源:test_transforms.py
示例15: test_local_tensor_multi_var_methods
def test_local_tensor_multi_var_methods(self):
x = torch.FloatTensor([[1, 2], [2, 3], [5, 6]])
t, s = torch.max(x, 1)
assert (t == torch.FloatTensor([2, 3, 6])).float().sum() == 3
assert (s == torch.LongTensor([1, 1, 1])).float().sum() == 3
x = torch.FloatTensor([[0, 0], [1, 1]])
y, z = torch.eig(x, True)
assert (y == torch.FloatTensor([[1, 0], [0, 0]])).all()
assert (torch.equal(z == torch.FloatTensor([[0, 0], [1, 0]]), torch.ByteTensor([[1, 0], [1, 0]])))
x = torch.FloatTensor([[0, 0], [1, 0]])
y, z = torch.qr(x)
assert (y == torch.FloatTensor([[0, -1], [-1, 0]])).all()
assert (z == torch.FloatTensor([[-1, 0], [0, 0]])).all()
x = torch.arange(1, 6)
y, z = torch.kthvalue(x, 4)
assert (y == torch.FloatTensor([4])).all()
assert (z == torch.LongTensor([3])).all()
x = torch.zeros(3, 3)
w, y, z = torch.svd(x)
assert (w == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
assert (y == torch.FloatTensor([0, 0, 0])).all()
assert (z == torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])).all()
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:26,代码来源:torch_test.py
示例16: test_identity_mc_objective
def test_identity_mc_objective(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype in (torch.float, torch.double):
obj = IdentityMCObjective()
# single-element tensor
samples = torch.randn(1, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), samples[0]))
# single-dimensional non-squeezable tensor
samples = torch.randn(2, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), samples))
# two-dimensional squeezable tensor
samples = torch.randn(3, 1, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), samples.squeeze(-1)))
# two-dimensional non-squeezable tensor
samples = torch.randn(3, 2, device=device, dtype=dtype)
self.assertTrue(torch.equal(obj(samples), samples))
开发者ID:saschwan,项目名称:botorch,代码行数:16,代码来源:test_objective.py
示例17: test_forward_pass_runs_correctly
def test_forward_pass_runs_correctly(self):
"""
Check to make sure a forward pass on an ensemble of two identical copies of a model yields the same
results as the model itself.
"""
bidaf_ensemble = BidafEnsemble([self.model, self.model])
batch = Batch(self.instances)
batch.index_instances(self.vocab)
training_tensors = batch.as_tensor_dict()
bidaf_output_dict = self.model(**training_tensors)
ensemble_output_dict = bidaf_ensemble(**training_tensors)
metrics = self.model.get_metrics(reset=True)
# We've set up the data such that there's a fake answer that consists of the whole
# paragraph. _Any_ valid prediction for that question should produce an F1 of greater than
# zero, while if we somehow haven't been able to load the evaluation data, or there was an
# error with using the evaluation script, this will fail. This makes sure that we've
# loaded the evaluation data correctly and have hooked things up to the official evaluation
# script.
assert metrics['f1'] > 0
assert torch.equal(ensemble_output_dict['best_span'], bidaf_output_dict['best_span'])
assert ensemble_output_dict['best_span_str'] == bidaf_output_dict['best_span_str']
开发者ID:apmoore1,项目名称:allennlp,代码行数:25,代码来源:bidaf_ensemble_test.py
示例18: test_torch_function_with_multiple_output_on_remote_var
def test_torch_function_with_multiple_output_on_remote_var(self):
hook = TorchHook(verbose=False)
me = hook.local_worker
remote = VirtualWorker(id=2, hook=hook)
me.add_worker(remote)
x = Var(torch.FloatTensor([[1, 2], [4, 3], [5, 6]]))
x.send(remote)
y, z = torch.max(x, 1)
y.get()
assert torch.equal(y, Var(torch.FloatTensor([2, 4, 6])))
x = Var(torch.FloatTensor([[0, 0], [1, 0]])).send(remote)
y, z = torch.qr(x)
assert (y.get() == Var(torch.FloatTensor([[0, -1], [-1, 0]]))).all()
assert (z.get() == Var(torch.FloatTensor([[-1, 0], [0, 0]]))).all()
x = Var(torch.arange(1, 6)).send(remote)
y, z = torch.kthvalue(x, 4)
assert (y.get() == Var(torch.FloatTensor([4]))).all()
assert (z.get() == Var(torch.LongTensor([3]))).all()
x = Var(torch.FloatTensor([[0, 0], [0, 0]]))
x.send(remote)
y, z = torch.eig(x, True)
assert (y.get() == Var(torch.FloatTensor([[0, 0], [0, 0]]))).all()
assert (z.get() == Var(torch.FloatTensor([[1, 0.], [0, 1]]))).all()
x = Var(torch.zeros(3, 3)).send(remote)
w, y, z = torch.svd(x)
assert (w.get() == Var(torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]))).all()
assert (y.get() == Var(torch.FloatTensor([0, 0, 0]))).all()
assert (z.get() == Var(torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]))).all()
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:34,代码来源:torch_test.py
示例19: test_archiving
def test_archiving(self):
# copy params, since they'll get consumed during training
params_copy = copy.deepcopy(self.params.as_dict())
# `train_model` should create an archive
serialization_dir = self.TEST_DIR / 'archive_test'
model = train_model(self.params, serialization_dir=serialization_dir)
archive_path = serialization_dir / "model.tar.gz"
# load from the archive
archive = load_archive(archive_path)
model2 = archive.model
# check that model weights are the same
keys = set(model.state_dict().keys())
keys2 = set(model2.state_dict().keys())
assert keys == keys2
for key in keys:
assert torch.equal(model.state_dict()[key], model2.state_dict()[key])
# check that vocabularies are the same
vocab = model.vocab
vocab2 = model2.vocab
assert vocab._token_to_index == vocab2._token_to_index # pylint: disable=protected-access
assert vocab._index_to_token == vocab2._index_to_token # pylint: disable=protected-access
# check that params are the same
params2 = archive.config
assert params2.as_dict() == params_copy
开发者ID:pyknife,项目名称:allennlp,代码行数:33,代码来源:archival_test.py
示例20: compare_state_dict
def compare_state_dict(sa, sb):
if sa.keys() != sb.keys():
return False
for k, va in sa.items():
if not torch.equal(va, sb[k]):
return False
return True
开发者ID:xiaoyongshen,项目名称:Detectron.pytorch,代码行数:7,代码来源:model_builder.py
注:本文中的torch.equal函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论