lukeingawesome commited on
Commit
850592d
·
verified ·
1 Parent(s): f1a89e8

Add vendored pooling_latent.py module

Browse files
Files changed (1) hide show
  1. pooling_latent.py +91 -0
pooling_latent.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Latent Attention Pooling implementation for LLM2Vec4CXR.
3
+ Vendored to make the model self-contained (no external llm2vec dependency required).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class LatentAttentionPooling(nn.Module):
12
+ """
13
+ Latent attention pooling layer that uses a trainable latent dictionary
14
+ to aggregate token embeddings into a fixed-size representation.
15
+ """
16
+
17
+ def __init__(self, d_model, num_latents=512, num_heads=8):
18
+ """
19
+ Args:
20
+ d_model: Hidden size of the model (e.g., 2048 for Llama-7B)
21
+ num_latents: Number of learnable latent vectors (default: 512)
22
+ num_heads: Number of attention heads (default: 8)
23
+ """
24
+ super().__init__()
25
+ self.num_latents = num_latents
26
+ self.d_model = d_model
27
+
28
+ # Trainable latent dictionary (used as both keys and values)
29
+ self.latents = nn.Parameter(torch.randn(num_latents, d_model))
30
+
31
+ # Multihead attention layer
32
+ # batch_first=True means input shape is (batch, seq_length, hidden_size)
33
+ self.multihead_attn = nn.MultiheadAttention(
34
+ embed_dim=d_model,
35
+ num_heads=num_heads,
36
+ batch_first=True
37
+ )
38
+
39
+ # Simple MLP: Linear -> GELU -> Linear
40
+ self.mlp = nn.Sequential(
41
+ nn.Linear(d_model, d_model),
42
+ nn.GELU(),
43
+ nn.Linear(d_model, d_model)
44
+ )
45
+
46
+ def forward(self, hidden_states, attention_mask=None):
47
+ """
48
+ Apply latent attention pooling to hidden states.
49
+
50
+ Args:
51
+ hidden_states: Token embeddings of shape (batch_size, seq_len, d_model)
52
+ attention_mask: Optional mask of shape (batch_size, seq_len)
53
+
54
+ Returns:
55
+ Pooled embeddings of shape (batch_size, d_model)
56
+ """
57
+ batch_size, seq_len, d_model = hidden_states.shape
58
+ device = hidden_states.device
59
+
60
+ # Ensure the module is on the same device as input
61
+ if next(self.parameters()).device != device:
62
+ self.to(device)
63
+
64
+ # Expand latents to match batch size: (batch_size, num_latents, d_model)
65
+ latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
66
+
67
+ # Apply multihead attention
68
+ # Use hidden_states as queries and latent dictionary as keys/values
69
+ # This computes: O = softmax((QK^T)/√d)V
70
+ attn_output, _ = self.multihead_attn(
71
+ query=hidden_states,
72
+ key=latents,
73
+ value=latents
74
+ )
75
+
76
+ # Apply MLP to attention output
77
+ mlp_output = self.mlp(attn_output)
78
+
79
+ # Mean pool over sequence dimension
80
+ if attention_mask is not None:
81
+ # Mask out padding tokens before pooling
82
+ mask_expanded = attention_mask.unsqueeze(-1).expand(mlp_output.size()).float()
83
+ sum_embeddings = torch.sum(mlp_output * mask_expanded, dim=1)
84
+ sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
85
+ pooled = sum_embeddings / sum_mask
86
+ else:
87
+ # Simple mean pooling if no mask provided
88
+ pooled = mlp_output.mean(dim=1)
89
+
90
+ return pooled
91
+