本文整理汇总了Python中torch.eq函数的典型用法代码示例。如果您正苦于以下问题:Python eq函数的具体用法?Python eq怎么用?Python eq使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了eq函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: updateGradInput
def updateGradInput(self, input, y):
v1 = input[0]
v2 = input[1]
gw1 = self.gradInput[0]
gw2 = self.gradInput[1]
gw1.resize_as_(v1).copy_(v2)
gw2.resize_as_(v1).copy_(v1)
torch.mul(self.w1, self.w22, out=self.buffer)
gw1.addcmul_(-1, self.buffer.expand_as(v1), v1)
gw1.mul_(self.w.expand_as(v1))
torch.mul(self.w1, self.w32, out=self.buffer)
gw2.addcmul_(-1, self.buffer.expand_as(v1), v2)
gw2.mul_(self.w.expand_as(v1))
# self._idx = self._outputs <= 0
torch.le(self._outputs, 0, out=self._idx)
self._idx = self._idx.view(-1, 1).expand(gw1.size())
gw1[self._idx] = 0
gw2[self._idx] = 0
torch.eq(y, 1, out=self._idx)
self._idx = self._idx.view(-1, 1).expand(gw2.size())
gw1[self._idx] = gw1[self._idx].mul_(-1)
gw2[self._idx] = gw2[self._idx].mul_(-1)
if self.sizeAverage:
gw1.div_(y.size(0))
gw2.div_(y.size(0))
return self.gradInput
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:33,代码来源:CosineEmbeddingCriterion.py
示例2: __iter__
def __iter__(self):
for batch in self.data:
batch_size = len(batch)
batch = list(zip(*batch))
if self.eval:
assert len(batch) == 7
else:
assert len(batch) == 9
context_len = max(len(x) for x in batch[0])
context_id = torch.LongTensor(batch_size, context_len).fill_(0)
context_order = torch.LongTensor(batch_size,context_len).fill_(0)
for i, doc in enumerate(batch[0]):
context_id[i, :len(doc)] = torch.LongTensor(doc)
context_order[i,:len(doc)] = torch.from_numpy(np.arange(1,len(doc)+1))
feature_len = len(batch[1][0][0])
context_feature = torch.Tensor(batch_size, context_len, feature_len).fill_(0)
for i, doc in enumerate(batch[1]):
for j, feature in enumerate(doc):
context_feature[i, j, :] = torch.Tensor(feature)
context_tag = torch.LongTensor(batch_size, context_len).fill_(0)
for i, doc in enumerate(batch[2]):
context_tag[i, :len(doc)] = torch.LongTensor(doc)
context_ent = torch.LongTensor(batch_size, context_len).fill_(0)
for i, doc in enumerate(batch[3]):
context_ent[i, :len(doc)] = torch.LongTensor(doc)
question_len = max(len(x) for x in batch[4])
question_id = torch.LongTensor(batch_size, question_len).fill_(0)
question_order = torch.LongTensor(batch_size,question_len).fill_(0)
for i, doc in enumerate(batch[4]):
question_id[i, :len(doc)] = torch.LongTensor(doc)
question_order[i,:len(doc)] = torch.from_numpy(np.arange(1,len(doc)+1))
context_mask = torch.eq(context_id, 0)
question_mask = torch.eq(question_id, 0)
if not self.eval:
y_s = torch.LongTensor(batch[5])
y_e = torch.LongTensor(batch[6])
text = list(batch[-2])
span = list(batch[-1])
if self.gpu:
context_id = context_id.pin_memory()
context_feature = context_feature.pin_memory()
context_tag = context_tag.pin_memory()
context_ent = context_ent.pin_memory()
context_mask = context_mask.pin_memory()
question_id = question_id.pin_memory()
question_mask = question_mask.pin_memory()
context_order = context_order.pin_memory()
question_order = question_order.pin_memory()
if self.eval:
yield (context_id, context_feature, context_tag, context_ent, context_mask,
question_id, question_mask, context_order, question_order, text, span)
else:
yield (context_id, context_feature, context_tag, context_ent, context_mask,
question_id, question_mask, context_order, question_order, y_s, y_e, text, span)
开发者ID:chickenbestlover,项目名称:DrQA-RN,代码行数:60,代码来源:train_RN_multiattn2.py
示例3: calc_precision
def calc_precision(pred, label):
t1 = torch.topk(pred, 1)[-1]
t5 = torch.topk(pred, 5)[-1]
mask_1 = torch.eq(t1, label.view(-1, 1))
mask_5 = torch.eq(t5, label.view(-1, 1))
t1_error = 1 - len(t1[mask_1]) / len(label)
t5_error = 1 - len(t5[mask_5]) / len(label)
return t1_error, t5_error
开发者ID:ZhangXinNan,项目名称:LearnPractice,代码行数:8,代码来源:train.py
示例4: updateOutput
def updateOutput(self, input, y):
input1, input2 = input[0], input[1]
# keep backward compatibility
if self.buffer is None:
self.buffer = input1.new()
self.w1 = input1.new()
self.w22 = input1.new()
self.w = input1.new()
self.w32 = input1.new()
self._outputs = input1.new()
# comparison operators behave differently from cuda/c implementations
# TODO: verify name
if input1.type() == 'torch.cuda.FloatTensor':
self._idx = torch.cuda.ByteTensor()
else:
self._idx = torch.ByteTensor()
torch.mul(input1, input2, out=self.buffer)
torch.sum(self.buffer, 1, out=self.w1, keepdim=True)
epsilon = 1e-12
torch.mul(input1, input1, out=self.buffer)
torch.sum(self.buffer, 1, out=self.w22, keepdim=True).add_(epsilon)
# self._outputs is also used as a temporary buffer
self._outputs.resize_as_(self.w22).fill_(1)
torch.div(self._outputs, self.w22, out=self.w22)
self.w.resize_as_(self.w22).copy_(self.w22)
torch.mul(input2, input2, out=self.buffer)
torch.sum(self.buffer, 1, out=self.w32, keepdim=True).add_(epsilon)
torch.div(self._outputs, self.w32, out=self.w32)
self.w.mul_(self.w32)
self.w.sqrt_()
torch.mul(self.w1, self.w, out=self._outputs)
self._outputs = self._outputs.select(1, 0)
torch.eq(y, -1, out=self._idx)
self._outputs[self._idx] = self._outputs[self._idx].add_(-self.margin).clamp_(min=0)
torch.eq(y, 1, out=self._idx)
self._outputs[self._idx] = self._outputs[self._idx].mul_(-1).add_(1)
self.output = self._outputs.sum().item()
if self.sizeAverage:
self.output = self.output / y.size(0)
return self.output
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:50,代码来源:CosineEmbeddingCriterion.py
示例5: test_local_var_binary_methods
def test_local_var_binary_methods(self):
''' Unit tests for methods mentioned on issue 1385
https://github.com/OpenMined/PySyft/issues/1385'''
x = torch.FloatTensor([1, 2, 3, 4])
y = torch.FloatTensor([[1, 2, 3, 4]])
z = torch.matmul(x, y.t())
assert (torch.equal(z, torch.FloatTensor([30])))
z = torch.add(x, y)
assert (torch.equal(z, torch.FloatTensor([[2, 4, 6, 8]])))
x = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
z = torch.cross(x, y, dim=1)
assert (torch.equal(z, torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])))
x = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
z = torch.dist(x, y)
t = torch.FloatTensor([z])
assert (torch.equal(t, torch.FloatTensor([0.])))
x = torch.FloatTensor([1, 2, 3])
y = torch.FloatTensor([1, 2, 3])
z = torch.dot(x, y)
t = torch.FloatTensor([z])
assert torch.equal(t, torch.FloatTensor([14]))
z = torch.eq(x, y)
assert (torch.equal(z, torch.ByteTensor([1, 1, 1])))
z = torch.ge(x, y)
assert (torch.equal(z, torch.ByteTensor([1, 1, 1])))
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:27,代码来源:torch_test.py
示例6: forward
def forward(self, output, context):
batch_size = output.size(0)
hidden_size = output.size(2)
input_size = context.size(1)
# (batch, out_len, dim) * (batch, in_len, dim) -> (batch, out_len, in_len)
attn = torch.bmm(output, context.transpose(1, 2))
mask = torch.eq(attn, 0).data.byte()
attn.data.masked_fill_(mask, -float('inf'))
attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
# (batch, out_len, in_len) * (batch, in_len, dim) -> (batch, out_len, dim)
mix = torch.bmm(attn, context)
# concat -> (batch, out_len, 2*dim)
combined = torch.cat((mix, output), dim=2)
# output -> (batch, out_len, dim)
output = F.tanh(self.linear_out(combined.view(-1, 2 * hidden_size))).view(batch_size, -1, hidden_size)
if not output.is_contiguous():
output = output.contiguous()
return output, attn
开发者ID:shruthi0898,项目名称:Writing-editing-Network,代码行数:25,代码来源:attention.py
示例7: test_rescale_torch_tensor
def test_rescale_torch_tensor(self):
rows, cols = 3, 5
original_tensor = torch.randint(low=10, high=40, size=(rows, cols)).float()
prev_max_tensor = torch.ones(1, 5) * 40.0
prev_min_tensor = torch.ones(1, 5) * 10.0
new_min_tensor = torch.ones(1, 5) * -1.0
new_max_tensor = torch.ones(1, 5).float()
print("Original tensor: ", original_tensor)
rescaled_tensor = rescale_torch_tensor(
original_tensor,
new_min_tensor,
new_max_tensor,
prev_min_tensor,
prev_max_tensor,
)
print("Rescaled tensor: ", rescaled_tensor)
reconstructed_original_tensor = rescale_torch_tensor(
rescaled_tensor,
prev_min_tensor,
prev_max_tensor,
new_min_tensor,
new_max_tensor,
)
print("Reconstructed Original tensor: ", reconstructed_original_tensor)
comparison_tensor = torch.eq(original_tensor, reconstructed_original_tensor)
self.assertTrue(torch.sum(comparison_tensor), rows * cols)
开发者ID:sra4077,项目名称:Horizon,代码行数:28,代码来源:test_utils.py
示例8: test_remote_var_binary_methods
def test_remote_var_binary_methods(self):
''' Unit tests for methods mentioned on issue 1385
https://github.com/OpenMined/PySyft/issues/1385'''
hook = TorchHook(verbose=False)
local = hook.local_worker
remote = VirtualWorker(hook, 1)
local.add_worker(remote)
x = Var(torch.FloatTensor([1, 2, 3, 4])).send(remote)
y = Var(torch.FloatTensor([[1, 2, 3, 4]])).send(remote)
z = torch.matmul(x, y.t())
assert (torch.equal(z.get(), Var(torch.FloatTensor([30]))))
z = torch.add(x, y)
assert (torch.equal(z.get(), Var(torch.FloatTensor([[2, 4, 6, 8]]))))
x = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
y = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
z = torch.cross(x, y, dim=1)
assert (torch.equal(z.get(), Var(torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]))))
x = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
y = Var(torch.FloatTensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])).send(remote)
z = torch.dist(x, y)
assert (torch.equal(z.get(), Var(torch.FloatTensor([0.]))))
x = Var(torch.FloatTensor([1, 2, 3])).send(remote)
y = Var(torch.FloatTensor([1, 2, 3])).send(remote)
z = torch.dot(x, y)
print(torch.equal(z.get(), Var(torch.FloatTensor([14]))))
z = torch.eq(x, y)
assert (torch.equal(z.get(), Var(torch.ByteTensor([1, 1, 1]))))
z = torch.ge(x, y)
assert (torch.equal(z.get(), Var(torch.ByteTensor([1, 1, 1]))))
开发者ID:TanayGahlot,项目名称:PySyft,代码行数:30,代码来源:torch_test.py
示例9: evaluate
def evaluate(attention_model,x_test,y_test):
"""
cv results
Args:
attention_model : {object} model
x_test : {nplist} x_test
y_test : {nplist} y_test
Returns:
cv-accuracy
"""
attention_model.batch_size = x_test.shape[0]
attention_model.hidden_state = attention_model.init_hidden()
x_test_var = Variable(torch.from_numpy(x_test).type(torch.LongTensor))
y_test_pred,_ = attention_model(x_test_var)
if bool(attention_model.type):
y_preds = torch.max(y_test_pred,1)[1]
y_test_var = Variable(torch.from_numpy(y_test).type(torch.LongTensor))
else:
y_preds = torch.round(y_test_pred.type(torch.DoubleTensor).squeeze(1))
y_test_var = Variable(torch.from_numpy(y_test).type(torch.DoubleTensor))
return torch.eq(y_preds,y_test_var).data.sum()/x_test_var.size(0)
开发者ID:daiyongya,项目名称:Structured-Self-Attention,代码行数:28,代码来源:train.py
示例10: rpn_bbox_loss
def rpn_bbox_loss(target_bbox, rpn_match, rpn_bbox, config):
"""Return the RPN bounding box loss graph.
config: the model config object.
target_bbox: [batch, max positive anchors, (dy, dx, log(dh), log(dw))].
Uses 0 padding to fill in unsed bbox deltas.
rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
-1=negative, 0=neutral anchor.
rpn_bbox: [batch, anchors, (dy, dx, log(dh), log(dw))]
"""
# Positive anchors contribute to the loss, but negative and
# neutral anchors (match value of 0 or -1) don't.
indices = torch.eq(rpn_match, 1)
rpn_bbox = torch.masked_select(rpn_bbox, indices)
batch_counts = torch.sum(indices.float(), dim=1)
outputs = []
for i in range(config.IMAGES_PER_GPU):
# print(batch_counts[i].cpu().data.numpy()[0])
outputs.append(target_bbox[i, torch.arange(int(batch_counts[i].cpu().data.numpy()[0])).type(torch.cuda.LongTensor)])
target_bbox = torch.cat(outputs, dim=0)
loss = F.smooth_l1_loss(rpn_bbox, target_bbox, size_average=True)
return loss
开发者ID:huanglizhi,项目名称:Pytorch_Mask_RCNN,代码行数:25,代码来源:mask_rcnn.py
示例11: knn
def knn(Mxx, Mxy, Myy, k, sqrt):
n0 = Mxx.size(0)
n1 = Myy.size(0)
label = torch.cat((torch.ones(n0),torch.zeros(n1)))
M = torch.cat((torch.cat((Mxx,Mxy),1), torch.cat((Mxy.transpose(0,1),Myy), 1)), 0)
if sqrt:
M = M.abs().sqrt()
INFINITY = float('inf')
val, idx = (M+torch.diag(INFINITY*torch.ones(n0+n1))).topk(k, 0, False)
count = torch.zeros(n0+n1)
for i in range(0,k):
count = count + label.index_select(0,idx[i])
pred = torch.ge(count, (float(k)/2)*torch.ones(n0+n1)).float()
s = Score_knn()
s.tp = (pred*label).sum()
s.fp = (pred*(1-label)).sum()
s.fn = ((1-pred)*label).sum()
s.tn = ((1-pred)*(1-label)).sum()
s.precision = s.tp/(s.tp+s.fp)
s.recall = s.tp/(s.tp+s.fn)
s.acc_t = s.tp/(s.tp+s.fn)
s.acc_f = s.tn/(s.tn+s.fp)
s.acc = torch.eq(label, pred).float().mean()
s.k = k
return s
开发者ID:RobinROAR,项目名称:TensorflowTutorialsCode,代码行数:28,代码来源:metric.py
示例12: test
def test(net, testloader, config):
total, correct = 0.0, 0.0
for i, data in enumerate(testloader):
# Get inputs
X, S1, S2, labels = data
if X.size()[0] != config.batch_size:
continue # Drop those data, if not enough for a batch
# Send Tensors to GPU if available
if use_GPU:
X = X.cuda()
S1 = S1.cuda()
S2 = S2.cuda()
labels = labels.cuda()
# Wrap to autograd.Variable
X, S1, S2 = Variable(X), Variable(S1), Variable(S2)
# Forward pass
outputs, predictions = net(X, S1, S2, config)
# Select actions with max scores(logits)
_, predicted = torch.max(outputs, dim=1, keepdim=True)
# Unwrap autograd.Variable to Tensor
predicted = predicted.data
# Compute test accuracy
correct += (torch.eq(torch.squeeze(predicted), labels)).sum()
total += labels.size()[0]
print('Test Accuracy: {:.2f}%'.format(100 * (correct / total)))
开发者ID:Kaushalya,项目名称:pytorch-value-iteration-networks,代码行数:25,代码来源:train.py
示例13: forward
def forward(self, input, target):
y_true = target.int().unsqueeze(-1)
same_id = torch.eq(y_true, y_true.t()).type_as(input)
pos_mask = same_id
neg_mask = 1 - same_id
def _mask_max(input_tensor, mask, axis=None, keepdims=False):
input_tensor = input_tensor - 1e6 * (1 - mask)
_max, _idx = torch.max(input_tensor, dim=axis, keepdim=keepdims)
return _max, _idx
def _mask_min(input_tensor, mask, axis=None, keepdims=False):
input_tensor = input_tensor + 1e6 * (1 - mask)
_min, _idx = torch.min(input_tensor, dim=axis, keepdim=keepdims)
return _min, _idx
# output[i, j] = || feature[i, :] - feature[j, :] ||_2
dist_squared = torch.sum(input ** 2, dim=1, keepdim=True) + \
torch.sum(input.t() ** 2, dim=0, keepdim=True) - \
2.0 * torch.matmul(input, input.t())
dist = dist_squared.clamp(min=1e-16).sqrt()
pos_max, pos_idx = _mask_max(dist, pos_mask, axis=-1)
neg_min, neg_idx = _mask_min(dist, neg_mask, axis=-1)
# loss(x, y) = max(0, -y * (x1 - x2) + margin)
y = torch.ones(same_id.size()[0]).to(DEVICE)
return F.margin_ranking_loss(neg_min.float(),
pos_max.float(),
y,
self.margin,
self.size_average)
开发者ID:jiangqy,项目名称:reid-mgn,代码行数:33,代码来源:triplet.py
示例14: test_train
def test_train(self):
self._metric.train()
calls = [[torch.FloatTensor([0.0]), torch.LongTensor([0])],
[torch.FloatTensor([0.0, 0.1, 0.2, 0.3]), torch.LongTensor([0, 1, 2, 3])]]
for i in range(len(self._states)):
self._metric.process(self._states[i])
self.assertEqual(2, len(self._metric_function.call_args_list))
for i in range(len(self._metric_function.call_args_list)):
self.assertTrue(torch.eq(self._metric_function.call_args_list[i][0][0], calls[i][0]).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[i][0][1], -calls[i][1])), 1e-12).all)
self._metric_function.reset_mock()
self._metric.process_final({})
self._metric_function.assert_called_once()
self.assertTrue(torch.eq(self._metric_function.call_args_list[0][0][1], torch.LongTensor([0, 1, 2, 3, 4])).all)
self.assertTrue(torch.lt(torch.abs(torch.add(self._metric_function.call_args_list[0][0][0], -torch.FloatTensor([0.0, 0.1, 0.2, 0.3, 0.4]))), 1e-12).all)
开发者ID:little1tow,项目名称:torchbearer,代码行数:16,代码来源:test_wrappers.py
示例15: test_serialization
def test_serialization(self):
nesting_field = data.Field(batch_first=True)
field = data.NestedField(nesting_field)
ex1 = data.Example.fromlist(["john loves mary"], [("words", field)])
ex2 = data.Example.fromlist(["mary cries"], [("words", field)])
dataset = data.Dataset([ex1, ex2], [("words", field)])
field.build_vocab(dataset)
examples_data = [
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("john") + ["</w>", "<cpad>"],
["<w>"] + list("loves") + ["</w>"],
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
],
[
["<w>", "<s>", "</w>"] + ["<cpad>"] * 4,
["<w>"] + list("mary") + ["</w>", "<cpad>"],
["<w>"] + list("cries") + ["</w>"],
["<w>", "</s>", "</w>"] + ["<cpad>"] * 4,
["<cpad>"] * 7,
]
]
field_pickle_filename = "char_field.pl"
field_pickle_path = os.path.join(self.test_dir, field_pickle_filename)
torch.save(field, field_pickle_path)
loaded_field = torch.load(field_pickle_path)
assert loaded_field == field
original_numericalization = field.numericalize(examples_data)
pickled_numericalization = loaded_field.numericalize(examples_data)
assert torch.all(torch.eq(original_numericalization, pickled_numericalization))
开发者ID:tu-artem,项目名称:text,代码行数:35,代码来源:test_field.py
示例16: test_serialization_built_vocab
def test_serialization_built_vocab(self):
self.write_test_ppid_dataset(data_format="tsv")
question_field = data.Field(sequential=True)
tsv_fields = [("id", None), ("q1", question_field),
("q2", question_field), ("label", None)]
tsv_dataset = data.TabularDataset(
path=self.test_ppid_dataset_path, format="tsv",
fields=tsv_fields)
question_field.build_vocab(tsv_dataset)
question_pickle_filename = "question.pl"
question_pickle_path = os.path.join(self.test_dir, question_pickle_filename)
torch.save(question_field, question_pickle_path)
loaded_question_field = torch.load(question_pickle_path)
assert loaded_question_field == question_field
test_example_data = [["When", "do", "you", "use", "シ",
"instead", "of", "し?"],
["What", "is", "2+2", "<pad>", "<pad>",
"<pad>", "<pad>", "<pad>"],
["Here", "is", "a", "sentence", "with",
"some", "oovs", "<pad>"]]
# Test results of numericalization
original_numericalization = question_field.numericalize(test_example_data)
pickled_numericalization = loaded_question_field.numericalize(test_example_data)
assert torch.all(torch.eq(original_numericalization, pickled_numericalization))
开发者ID:tu-artem,项目名称:text,代码行数:31,代码来源:test_field.py
示例17: forward
def forward(self, input, target):
buffer = input.new()
buffer.resize_as_(input).copy_(input)
buffer[torch.eq(target, -1.)] = 0
output = buffer.sum()
buffer.fill_(self.margin).add_(-1, input)
buffer.clamp_(min=0)
buffer[torch.eq(target, 1.)] = 0
output += buffer.sum()
if self.size_average:
output = output / input.nelement()
self.save_for_backward(input, target)
return input.new((output,))
开发者ID:athiwatp,项目名称:pytorch,代码行数:16,代码来源:loss.py
示例18: train
def train():
epoch_num, loss_sum, cort_num_sum = 0, 0.0, 0
for epoch in epoches :
epoch_num += 1
inputs = Variable(epoch[0])
target = Variable(epoch[1])
output = model(inputs)
loss = criterion(output, target)
# reset gradients
optimizer.zero_grad()
# backward pass
loss.backward()
# update parameters
optimizer.step()
# get training infomation
loss_sum += loss.data[0]
_, pred = torch.max(output.data, 1)
num_correct = torch.eq(pred, epoch[1]).sum()
cort_num_sum += num_correct
loss_avg = loss_sum /float(epoch_num)
cort_num_avg = cort_num_sum / float(epoch_num) /float( epoch_size)
return loss_avg,cort_num_avg
开发者ID:wu-yy,项目名称:pytorchExample,代码行数:27,代码来源:mnist_train.py
示例19: updateOutput
def updateOutput(self, input, y):
if self.buffer is None:
self.buffer = input.new()
self.buffer.resize_as_(input).copy_(input)
self.buffer[torch.eq(y, -1.)] = 0
self.output = self.buffer.sum()
self.buffer.fill_(self.margin).add_(-1, input)
self.buffer.clamp_(min=0)
self.buffer[torch.eq(y, 1.)] = 0
self.output = self.output + self.buffer.sum()
if self.sizeAverage:
self.output = self.output / input.nelement()
return self.output
开发者ID:Northrend,项目名称:pytorch,代码行数:16,代码来源:HingeEmbeddingCriterion.py
示例20: updateGradInput
def updateGradInput(self, input, y):
self.gradInput.resize_as_(input).copy_(y)
self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0
if self.sizeAverage:
self.gradInput.mul_(1. / input.nelement())
return self.gradInput
开发者ID:Northrend,项目名称:pytorch,代码行数:8,代码来源:HingeEmbeddingCriterion.py
注:本文中的torch.eq函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论