lll2343 commited on
Commit
25a2ede
·
verified ·
1 Parent(s): d7ac46a

Update attn_mask_utils.py

Browse files
Files changed (1) hide show
  1. attn_mask_utils.py +52 -1
attn_mask_utils.py CHANGED
@@ -30,7 +30,6 @@ def find_prefix_seq_length_by_pe(
30
  return seq_len
31
 
32
 
33
-
34
  def update_causal_mask_with_pad_non_visible_2d(
35
  input_ids: torch.Tensor,
36
  attn_mask_2d: torch.Tensor,
@@ -96,6 +95,58 @@ def update_causal_mask_with_pad_non_visible_2d(
96
  return attn_mask_2d
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def update_causal_mask_for_one_gen_window_2d(
100
  input_ids: torch.Tensor,
101
  attn_mask_2d: torch.Tensor,
 
30
  return seq_len
31
 
32
 
 
33
  def update_causal_mask_with_pad_non_visible_2d(
34
  input_ids: torch.Tensor,
35
  attn_mask_2d: torch.Tensor,
 
95
  return attn_mask_2d
96
 
97
 
98
+ def update_causal_mask_with_pad_non_visible_2d_for_ssd_cache(
99
+ input_ids: torch.Tensor,
100
+ attn_mask_2d: torch.Tensor,
101
+ block_size: int = 4,
102
+ use_cache: bool = True,
103
+ causal_attn: bool = False
104
+ ) -> torch.Tensor:
105
+ """
106
+ Updates a 2D attention mask for Self-Speculative Decoding generate
107
+
108
+ Details is avaliabe in Appendix B Figure 5.
109
+
110
+ Args:
111
+ input_ids: Input token IDs (unused in current implementation)
112
+ attn_mask_2d: 2D attention mask matrix of shape [seq_len, seq_len] where:
113
+ - 0.0 indicates allowed attention
114
+ - -inf indicates masked attention
115
+ block_size: Size of the diffusion window
116
+ use_cache: Whether key-value cache is being used
117
+ causal_attn: If True, maintains strict causal masking throughout
118
+
119
+ Returns:
120
+ Modified attention mask with updated visibility patterns
121
+ """
122
+
123
+ q_len, kv_len = attn_mask_2d.shape
124
+
125
+ if q_len == kv_len:
126
+ # prefill
127
+ return update_causal_mask_for_one_gen_window_2d(
128
+ input_ids = input_ids,
129
+ attn_mask_2d = attn_mask_2d,
130
+ block_size = block_size,
131
+ use_cache = use_cache,
132
+ causal_attn = causal_attn
133
+ )
134
+
135
+ start_ix = q_len - block_size
136
+ start_jx = kv_len - block_size
137
+ for ix in range(block_size-1, -1, -1):
138
+ attn_mask_2d[start_ix:start_ix+block_size, start_jx:start_jx+block_size] = 0.0
139
+ attn_mask_2d[start_ix+block_size:, start_jx-ix:start_jx+block_size] = -float('inf')
140
+
141
+ start_ix = start_ix - ix - block_size
142
+ start_jx = start_jx - ix - block_size
143
+
144
+ attn_mask_2d[start_ix+block_size:, start_jx+block_size-1] = -float('inf')
145
+
146
+ return attn_mask_2d
147
+
148
+
149
+
150
  def update_causal_mask_for_one_gen_window_2d(
151
  input_ids: torch.Tensor,
152
  attn_mask_2d: torch.Tensor,