Si encuentras alguna parte que no comprendes puedes comentarlo y te ayudaremos rápidamente.
Solución:
Estas buscando
torch.mm(a,b)
Tenga en cuenta que torch.dot()
se comporta diferente a np.dot()
. Ha habido cierta discusión sobre lo que sería deseable aquí. Específicamente, torch.dot()
trata a ambos a
y b
como vectores 1D (independientemente de su forma original) y calcula su producto interno. Se lanza el error, porque este comportamiento hace que su a
un vector de longitud 6 y su b
un vector de longitud 2; por lo tanto, su producto interno no se puede calcular. Para la multiplicación de matrices en PyTorch, use torch.mm()
. de Numpy np.dot()
en cambio es más flexible; calcula el producto interno para arreglos 1D y realiza la multiplicación de matrices para arreglos 2D.
Por demanda popular, la función torch.matmul
realiza multiplicaciones de matrices si ambos argumentos son 2D
y calcula su producto escalar si ambos argumentos son 1D
. Para entradas de tales dimensiones, su comportamiento es el mismo que np.dot
. También le permite hacer transmisiones o matrix x matrix
, matrix x vector
y vector x vector
operaciones por lotes. Para obtener más información, consulte sus documentos.
# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])
# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])
Si desea hacer una multiplicación de matrices (tensor de rango 2), puede hacerlo de cuatro maneras equivalentes:
AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or
AB = torch.matmul(A, B)
# or, even simpler
AB = A @ B # Python 3.5+
Hay algunas sutilezas. De la documentación de PyTorch:
torch.mm no transmite. Para productos de matriz de difusión, consulte torch.matmul().
Por ejemplo, no puedes multiplicar dos vectores unidimensionales con torch.mm
, ni multiplicar matrices por lotes (rango 3). Para ello, debe utilizar el más versátil torch.matmul
. Para obtener una lista extensa de los comportamientos de transmisión de torch.matmul
consulte la documentación.
Para la multiplicación por elementos, simplemente puede hacer (si A y B tienen la misma forma)
A * B # element-wise matrix multiplication (Hadamard product)
Usar torch.mm(a, b)
o torch.matmul(a, b)
Ambos son lo mismo.
>>> torch.mm
>>> torch.matmul
Hay una opción más que puede ser bueno saber. Es decir @
operador. @Simón H.
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> [email protected]
tensor([[ 0.6176, -0.6743, 0.5989, -0.1390],
[ 0.8699, -0.3445, 1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743, 0.5989, -0.1390],
[ 0.8699, -0.3445, 1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743, 0.5989, -0.1390],
[ 0.8699, -0.3445, 1.4122, -0.5826]])
Los tres dan los mismos resultados.
Enlaces relacionados:
Operador de multiplicación de matrices
PEP 465 — Un operador infijo dedicado para la multiplicación de matrices
Tienes la posibilidad dar difusión a este post si te ayudó.