Pytorch 代码里看到 @ 和 * 两个操作符,其功能如下.

1. 矩阵相乘@

示例如,

import torch

#
X = torch.tensor([[1,2],[3,4],[5,6]]) #torch.Size([3, 2])
#tensor([[1, 2],
#        [3, 4],
#        [5, 6]])
Y = torch.tensor([[7,8],[9, 10]]) #torch.Size([2, 2])
#tensor([[5, 6],
#        [7, 8]])

Z = X@Y #torch.Size([3, 2])
#tensor([[ 25,  28],
#        [ 57,  64],
#        [ 89, 100]])

#等价于:
Z = torch.matmul(X, Y)

2. 矩阵逐元素相乘*

示例如,

import torch

#
X = torch.tensor([[1,2],[3,4],[5,6]]) #torch.Size([3, 2])
#tensor([[1, 2],
#        [3, 4],
#        [5, 6]])
Y = torch.tensor([[7,8],[9, 10],[11,12]]) #torch.Size([3, 2])
#tensor([[ 7,  8],
#        [ 9, 10],
#        [11, 12]])

Z = X*Y
#tensor([[ 7, 16],
#        [27, 40],
#        [55, 72]])

#等价于:
Z = torch.mul(X, Y)
Last modification:June 8th, 2021 at 09:18 pm