• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python torch.index_select函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中torch.index_select函数的典型用法代码示例。如果您正苦于以下问题:Python index_select函数的具体用法?Python index_select怎么用?Python index_select使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了index_select函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: forward

 def forward(self, sents, sent_lengths):
     '''
         sents is (batch_size by padded_length)
         when we evaluate sentence by sentence, you evaluate it with batch_size = 1, padded_length.
         [[1, 2, 3, 4]] etc. 
     '''
     batch_size = sents.size()[0]
     sent_lengths = list(sent_lengths)
     # We sort and then do pad packed sequence here. 
     descending_lengths = [x for x, _ in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
     descending_indices = [x for _, x in sorted(zip(sent_lengths, range(len(sent_lengths))), reverse=True)]
     descending_lengths = torch.tensor(descending_lengths)
     descending_indices = torch.tensor(descending_indices).to(device)
     descending_sents = torch.index_select(sents, torch.tensor(0), descending_indices)
     
     # get embedding
     embed = self.embedding(descending_sents)
     # pack padded sequence
     embed = torch.nn.utils.rnn.pack_padded_sequence(embed, descending_lengths, batch_first=True)
     
     # fprop though RNN
     self.hidden = self.init_hidden(batch_size)
     rnn_out, self.hidden = self.gru(embed, self.hidden)
     pdb.set_trace()
     rnn_out, _ = torch.nn.utils.rnn.pad_packed_sequence(rnn_out, batch_first=True)
     # rnn_out is 32 by 72 by 256
     
     # change the order back
     change_it_back = [x for _, x in sorted(zip(descending_indices, range(len(descending_indices))))]
     self.hidden = torch.index_select(self.hidden, 1, torch.LongTensor(change_it_back).to(device))  
     rnn_out = torch.index_select(rnn_out, 0, torch.LongTensor(change_it_back).to(device)) 
     
     return rnn_out, self.hidden
开发者ID:vwrj,项目名称:neural_machine_translation,代码行数:33,代码来源:V2-Attention-Vish.py


示例2: get_triplet_loss

    def get_triplet_loss(image_a_pred, image_b_pred, matches_a, matches_b, non_matches_a, non_matches_b, alpha):
        """
        Computes the loss function

        \sum_{triplets} ||D(I_a, u_a, I_b, u_{b,match})||_2^2 - ||D(I_a, u_a, I_b, u_{b,non-match)||_2^2 + alpha 

        """
        num_matches = matches_a.size()[0]
        num_non_matches = non_matches_a.size()[0]
        multiplier = num_non_matches / num_matches

        ## non_matches_a is already replicated up to be the right size
        ## non_matches_b is also that side
        ## matches_a is just a smaller version of non_matches_a
        ## matches_b is the only thing that needs to be replicated up in size

        matches_b_long =  torch.t(matches_b.repeat(multiplier, 1)).contiguous().view(-1)
                         
        matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
        matches_b_descriptors      = torch.index_select(image_b_pred, 1, matches_b_long)
        non_matches_b_descriptors  = torch.index_select(image_b_pred, 1, non_matches_b)

        triplet_losses = (matches_a_descriptors - matches_b_descriptors).pow(2) - (matches_a_descriptors - non_matches_b_descriptors).pow(2) + alpha
        triplet_loss = 1.0 / num_non_matches * torch.clamp(triplet_losses, min=0).sum()

        return triplet_loss
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:26,代码来源:pixelwise_contrastive_loss.py


示例3: get_loss

    def get_loss(self, image_a_pred, image_b_pred, mask_a, mask_b):
        loss = 0

        # get the nonzero indices
        mask_a_indices_flat = torch.nonzero(mask_a)
        mask_b_indices_flat = torch.nonzero(mask_b)
        if len(mask_a_indices_flat) == 0:
            return Variable(torch.cuda.LongTensor([0]), requires_grad=True)
        if len(mask_b_indices_flat) == 0:
            return Variable(torch.cuda.LongTensor([0]), requires_grad=True)

        # take 5000 random pixel samples of the object, using the mask
        num_samples = 10000

        rand_numbers_a = (torch.rand(num_samples)*len(mask_a_indices_flat)).cuda()
        rand_indices_a = Variable(torch.floor(rand_numbers_a).type(torch.cuda.LongTensor), requires_grad=False)
        randomized_mask_a_indices_flat = torch.index_select(mask_a_indices_flat, 0, rand_indices_a).squeeze(1)

        rand_numbers_b = (torch.rand(num_samples)*len(mask_b_indices_flat)).cuda()
        rand_indices_b = Variable(torch.floor(rand_numbers_b).type(torch.cuda.LongTensor), requires_grad=False)
        randomized_mask_b_indices_flat = torch.index_select(mask_b_indices_flat, 0, rand_indices_b).squeeze(1)

        # index into the image and get descriptors
        M_margin = 0.5 # margin parameter
        random_img_a_object_descriptors = torch.index_select(image_a_pred, 1, randomized_mask_a_indices_flat)
        random_img_b_object_descriptors = torch.index_select(image_b_pred, 1, randomized_mask_b_indices_flat)
        pixel_wise_loss = (random_img_a_object_descriptors - random_img_b_object_descriptors).pow(2).sum(dim=2)
        pixel_wise_loss = torch.add(pixel_wise_loss, -2*M_margin)
        zeros_vec = torch.zeros_like(pixel_wise_loss)
        loss += torch.max(zeros_vec, pixel_wise_loss).sum()

        return loss
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:32,代码来源:semantic_consistency_loss.py


示例4: forward

    def forward(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
                sparse=False):

        ctx.padding_idx = padding_idx
        ctx.scale_grad_by_freq = scale_grad_by_freq
        ctx._indices = None
        ctx.sparse = sparse

        assert indices.dim() <= 2
        assert not ctx.needs_input_grad[0], "Embedding doesn't " \
            "compute the gradient w.r.t. the indices"

        ctx._backend = type2backend[type(weight)]
        ctx._weight_size = weight.size()

        if not indices.is_contiguous():
            ctx._indices = indices.contiguous()
            indices = ctx._indices
        else:
            ctx.save_for_backward(indices)

        output = weight.new()
        if max_norm is not None:
            cls._renorm(ctx, indices, weight, max_norm, norm_type)

        if indices.dim() == 1:
            output = torch.index_select(weight, 0, indices)
        else:
            output = torch.index_select(weight, 0, indices.view(-1))
            output = output.view(indices.size(0), indices.size(1), weight.size(1))

        return output
开发者ID:Northrend,项目名称:pytorch,代码行数:32,代码来源:sparse.py


示例5: dual_OT_model

    def dual_OT_model(self, Xs_batch, i_t):

        batch_size = i_t.shape[0]
        u_batch = self.u(Xs_batch)
        v_batch = torch.index_select(self.v, dim=0, index=i_t)
        Xt_batch = torch.index_select(self.Xt, dim=0, index=i_t)

        return self.dual_OT_batch_loss(batch_size=batch_size, u_batch=u_batch, v_batch=v_batch, Xs_batch=Xs_batch, Xt_batch=Xt_batch)
开发者ID:vivienseguy,项目名称:Large-Scale-OT,代码行数:8,代码来源:StochasticOTSemiDiscrete.py


示例6: get_loss_original

    def get_loss_original(self, image_a_pred, image_b_pred, matches_a,
                          matches_b, non_matches_a, non_matches_b,
                          M_margin=0.5, non_match_loss_weight=1.0):

        # this is pegged to it's implemenation at sha 87abdb63bb5b99d9632f5c4360b5f6f1cf54245f
        """
        Computes the loss function
        DCN = Dense Correspondence Network
        num_images = number of images in this batch
        num_matches = number of matches
        num_non_matches = number of non-matches
        W = image width
        H = image height
        D = descriptor dimension
        match_loss = 1/num_matches \sum_{num_matches} ||descriptor_a - descriptor_b||_2^2
        non_match_loss = 1/num_non_matches \sum_{num_non_matches} max(0, M_margin - ||descriptor_a - descriptor_b||_2^2 )
        loss = match_loss + non_match_loss
        :param image_a_pred: Output of DCN network on image A.
        :type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
        :param image_b_pred: same as image_a_pred
        :type image_b_pred:
        :param matches_a: torch.Variable(torch.LongTensor) has shape [num_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of one dimension of image_a_pred
        :type matches_a: torch.Variable(torch.FloatTensor)
        :param matches_b: same as matches_b
        :type matches_b:
        :param non_matches_a: torch.Variable(torch.FloatTensor) has shape [num_non_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of image_a_pred
        :type non_matches_a: torch.Variable(torch.FloatTensor)
        :param non_matches_b: same as non_matches_a
        :type non_matches_b:
        :return: loss, match_loss, non_match_loss
        :rtype: torch.Variable(torch.FloatTensor) each of shape torch.Size([1])
        """

        num_matches = matches_a.size()[0]
        num_non_matches = non_matches_a.size()[0]


        matches_a_descriptors = torch.index_select(image_a_pred, 1, matches_a)
        matches_b_descriptors = torch.index_select(image_b_pred, 1, matches_b)

        match_loss = 1.0/num_matches * (matches_a_descriptors - matches_b_descriptors).pow(2).sum()

        # add loss via non_matches
        non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a)
        non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b)
        pixel_wise_loss = (non_matches_a_descriptors - non_matches_b_descriptors).pow(2).sum(dim=2)
        pixel_wise_loss = torch.add(torch.neg(pixel_wise_loss), M_margin)
        zeros_vec = torch.zeros_like(pixel_wise_loss)
        non_match_loss = non_match_loss_weight * 1.0/num_non_matches * torch.max(zeros_vec, pixel_wise_loss).sum()

        loss = match_loss + non_match_loss

        return loss, match_loss, non_match_loss
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:55,代码来源:pixelwise_contrastive_loss.py


示例7: forward

    def forward(self, base_feat, im_info, gt_boxes, num_boxes):

        batch_size = base_feat.size(0)

        # return feature map after convrelu layer
        rpn_conv1 = F.relu(self.RPN_Conv(base_feat), inplace=True)
        # get rpn classification score
        rpn_cls_score = self.RPN_cls_score(rpn_conv1)

        rpn_cls_score_reshape = self.reshape(rpn_cls_score, 2)
        rpn_cls_prob_reshape = F.softmax(rpn_cls_score_reshape, dim=1)
        rpn_cls_prob = self.reshape(rpn_cls_prob_reshape, self.nc_score_out)

        # get rpn offsets to the anchor boxes
        rpn_bbox_pred = self.RPN_bbox_pred(rpn_conv1)

        # proposal layer
        cfg_key = 'TRAIN' if self.training else 'TEST'

        rois = self.RPN_proposal((rpn_cls_prob.data, rpn_bbox_pred.data,
                                 im_info, cfg_key))

        self.rpn_loss_cls = 0
        self.rpn_loss_box = 0

        # generating training labels and build the rpn loss
        if self.training:
            assert gt_boxes is not None

            rpn_data = self.RPN_anchor_target((rpn_cls_score.data, gt_boxes, im_info, num_boxes))

            # compute classification loss
            rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, 2)
            rpn_label = rpn_data[0].view(batch_size, -1)

            rpn_keep = Variable(rpn_label.view(-1).ne(-1).nonzero().view(-1))
            rpn_cls_score = torch.index_select(rpn_cls_score.view(-1,2), 0, rpn_keep)
            rpn_label = torch.index_select(rpn_label.view(-1), 0, rpn_keep.data)
            rpn_label = Variable(rpn_label.long())
            self.rpn_loss_cls = F.cross_entropy(rpn_cls_score, rpn_label)
            fg_cnt = torch.sum(rpn_label.data.ne(0))

            rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]

            # compute bbox regression loss
            rpn_bbox_inside_weights = Variable(rpn_bbox_inside_weights)
            rpn_bbox_outside_weights = Variable(rpn_bbox_outside_weights)
            rpn_bbox_targets = Variable(rpn_bbox_targets)

            self.rpn_loss_box = _smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights,
                                                            rpn_bbox_outside_weights, sigma=3, dim=[1,2,3])

        return rois, self.rpn_loss_cls, self.rpn_loss_box
开发者ID:lianDaniel,项目名称:R-FCN.pytorch,代码行数:53,代码来源:rpn.py


示例8: barycentric_mapping_loss_model

    def barycentric_mapping_loss_model(self, neuralNet, Xs_batch, i_t):

        self.u.eval()
        self.v.requires_grad_(False)

        u_batch = self.u(Xs_batch)
        v_batch = torch.index_select(self.v, dim=0, index=i_t)
        Xt_batch = torch.index_select(self.Xt, dim=0, index=i_t)

        fXs_batch = neuralNet(Xs_batch)

        return self.barycentric_model_batch_loss(u_batch, v_batch, Xs_batch, Xt_batch, fXs_batch)
开发者ID:vivienseguy,项目名称:Large-Scale-OT,代码行数:12,代码来源:StochasticOTSemiDiscrete.py


示例9: forward

    def forward(self, input, output, input_lens=None, output_lens=None, lookup=None, **kwargs):
        h0 = self.h0.expand(1, input.size(0), self.hidden_dim).contiguous()
        c0 = self.c0.expand(1, input.size(0), self.hidden_dim).contiguous()
        input_encoded, input_h, input_c = self.encoder(input, h0, c0, lens=input_lens)

        if lookup:
            input_h = th.index_select(input_h, 1, lookup)
            input_c = th.index_select(input_c, 1, lookup)
            
        transfer_h, transfer_c = self.transfer(input_h, input_c, **kwargs)
        log_probs, _, _ = self.decoder(output, transfer_h, transfer_c, lens=output_lens)
        return log_probs
开发者ID:douwekiela,项目名称:nncg-negation,代码行数:12,代码来源:models.py


示例10: updateOutput

    def updateOutput(self, input):
        self.renorm(input)
        input = self._makeInputContiguous(input)
        if input.dim() == 1:
            torch.index_select(self.weight, 0, input, out=self.output)
        elif input.dim() == 2:
            torch.index_select(self.weight, 0, input.view(-1), out=self.output)
            self.output = self.output.view(input.size(0), input.size(1), self.weight.size(1))
        else:
            raise RuntimeError("input must be a vector or matrix")

        return self.output
开发者ID:Northrend,项目名称:pytorch,代码行数:12,代码来源:LookupTable.py


示例11: random_sample_from_masked_image_torch

def random_sample_from_masked_image_torch(img_mask, num_samples):
    """

    :param img_mask: Numpy array [H,W] or torch.Tensor with shape [H,W]
    :type img_mask:
    :param num_samples: an integer
    :type num_samples:
    :return: tuple of torch.LongTensor in (u,v) format. Each torch.LongTensor has shape
    [num_samples]
    :rtype:
    """

    image_height, image_width = img_mask.shape

    if isinstance(img_mask, np.ndarray):
        img_mask_torch = torch.from_numpy(img_mask).float()
    else:
        img_mask_torch = img_mask

    # This code would randomly subsample from the mask
    mask = img_mask_torch.view(image_width*image_height,1).squeeze(1)
    mask_indices_flat = torch.nonzero(mask)
    if len(mask_indices_flat) == 0:
        return (None, None)

    rand_numbers = torch.rand(num_samples)*len(mask_indices_flat)
    rand_indices = torch.floor(rand_numbers).long()
    uv_vec_flattened = torch.index_select(mask_indices_flat, 0, rand_indices).squeeze(1)
    uv_vec = utils.flattened_pixel_locations_to_u_v(uv_vec_flattened, image_width)
    return uv_vec
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:30,代码来源:correspondence_finder.py


示例12: next

    def next(self):
        if self.next_i+self.batch_size > len(self.data):
            raise StopIteration()
        else:
            x_idx = self.x_idx[self.next_i:self.next_i+self.batch_size]
            self.next_i += self.batch_size

        labels = {k: torch.index_select(self.labels[k], 0, x_idx) for k in self.labels}
        x = self.select_data(x_idx)

        inputs = {}
        sizes = {} 
        for k,v in labels.items():
            possibilities = [self.label_idxs[k][v[i].item()] for i in range(len(x_idx))]
            sizes[k] = [len(X) for X in possibilities]
            input_idx = [np.random.choice(X, size=self.k_shot[k]) for X in possibilities]
            _inputs = [
                self.select_data(torch.LongTensor([I[j] for I in input_idx]))
                for j in range(self.k_shot[k])]
            if self.mode == "tensor":
                inputs[k] = torch.cat([x.unsqueeze(1) for x in _inputs], dim=1)
            elif self.mode == "list":
                inputs[k] = [[_inputs[j][i] for j in range(self.k_shot[k])]
                        for i in range(len(_inputs[0]))]

        batch = VHEBatch(target=x, inputs=inputs, sizes=sizes)
        for transform in self.transforms:
            batch = transform.apply(batch)
        return batch
开发者ID:insperatum,项目名称:vhe,代码行数:29,代码来源:vhe.py


示例13: forward

 def forward(self, x):
     x = torch.index_select(x, 1, Variable(self.index))
     x = self.norm(x)
     x = self.relu(x)
     x = self.conv(x)
     x = ShuffleLayer(x, self.groups)
     return x
开发者ID:dengshuo,项目名称:CondenseNet,代码行数:7,代码来源:layers.py


示例14: model

        def model():
            p_latent = pyro.param("p1", Variable(torch.Tensor([[0.7], [0.3]])))
            p_obs = pyro.param("p2", Variable(torch.Tensor([[0.9], [0.1]])))

            latents = [Variable(torch.ones(1, 1))]
            observes = []
            for t in range(self.model_steps):

                latents.append(
                    pyro.sample("latent_{}".format(str(t)),
                                Bernoulli(torch.index_select(p_latent, 0, latents[-1].view(-1).long()))))

                observes.append(
                    pyro.observe("observe_{}".format(str(t)),
                                 Bernoulli(torch.index_select(p_obs, 0, latents[-1].view(-1).long())),
                                 self.data[t]))
            return torch.sum(torch.cat(latents))
开发者ID:Magica-Chen,项目名称:pyro,代码行数:17,代码来源:test_sampling.py


示例15: deprocess_img

def deprocess_img(img):
    # BGR to RGB
    idx = torch.LongTensor([2, 1, 0])
    img = torch.index_select(img, 0, idx)

    # [-1,1] to [0,1]
    img = img.add_(1).div_(2)

    return img
开发者ID:vivek231,项目名称:vivek,代码行数:9,代码来源:util.py


示例16: nms

def nms(boxes, scores, overlap=0.5, top_k=100):
    keep = scores.new(scores.size(0)).zero_().long()
    if boxes.numel() == 0: return keep
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]
    area = torch.mul(x2 - x1, y2 - y1)
    v, idx = scores.sort(0)  # sort in ascending order
    idx = idx[-top_k:]  # indices of the top-k largest vals
    xx1 = boxes.new()
    yy1 = boxes.new()
    xx2 = boxes.new()
    yy2 = boxes.new()
    w = boxes.new()
    h = boxes.new()

    count = 0
    while idx.numel() > 0:
        i = idx[-1]  # index of current largest val
        keep[count] = i
        count += 1
        if idx.size(0) == 1: break
        idx = idx[:-1]  # remove kept element from view
        # load bboxes of next highest vals
        torch.index_select(x1, 0, idx, out=xx1)
        torch.index_select(y1, 0, idx, out=yy1)
        torch.index_select(x2, 0, idx, out=xx2)
        torch.index_select(y2, 0, idx, out=yy2)
        # store element-wise max with next highest score
        xx1 = torch.clamp(xx1, min=x1[i])
        yy1 = torch.clamp(yy1, min=y1[i])
        xx2 = torch.clamp(xx2, max=x2[i])
        yy2 = torch.clamp(yy2, max=y2[i])
        w.resize_as_(xx2)
        h.resize_as_(yy2)
        w = xx2 - xx1
        h = yy2 - yy1
        # check sizes of xx1 and xx2.. after each iteration
        w = torch.clamp(w, min=0.0)
        h = torch.clamp(h, min=0.0)
        inter = w*h
        # IoU = i / (area(a) + area(b) - i)
        rem_areas = torch.index_select(area, 0, idx)  # load remaining areas)
        union = (rem_areas - inter) + area[i]
        IoU = inter/union  # store result in iou
        # keep only elements with an IoU <= overlap
        idx = idx[IoU.le(overlap)]
    return keep, count
开发者ID:rohitgeo,项目名称:arcgis-python-api,代码行数:49,代码来源:util.py


示例17: updateOutput

    def updateOutput(self, input):
        self.output.set_(self.network.forward([input, self.partition]))
        if self.bias is not None:
            self.output.add_(torch.index_select(self.bias, 1, self.partition).expand_as(self.output))
            if self.addBuffer is None:
                self.addBuffer = input.new()
            if self.addBuffer.nelement() != input.size(0):
                self.addBuffer.resize_(input.size(0)).fill_(1)

        return self.output
开发者ID:Jsmilemsj,项目名称:pytorch,代码行数:10,代码来源:PartialLinear.py


示例18: reverse_sequences_torch

def reverse_sequences_torch(mini_batch, seq_lengths):
    reversed_mini_batch = mini_batch.new_zeros(mini_batch.size())
    for b in range(mini_batch.size(0)):
        T = seq_lengths[b]
        time_slice = np.arange(T - 1, -1, -1)
        time_slice = torch.cuda.LongTensor(time_slice) if 'cuda' in mini_batch.data.type() \
            else torch.LongTensor(time_slice)
        reversed_sequence = torch.index_select(mini_batch[b, :, :], 0, time_slice)
        reversed_mini_batch[b, 0:T, :] = reversed_sequence
    return reversed_mini_batch
开发者ID:lewisKit,项目名称:pyro,代码行数:10,代码来源:polyphonic_data_loader.py


示例19: non_match_descriptor_loss

    def non_match_descriptor_loss(image_a_pred, image_b_pred, non_matches_a, non_matches_b, M=0.5, invert=False):
        """
        Computes the max(0, M - D(I_a,I_b,u_a,u_b))^2 term

        This is effectively:       "a and b should be AT LEAST M away from each other"
        With invert=True, this is: "a and b should be AT MOST  M away from each other" 

         :param image_a_pred: Output of DCN network on image A.
        :type image_a_pred: torch.Variable(torch.FloatTensor) shape [1, W * H, D]
        :param image_b_pred: same as image_a_pred
        :type image_b_pred:
        :param non_matches_a: torch.Variable(torch.FloatTensor) has shape [num_non_matches,],  a (u,v) pair is mapped
        to (u,v) ---> image_width * v + u, this matches the shape of image_a_pred
        :type non_matches_a: torch.Variable(torch.FloatTensor)
        :param non_matches_b: same as non_matches_a
        :param M: the margin
        :type M: float
        :return: torch.FloatTensor with shape torch.Shape([num_non_matches])
        :rtype:
        """

        non_matches_a_descriptors = torch.index_select(image_a_pred, 1, non_matches_a).squeeze()
        non_matches_b_descriptors = torch.index_select(image_b_pred, 1, non_matches_b).squeeze()

        # crazily enough, if there is only one element to index_select into
        # above, then the first dimension is collapsed down, and we end up 
        # with shape [D,], where we want [1,D]
        # this unsqueeze fixes that case
        if len(non_matches_a) == 1:
            non_matches_a_descriptors = non_matches_a_descriptors.unsqueeze(0)
            non_matches_b_descriptors = non_matches_b_descriptors.unsqueeze(0)

        norm_degree = 2
        non_match_loss = (non_matches_a_descriptors - non_matches_b_descriptors).norm(norm_degree, 1)
        if not invert:
            non_match_loss = torch.clamp(M - non_match_loss, min=0).pow(2)
        else:
            non_match_loss = torch.clamp(non_match_loss - M, min=0).pow(2)

        hard_negative_idxs = torch.nonzero(non_match_loss)
        num_hard_negatives = len(hard_negative_idxs)

        return non_match_loss, num_hard_negatives, non_matches_a_descriptors, non_matches_b_descriptors
开发者ID:shooter2062424,项目名称:pytorch-dense-correspondence,代码行数:43,代码来源:pixelwise_contrastive_loss.py


示例20: crop1d

def crop1d(x,cutoff,dim):
    '''Crops tensor x by cutoff elements from the beginning and the end along dimension dim.
    Example:
    x=torch.FloatTensor([1,2,3,4,5,6,7,8]).cuda(1)
    crop1d(x,2,0) '''    
    idx = torch.arange(cutoff, x.shape[dim]-cutoff).long()
    if x.is_cuda:
        dev = x.get_device()
        idx = idx.cuda(dev)
    return torch.index_select(x,dim,idx)
开发者ID:Apogentus,项目名称:common,代码行数:10,代码来源:tensor.py



注:本文中的torch.index_select函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python torch.is_tensor函数代码示例发布时间:2022-05-27
下一篇:
Python torch.gather函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap