pytorch gather b2 = a.gather(1, b.view(-1, 1))

发布时间 2023-03-25 19:27:37作者: 无左无右
import torch

a = torch.randint(0, 100, (6,3))
b = torch.Tensor([0, 1, 1, 2, 0, 2]).long()
b = b.unsqueeze(1)
b0 = b.view(-1, 1)

b2 = a.gather(1, b.view(-1, 1))

print(a)
print(a.shape)
print(b)
print(b.shape)

print(b2)

输出

tensor([[ 9, 10, 79],
        [98, 43,  2],
        [94, 82, 24],
        [93, 72,  3],
        [30, 29, 86],
        [94, 25,  4]])
torch.Size([6, 3])
tensor([[0],
        [1],
        [1],
        [2],
        [0],
        [2]])
torch.Size([6, 1])
tensor([[ 9],
        [43],
        [82],
        [ 3],
        [30],
        [ 4]])

pytorch ssd里面gather的用法

       # Compute max conf across batch for hard negative mining
        #conf_data [3,8732,21]  batch_conf[3*8732,21]  [26196,21]
        batch_conf = conf_data.view(-1, self.num_classes)  #batch_conf [26196,21]
        b1 = log_sum_exp(batch_conf) #[26196,1]
        b00 = conf_t.view(-1, 1) #[26196, 1]
        b2 = batch_conf.gather(1, conf_t.view(-1, 1)) #[26196,1]

        #loss_c1 = F.cross_entropy(batch_conf, conf_t.view(-1))

        #loss_c[26196,1]    #https://zhuanlan.zhihu.com/p/153535799
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))