nn.transformer

发布时间 2023-11-26 22:42:36作者: 黑逍逍

 

torch上给的案例


transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) # 创建一个具有16个注意力头和12个编码器层的Transformer模型

src = torch.rand((10, 32, 512))# 创建一个形状为 (10, 32, 512) 的随机输入张量,代表序列的编码器输入

tgt = torch.rand((20, 32, 512)) # 创建一个形状为 (20, 32, 512) 的随机输入张量,代表序列的解码器输入

out = transformer_model(src, tgt)# 将输入张量传递给Transformer模型进行前向传播