|
|
@@ -46,11 +46,8 @@ class MultiScaleAttention(nn.Module):
|
|
|
|
|
|
self.dim = dim
|
|
|
self.dim_out = dim_out
|
|
|
-
|
|
|
self.num_heads = num_heads
|
|
|
- head_dim = dim_out // num_heads
|
|
|
- self.scale = head_dim**-0.5
|
|
|
-
|
|
|
+
|
|
|
self.q_pool = q_pool
|
|
|
self.qkv = nn.Linear(dim, dim_out * 3)
|
|
|
self.proj = nn.Linear(dim_out, dim_out)
|