本文整理汇总了Python中torch.index_select函数的典型用法代码示例。如果您正苦于以下问题:Python index_select函数的具体用法?Python index_select怎么用?Python index_select使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了index_select函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: forward
def forward(self, sents, sent_lengths):
'''
sents is (batch_size by padded_length)
when we evaluate sentence by sentence, you evaluate it with batch_size = 1, padded_length.
[[1, 2, 3, 4]] etc.
'''
batch_size = sents.size()[0]
sent_lengths = list(sent_lengths)
# We sort and then do pad packed sequence here.
descending_lengths = [x for x, _ in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
descending_indices = [x for _, x in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
descending_lengths = torch.tensor(descending_lengths)
descending_indices = torch.tensor(descending_indices).to(device)
descending_sents = torch.index_select(sents, torch.tensor(0), descending_indices)
# get embedding
embed = self.embedding(descending_sents)
# pack padded sequence
embed = torch.nn.utils.rnn.pack_padded_sequence(embed, descending_lengths, batch_first=True)
# fprop though RNN
self.hidden = self.init_hidden(batch_size)
rnn_out, self.hidden = self.gru(embed, self.hidden)
pdb.set_trace()
rnn_out, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
# rnn_out is 32 by 72 by 256
# change the order back
change_it_back = [x for _, x in sorted(zip(descending_indices, range(len(descending_indices))))]
self.hidden = torch.index_select(self.hidden, 1, torch.LongTensor(change_it_back).to(device))
rnn_out = torch.index_select(rnn_out, 0, torch.LongTensor(change_it_back).to(device))
return rnn_out, self.hidden
开发者ID:vwrj,项目名称:neural_machine_translation,代码行数:33,代码来源:V2-Attention-Vish.py
示例2: get_triplet_loss
def get_triplet_loss(image_a_pred, image_b_pred, matches_a, matches_b, non_matches_a, non_matches_b, alpha):
"""
Computes the loss function
\sum_{triplets} ||D(I_a, u_a, I_b, u_{b,match})||_2^2 - ||D(I_a, u_a, I_b, u_{b,non-match)||_2^2 + alpha
"""
num_matches = matches_a.size()[0]
num_non_matches = non_matches_a.size()[0]
multiplier = num_non_matches / num_matches
## non_matches_a is already replicated up to be the right size
## non_matches_b is also that side
## matches_a is just a smaller version of non_matches_a
## matches_b is the only thing that needs to be replicated up in size
matches_b_long = torch.t(matches_b.repeat(multiplier, 1)).contiguous().view(-1)
matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b_long)
non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b)
triplet_losses = (matches_a_descriptors - matches_b_descriptors).pow(2) - (matches_a_descriptors - non_matches_b_descriptors).pow(2) + alpha
triplet_loss = 1.0 / num_non_matches * torch.clamp(triplet_losses, min=0).sum()
return triplet_loss
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:26,代码来源:pixelwise_contrastive_loss.py
示例3: get_loss
def get_loss(self, image_a_pred, image_b_pred, mask_a, mask_b):
loss = 0
# get the nonzero indices
mask_a_indices_flat = torch.nonzero(mask_a)
mask_b_indices_flat = torch.nonzero(mask_b)
if len(mask_a_indices_flat) == 0:
return Variable(torch.cuda.LongTensor([0]), requires_grad=True)
if len(mask_b_indices_flat) == 0:
return Variable(torch.cuda.LongTensor([0]), requires_grad=True)
# take 5000 random pixel samples of the object, using the mask
num_samples = 10000
rand_numbers_a = (torch.rand(num_samples)*len(mask_a_indices_flat)).cuda()
rand_indices_a = Variable(torch.floor(rand_numbers_a).type(torch.cuda.LongTensor), requires_grad=False)
randomized_mask_a_indices_flat = torch.index_select(mask_a_indices_flat, 0, rand_indices_a).squeeze(1)
rand_numbers_b = (torch.rand(num_samples)*len(mask_b_indices_flat)).cuda()
rand_indices_b = Variable(torch.floor(rand_numbers_b).type(torch.cuda.LongTensor), requires_grad=False)
randomized_mask_b_indices_flat = torch.index_select(mask_b_indices_flat, 0, rand_indices_b).squeeze(1)
# index into the image and get descriptors
M_margin = 0.5 # margin parameter
random_img_a_object_descriptors = torch.index_select(image_a_pred, 1, randomized_mask_a_indices_flat)
random_img_b_object_descriptors = torch.index_select(image_b_pred, 1, randomized_mask_b_indices_flat)
pixel_wise_loss = (random_img_a_object_descriptors - random_img_b_object_descriptors).pow(2).sum(dim=2)
pixel_wise_loss = torch.add(pixel_wise_loss, -2*M_margin)
zeros_vec = torch.zeros_like(pixel_wise_loss)
loss += torch.max(zeros_vec, pixel_wise_loss).sum()
return loss
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:32,代码来源:semantic_consistency_loss.py
示例4: forward
def forward(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse=False):
ctx.padding_idx = padding_idx
ctx.scale_grad_by_freq = scale_grad_by_freq
ctx._indices = None
ctx.sparse = sparse
assert indices.dim() <= 2
assert not ctx.needs_input_grad[0], "Embedding doesn't " \
"compute the gradient w.r.t. the indices"
ctx._backend = type2backend[type(weight)]
ctx._weight_size = weight.size()
if not indices.is_contiguous():
ctx._indices = indices.contiguous()
indices = ctx._indices
else:
ctx.save_for_backward(indices)
output = weight.new()
if max_norm is not None:
cls._renorm(ctx, indices, weight, max_norm, norm_type)
if indices.dim() == 1:
output = torch.index_select(weight, 0, indices)
else:
output = torch.index_select(weight, 0, indices.view(-1))
output = output.view(indices.size(0), indices.size(1), weight.size(1))
return output
开发者ID:Northrend,项目名称:pytorch,代码行数:32,代码来源:sparse.py
示例5: dual_OT_model
def dual_OT_model(self, Xs_batch, i_t):
batch_size = i_t.shape[0]
u_batch = self.u(Xs_batch)
v_batch = torch.index_select(self.v, dim=0, index=i_t)
Xt_batch = torch.index_select(self.Xt, dim=0, index=i_t)
return self.dual_OT_batch_loss(batch_size=batch_size, u_batch=u_batch, v_batch=v_batch, Xs_batch=Xs_batch, Xt_batch=Xt_batch)
开发者ID:vivienseguy,项目名称:Large-Scale-OT,代码行数:8,代码来源:StochasticOTSemiDiscrete.py
示例6: get_loss_original
def get_loss_original(self, image_a_pred, image_b_pred, matches_a,
matches_b, non_matches_a, non_matches_b,
M_margin=0.5, non_match_loss_weight=1.0):
# this is pegged to it's implemenation at sha 87abdb63bb5b99d9632f5c4360b5f6f1cf54245f
"""
Computes the loss function
DCN = Dense Correspondence Network
num_images = number of images in this batch
num_matches = number of matches
num_non_matches = number of non-matches
W = image width
H = image height
D = descriptor dimension
match_loss = 1/num_matches \sum_{num_matches} ||descriptor_a - descriptor_b||_2^2
non_match_loss = 1/num_non_matches \sum_{num_non_matches} max(0, M_margin - ||descriptor_a - descriptor_b||_2^2 )
loss = match_loss + non_match_loss
:param image_a_pred: Output of DCN network on image A.
:type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
:param image_b_pred: same as image_a_pred
:type image_b_pred:
:param matches_a: torch.Variable(torch.LongTensor) has shape [num_matches,], a (u,v) pair is mapped
to (u,v) ---> image_width * v + u, this matches the shape of one dimension of image_a_pred
:type matches_a: torch.Variable(torch.FloatTensor)
:param matches_b: same as matches_b
:type matches_b:
:param non_matches_a: torch.Variable(torch.FloatTensor) has shape [num_non_matches,], a (u,v) pair is mapped
to (u,v) ---> image_width * v + u, this matches the shape of image_a_pred
:type non_matches_a: torch.Variable(torch.FloatTensor)
:param non_matches_b: same as non_matches_a
:type non_matches_b:
:return: loss, match_loss, non_match_loss
:rtype: torch.Variable(torch.FloatTensor) each of shape torch.Size([1])
"""
num_matches = matches_a.size()[0]
num_non_matches = non_matches_a.size()[0]
matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)
match_loss = 1.0/num_matches * (matches_a_descriptors - matches_b_descriptors).pow(2).sum()
# add loss via non_matches
non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b)
pixel_wise_loss = (non_matches_a_descriptors - non_matches_b_descriptors).pow(2).sum(dim=2)
pixel_wise_loss = torch.add(torch.neg(pixel_wise_loss), M_margin)
zeros_vec = torch.zeros_like(pixel_wise_loss)
non_match_loss = non_match_loss_weight * 1.0/num_non_matches * torch.max(zeros_vec, pixel_wise_loss).sum()
loss = match_loss + non_match_loss
return loss, match_loss, non_match_loss
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:55,代码来源:pixelwise_contrastive_loss.py
示例7: forward
def forward(self, base_feat, im_info, gt_boxes, num_boxes):
batch_size = base_feat.size(0)
# return feature map after convrelu layer
rpn_conv1 = F.relu(self.RPN_Conv(base_feat), inplace=True)
# get rpn classification score
rpn_cls_score = self.RPN_cls_score(rpn_conv1)
rpn_cls_score_reshape = self.reshape(rpn_cls_score, 2)
rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, dim=1)
rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out)
# get rpn offsets to the anchor boxes
rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1)
# proposal layer
cfg_key = 'TRAIN' if self.training else 'TEST'
rois = self.RPN_proposal((rpn_cls_prob.data, rpn_bbox_pred.data,
im_info, cfg_key))
self.rpn_loss_cls = 0
self.rpn_loss_box = 0
# generating training labels and build the rpn loss
if self.training:
assert gt_boxes is not None
rpn_data = self.RPN_anchor_target((rpn_cls_score.data, gt_boxes, im_info, num_boxes))
# compute classification loss
rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
rpn_label = rpn_data[0].view(batch_size, -1)
rpn_keep = Variable(rpn_label.view(-1).ne(-1).nonzero().view(-1))
rpn_cls_score = torch.index_select(rpn_cls_score.view(-1,2), 0, rpn_keep)
rpn_label = torch.index_select(rpn_label.view(-1), 0, rpn_keep.data)
rpn_label = Variable(rpn_label.long())
self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)
fg_cnt = torch.sum(rpn_label.data.ne(0))
rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
# compute bbox regression loss
rpn_bbox_inside_weights = Variable(rpn_bbox_inside_weights)
rpn_bbox_outside_weights = Variable(rpn_bbox_outside_weights)
rpn_bbox_targets = Variable(rpn_bbox_targets)
self.rpn_loss_box = _smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights,
rpn_bbox_outside_weights, sigma=3, dim=[1,2,3])
return rois, self.rpn_loss_cls, self.rpn_loss_box
开发者ID:lianDaniel,项目名称:R-FCN.pytorch,代码行数:53,代码来源:rpn.py
示例8: barycentric_mapping_loss_model
def barycentric_mapping_loss_model(self, neuralNet, Xs_batch, i_t):
self.u.eval()
self.v.requires_grad_(False)
u_batch = self.u(Xs_batch)
v_batch = torch.index_select(self.v, dim=0, index=i_t)
Xt_batch = torch.index_select(self.Xt, dim=0, index=i_t)
fXs_batch = neuralNet(Xs_batch)
return self.barycentric_model_batch_loss(u_batch, v_batch, Xs_batch, Xt_batch, fXs_batch)
开发者ID:vivienseguy,项目名称:Large-Scale-OT,代码行数:12,代码来源:StochasticOTSemiDiscrete.py
示例9: forward
def forward(self, input, output, input_lens=None, output_lens=None, lookup=None, **kwargs):
h0 = self.h0.expand(1, input.size(0), self.hidden_dim).contiguous()
c0 = self.c0.expand(1, input.size(0), self.hidden_dim).contiguous()
input_encoded, input_h, input_c = self.encoder(input, h0, c0, lens=input_lens)
if lookup:
input_h = th.index_select(input_h, 1, lookup)
input_c = th.index_select(input_c, 1, lookup)
transfer_h, transfer_c = self.transfer(input_h, input_c, **kwargs)
log_probs, _, _ = self.decoder(output, transfer_h, transfer_c, lens=output_lens)
return log_probs
开发者ID:douwekiela,项目名称:nncg-negation,代码行数:12,代码来源:models.py
示例10: updateOutput
def updateOutput(self, input):
self.renorm(input)
input = self._makeInputContiguous(input)
if input.dim() == 1:
torch.index_select(self.weight, 0, input, out=self.output)
elif input.dim() == 2:
torch.index_select(self.weight, 0, input.view(-1), out=self.output)
self.output = self.output.view(input.size(0), input.size(1), self.weight.size(1))
else:
raise RuntimeError("input must be a vector or matrix")
return self.output
开发者ID:Northrend,项目名称:pytorch,代码行数:12,代码来源:LookupTable.py
示例11: random_sample_from_masked_image_torch
def random_sample_from_masked_image_torch(img_mask, num_samples):
"""
:param img_mask: Numpy array [H,W] or torch.Tensor with shape [H,W]
:type img_mask:
:param num_samples: an integer
:type num_samples:
:return: tuple of torch.LongTensor in (u,v) format. Each torch.LongTensor has shape
[num_samples]
:rtype:
"""
image_height, image_width = img_mask.shape
if isinstance(img_mask, np.ndarray):
img_mask_torch = torch.from_numpy(img_mask).float()
else:
img_mask_torch = img_mask
# This code would randomly subsample from the mask
mask = img_mask_torch.view(image_width*image_height,1).squeeze(1)
mask_indices_flat = torch.nonzero(mask)
if len(mask_indices_flat) == 0:
return (None, None)
rand_numbers = torch.rand(num_samples)*len(mask_indices_flat)
rand_indices = torch.floor(rand_numbers).long()
uv_vec_flattened = torch.index_select(mask_indices_flat, 0, rand_indices).squeeze(1)
uv_vec = utils.flattened_pixel_locations_to_u_v(uv_vec_flattened, image_width)
return uv_vec
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:30,代码来源:correspondence_finder.py
示例12: next
def next(self):
if self.next_i+self.batch_size > len(self.data):
raise StopIteration()
else:
x_idx = self.x_idx[self.next_i:self.next_i+self.batch_size]
self.next_i += self.batch_size
labels = {k: torch.index_select(self.labels[k], 0, x_idx) for k in self.labels}
x = self.select_data(x_idx)
inputs = {}
sizes = {}
for k,v in labels.items():
possibilities = [self.label_idxs[k][v[i].item()] for i in range(len(x_idx))]
sizes[k] = [len(X) for X in possibilities]
input_idx = [np.random.choice(X, size=self.k_shot[k]) for X in possibilities]
_inputs = [
self.select_data(torch.LongTensor([I[j] for I in input_idx]))
for j in range(self.k_shot[k])]
if self.mode == "tensor":
inputs[k] = torch.cat([x.unsqueeze(1) for x in _inputs], dim=1)
elif self.mode == "list":
inputs[k] = [[_inputs[j][i] for j in range(self.k_shot[k])]
for i in range(len(_inputs[0]))]
batch = VHEBatch(target=x, inputs=inputs, sizes=sizes)
for transform in self.transforms:
batch = transform.apply(batch)
return batch
开发者ID:insperatum,项目名称:vhe,代码行数:29,代码来源:vhe.py
示例13: forward
def forward(self, x):
x = torch.index_select(x, 1, Variable(self.index))
x = self.norm(x)
x = self.relu(x)
x = self.conv(x)
x = ShuffleLayer(x, self.groups)
return x
开发者ID:dengshuo,项目名称:CondenseNet,代码行数:7,代码来源:layers.py
示例14: model
def model():
p_latent = pyro.param("p1", Variable(torch.Tensor([[0.7], [0.3]])))
p_obs = pyro.param("p2", Variable(torch.Tensor([[0.9], [0.1]])))
latents = [Variable(torch.ones(1, 1))]
observes = []
for t in range(self.model_steps):
latents.append(
pyro.sample("latent_{}".format(str(t)),
Bernoulli(torch.index_select(p_latent, 0, latents[-1].view(-1).long()))))
observes.append(
pyro.observe("observe_{}".format(str(t)),
Bernoulli(torch.index_select(p_obs, 0, latents[-1].view(-1).long())),
self.data[t]))
return torch.sum(torch.cat(latents))
开发者ID:Magica-Chen,项目名称:pyro,代码行数:17,代码来源:test_sampling.py
示例15: deprocess_img
def deprocess_img(img):
# BGR to RGB
idx = torch.LongTensor([2, 1, 0])
img = torch.index_select(img, 0, idx)
# [-1,1] to [0,1]
img = img.add_(1).div_(2)
return img
开发者ID:vivek231,项目名称:vivek,代码行数:9,代码来源:util.py
示例16: nms
def nms(boxes, scores, overlap=0.5, top_k=100):
keep = scores.new(scores.size(0)).zero_().long()
if boxes.numel() == 0: return keep
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1)
v, idx = scores.sort(0) # sort in ascending order
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()
count = 0
while idx.numel() > 0:
i = idx[-1] # index of current largest val
keep[count] = i
count += 1
if idx.size(0) == 1: break
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)
# store element-wise max with next highest score
xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)
inter = w*h
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter/union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
return keep, count
开发者ID:rohitgeo,项目名称:arcgis-python-api,代码行数:49,代码来源:util.py
示例17: updateOutput
def updateOutput(self, input):
self.output.set_(self.network.forward([input, self.partition]))
if self.bias is not None:
self.output.add_(torch.index_select(self.bias, 1, self.partition).expand_as(self.output))
if self.addBuffer is None:
self.addBuffer = input.new()
if self.addBuffer.nelement() != input.size(0):
self.addBuffer.resize_(input.size(0)).fill_(1)
return self.output
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:10,代码来源:PartialLinear.py
示例18: reverse_sequences_torch
def reverse_sequences_torch(mini_batch, seq_lengths):
reversed_mini_batch = mini_batch.new_zeros(mini_batch.size())
for b in range(mini_batch.size(0)):
T = seq_lengths[b]
time_slice = np.arange(T - 1, -1, -1)
time_slice = torch.cuda.LongTensor(time_slice) if 'cuda' in mini_batch.data.type() \
else torch.LongTensor(time_slice)
reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
reversed_mini_batch[b, 0:T, :] = reversed_sequence
return reversed_mini_batch
开发者ID:lewisKit,项目名称:pyro,代码行数:10,代码来源:polyphonic_data_loader.py
示例19: non_match_descriptor_loss
def non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, M=0.5, invert=False):
"""
Computes the max(0, M - D(I_a,I_b,u_a,u_b))^2 term
This is effectively: "a and b should be AT LEAST M away from each other"
With invert=True, this is: "a and b should be AT MOST M away from each other"
:param image_a_pred: Output of DCN network on image A.
:type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
:param image_b_pred: same as image_a_pred
:type image_b_pred:
:param non_matches_a: torch.Variable(torch.FloatTensor) has shape [num_non_matches,], a (u,v) pair is mapped
to (u,v) ---> image_width * v + u, this matches the shape of image_a_pred
:type non_matches_a: torch.Variable(torch.FloatTensor)
:param non_matches_b: same as non_matches_a
:param M: the margin
:type M: float
:return: torch.FloatTensor with shape torch.Shape([num_non_matches])
:rtype:
"""
non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a).squeeze()
non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b).squeeze()
# crazily enough, if there is only one element to index_select into
# above, then the first dimension is collapsed down, and we end up
# with shape [D,], where we want [1,D]
# this unsqueeze fixes that case
if len(non_matches_a) == 1:
non_matches_a_descriptors = non_matches_a_descriptors.unsqueeze(0)
non_matches_b_descriptors = non_matches_b_descriptors.unsqueeze(0)
norm_degree = 2
non_match_loss = (non_matches_a_descriptors - non_matches_b_descriptors).norm(norm_degree, 1)
if not invert:
non_match_loss = torch.clamp(M - non_match_loss, min=0).pow(2)
else:
non_match_loss = torch.clamp(non_match_loss - M, min=0).pow(2)
hard_negative_idxs = torch.nonzero(non_match_loss)
num_hard_negatives = len(hard_negative_idxs)
return non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:43,代码来源:pixelwise_contrastive_loss.py
示例20: crop1d
def crop1d(x,cutoff,dim):
'''Crops tensor x by cutoff elements from the beginning and the end along dimension dim.
Example:
x=torch.FloatTensor([1,2,3,4,5,6,7,8]).cuda(1)
crop1d(x,2,0) '''
idx = torch.arange(cutoff, x.shape[dim]-cutoff).long()
if x.is_cuda:
dev = x.get_device()
idx = idx.cuda(dev)
return torch.index_select(x,dim,idx)
开发者ID:Apogentus,项目名称:common,代码行数:10,代码来源:tensor.py
注:本文中的torch.index_select函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论