import torch
import torch.nn as nn
class Multi_Head_Attention(nn.Module):
def __init__(self, dim, num_heads=8, attn_drop=0.5, proj_drop=0.5):
super(Multi_Head_Attention, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3)
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
assert self.scale > 0
self.softmax = nn.Softmax(dim=-1)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = self.softmax((q @ k.transpose(-2, -1)) * self.scale)
attn = self.attn_drop(attn)
print(attn.shape)
res = (attn @ v).transpose(-2, -1).reshape(B, N, C)
res = self.proj_drop(self.proj(res))
return res
x = torch.randn(size=(64, 10, 512))
att = Multi_Head_Attention(dim=512)
print(att(x).shape)