| from typing import Any, Dict, Optional, List |
| import torch |
| from transformers import GenerationMixin |
| from transformers import AutoTokenizer |
| import re |
| import traceback |
|
|
|
|
| class WebGenerationMixin(GenerationMixin): |
| def _update_model_kwargs_for_generation( |
| self, |
| outputs, |
| model_kwargs: Dict[str, Any], |
| is_encoder_decoder: bool = False, |
| standardize_cache_format: bool = False, |
| ) -> Dict[str, Any]: |
| |
| |
| model_kwargs["past_key_values"] = self._extract_past_from_model_output( |
| outputs, standardize_cache_format=standardize_cache_format |
| ) |
| if getattr(outputs, "state", None) is not None: |
| model_kwargs["state"] = outputs.state |
|
|
| |
| if "token_type_ids" in model_kwargs: |
| token_type_ids = model_kwargs["token_type_ids"] |
| model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) |
|
|
| if not is_encoder_decoder: |
| |
| if 'web_attention_mask' not in model_kwargs: |
| attention_mask = model_kwargs["attention_mask"] |
| model_kwargs['web_attention_mask'] = torch.tril(torch.ones((attention_mask.shape[-1], attention_mask.shape[-1]), dtype = attention_mask.dtype)).unsqueeze(0) |
| |
| if "attention_mask" in model_kwargs: |
| attention_mask = model_kwargs["attention_mask"] |
| model_kwargs["attention_mask"] = torch.cat( |
| [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 |
| ) |
| |
| model_kwargs['html_tree'] = outputs.html_tree |
| |
| else: |
| |
| if "decoder_attention_mask" in model_kwargs: |
| decoder_attention_mask = model_kwargs["decoder_attention_mask"] |
| model_kwargs["decoder_attention_mask"] = torch.cat( |
| [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))], |
| dim=-1, |
| ) |
|
|
| if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: |
| model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
| return model_kwargs |
|
|
| def _reorder_cache(self, past_key_values, beam_idx): |
| raise NotImplementedError( |
| f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" |
| f" enable beam search for {self.__class__}" |
| ) |
| |
|
|
| class TreeNode(): |
| def __init__(self,content: list, idx: int): |
| self.open_tag: List[str] = content |
| self.end_tag: Optional[List[str]] = None |
| self.self_closing_tag: Optional[List[str]] = None |
| self.text = "" |
| |
| self.name: Optional[str] = None |
| self.parent: Optional['TreeNode'] = None |
| |
| self.open_tag_range: Optional[List[int]] = None |
| self.end_tag_range: Optional[List[int]] = None |
| self.text_range = [-1,-1] |
| self.self_closing_tag_range = [-1,-1] |
| |
| self.idx: int = idx |
| self.children: List['TreeNode'] = [] |
| |
| |
| def partially_open(self): |
| if not self.open_tag: return False |
| if any('<' in s for s in self.open_tag) and not any('>' in s for s in self.open_tag): |
| return True |
| return False |
| |
| def add_child(self,child): |
| assert child.parent is None, "Child already has a parent" |
| assert child not in self.children, "Child is already in children list" |
| child.parent = self |
| self.children.append(child) |
| |
| def get_range(self): |
| if self.text: |
| return list(range(*self.text_range)) |
| elif self.self_closing_tag: |
| return list(range(*self.self_closing_tag_range)) |
| else: |
| attn_range = [] |
| if self.open_tag_range: |
| attn_range += list(range(*self.open_tag_range)) |
| if self.end_tag_range: |
| attn_range += list(range(*self.end_tag_range)) |
| return attn_range |
| |
| def __repr__(self): |
| return f"Node(name='{self.open_tag}', idx = {self.idx})" |
| |
| def print_tree(self, level=0, input_ids = None, tokenizer = None): |
| if level == 0: |
| print("--------") |
| indent = " " * level |
| if self.text: |
| print(f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()}, level = {level} ") |
| elif self.self_closing_tag: |
| print(f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()}, level = {level} ") |
| elif self.open_tag: |
| print(f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()}, level = {level} ") |
| for child in self.children: |
| child.print_tree(level + 1, input_ids, tokenizer) |
| if self.end_tag: |
| print(f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()}, level = {level} ") |
| else: |
| for child in self.children: |
| child.print_tree(level + 1, input_ids, tokenizer) |
| if level == 0: |
| print("--------") |
|
|
| def get_tree(self, level=0, input_ids = None, tokenizer=None): |
| tree_str = "" |
|
|
| indent = " " * level |
| if self.text: |
| tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.text).strip()} \n" |
| elif self.self_closing_tag: |
| tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.self_closing_tag).strip()} \n" |
| elif self.open_tag: |
| tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.open_tag).strip()} \n" |
| for child in self.children: |
| tree_str+=child.get_tree(level + 1, input_ids, tokenizer) |
| if self.end_tag: |
| tree_str+=f"{indent}{tokenizer.convert_tokens_to_string(self.end_tag).strip()} \n" |
| else: |
| for child in self.children: |
| tree_str+=child.get_tree(level + 1, input_ids, tokenizer) |
|
|
| return tree_str |
|
|
|
|
| class TreeBuilder(): |
| def __init__(self, tokenizer: AutoTokenizer = None, root: TreeNode = None, cur_node: TreeNode = None): |
| self.tokenizer = tokenizer |
| self.root = TreeNode(None, 0) |
| self.cur_node = self.root |
| self.buffer = [] |
| self.buffer_start_index = 0 |
| self.idx = 0 |
| self.full_attention_list= None |
| self.web_attention_mask = None |
| self.input_ids = None |
| self.void_elements = [ |
| "area", |
| "base", |
| "br", |
| "col", |
| "embed", |
| "hr", |
| "img", |
| "input", |
| "link", |
| "meta", |
| "param", |
| "source", |
| "track", |
| "wbr" |
| ] |
|
|
| def is_empty(self): |
| return self.root == None |
|
|
| def in_buffer(self, text): |
| if len(self.buffer) == 0: |
| return False |
| return any(text in s for s in self.buffer) |
| |
| def find_buffer(self, text): |
| |
| for index, s in enumerate(self.buffer): |
| if text in s: |
| return index |
| return -1 |
| |
| |
| def extract_open_tag_name(self,buffer): |
| input_string = self.tokenizer.convert_tokens_to_string(buffer) |
| match = re.search(r'<\s*(\w+)(?:\s+[^>]*)?>', input_string) |
| if match: |
| return match.group(1) |
| return None |
|
|
| def extract_close_tag_name(self,buffer): |
| |
| |
| input_string = self.tokenizer.convert_tokens_to_string(buffer) |
| match = re.search(r'</\s*(\w+)(?:\s+[^>]*)?>', input_string) |
| if match: |
| return match.group(1) |
| return None |
| |
| def is_not_empty_buffer(self): |
| return self.tokenizer.convert_tokens_to_string(self.buffer).strip() != '' |
| |
| def get_parent_and_siblings_attention_range(self): |
| attn_range = [] |
| if self.cur_node.parent: |
| parent = self.cur_node.parent |
| if parent.open_tag_range: |
| attn_range += list(range(*parent.open_tag_range)) |
| for child in parent.children: |
| if child is not self.cur_node: |
| if child.open_tag and child.end_tag: |
| attn_range += list(range(*child.open_tag_range)) |
| attn_range += list(range(*child.end_tag_range)) |
| elif child.text: |
| attn_range += list(range(*child.text_range)) |
| elif child.self_closing_tag: |
| attn_range += list(range(*child.self_closing_tag_range)) |
| else: |
| raise Exception(f"??? line 151, get p and s attention range") |
| |
| return attn_range |
| |
| def update_buffer(self, cur_decoded_token): |
| |
| assert isinstance(cur_decoded_token,list), f"{cur_decoded_token}" |
| self.buffer+=cur_decoded_token |
| assert isinstance(cur_decoded_token[0],str) |
| |
| try: |
| |
| if self.in_buffer('</' ) and self.in_buffer('>') and self.find_buffer('</') <= self.find_buffer('>'): |
| close_tag_name = self.extract_close_tag_name(self.buffer) |
| |
| if self.cur_node.open_tag and not self.cur_node.end_tag: |
| assert close_tag_name == self.extract_open_tag_name(self.cur_node.open_tag), f"close_tag_name is {close_tag_name}, with buffer: {self.buffer}, open is-----{self.cur_node.open_tag}---" |
| elif self.cur_node.text or self.cur_node.self_closing_tag or self.cur_node.end_tag: |
| content = None |
| if self.cur_node.text: content = self.cur_node.text |
| elif self.cur_node.self_closing_tag: content = self.cur_node.self_closing_tag |
| elif self.cur_node.end_tag: content = self.cur_node.end_tag |
| self.root.print_tree(0,None,self.tokenizer) |
| raise Exception(f"This should never happen\n {content}, buffer is {self.buffer}") |
| |
| |
| else: |
| raise Exception(f"having end tag without having an open tag\n {self.cur_node.text}") |
| |
| self.cur_node.end_tag = self.buffer[:self.find_buffer('>')+1] |
| self.cur_node.end_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
| self.buffer_start_index += self.find_buffer('>')+1 |
| self.buffer = self.buffer[self.find_buffer('>')+1:] |
| |
| elif self.in_buffer('</'): |
| if self.cur_node.open_tag and not self.cur_node.end_tag: |
| pass |
| elif self.cur_node.text or self.cur_node.self_closing_tag or (self.cur_node.open_tag and self.cur_node.end_tag): |
| cur_end_tag_index = self.find_buffer('</') |
| |
| if self.cur_node.text: |
| self.cur_node.text += self.buffer[:cur_end_tag_index] |
| self.cur_node.text_range[1] += len(self.buffer[:cur_end_tag_index]) |
| elif self.cur_node.self_closing_tag: |
| self.cur_node.self_closing_tag += self.buffer[:cur_end_tag_index] |
| self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_end_tag_index]) |
| else: |
| self.cur_node.end_tag += self.buffer[:cur_end_tag_index] |
| self.cur_node.end_tag_range[1] += len(self.buffer[:cur_end_tag_index]) |
| self.buffer_start_index += len(self.buffer[:cur_end_tag_index]) |
| self.buffer =self.buffer[cur_end_tag_index:] |
| self.cur_node = self.cur_node.parent |
| else: |
| raise Exception(f"having end tag without having an open tag\n {self.cur_node.text} {self.cur_node} {self.cur_node.parent.open_tag}") |
| |
| elif self.in_buffer('<') and self.in_buffer('>'): |
| |
| if self.in_buffer('/>'): |
| self.cur_node.open_tag = None |
| self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1] |
| self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
| else: |
| open_tag_name = self.extract_open_tag_name(self.buffer) |
| if open_tag_name in self.void_elements: |
| self.cur_node.open_tag = None |
| self.cur_node.self_closing_tag = self.buffer[:self.find_buffer(">")+1] |
| self.cur_node.self_closing_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
| else: |
| self.cur_node.open_tag = self.buffer[:self.find_buffer(">")+1] |
| self.cur_node.open_tag_range = [self.buffer_start_index, self.buffer_start_index + self.find_buffer('>')+1] |
| |
| self.buffer_start_index += self.find_buffer('>')+1 |
| self.buffer = self.buffer[self.find_buffer(">")+1:] |
| elif self.in_buffer('<'): |
| if self.full_attention_list is None: |
| self.full_attention_list = self.buffer[:-1] |
| self.buffer = self.buffer[-1:] |
| self.buffer_start_index = len(self.full_attention_list) |
| else: |
| cur_open_tag_index = self.find_buffer('<') |
| |
| if not self.cur_node.partially_open() and self.cur_node.open_tag: |
| if self.cur_node.end_tag: |
| self.cur_node.end_tag += self.buffer[:cur_open_tag_index] |
| self.cur_node.end_tag_range[1] += len(self.buffer[:cur_open_tag_index]) |
| self.buffer_start_index += len(self.buffer[:cur_open_tag_index]) |
| self.buffer =self.buffer[cur_open_tag_index:] |
| child_node = TreeNode(self.buffer, self.idx) |
| if self.cur_node.parent: |
| self.cur_node.parent.add_child(child_node) |
| else: |
| raise Exception(f"This should never happen, a html element with full open tag should have a parent, {self.cur_node.open_tag}") |
| self.idx += 1 |
| self.cur_node = child_node |
| else: |
| child_node = TreeNode(self.buffer, self.idx) |
| self.cur_node.add_child(child_node) |
| self.idx += 1 |
| self.cur_node = child_node |
| elif self.cur_node.text or self.cur_node.self_closing_tag: |
| if self.cur_node.text: |
| self.cur_node.text += self.buffer[:cur_open_tag_index] |
| self.cur_node.text_range[1] += len(self.buffer[:cur_open_tag_index]) |
| elif self.cur_node.self_closing_tag: |
| self.cur_node.self_closing_tag += self.buffer[:cur_open_tag_index] |
| self.cur_node.self_closing_tag_range[1] += len(self.buffer[:cur_open_tag_index]) |
| |
| self.buffer_start_index += len(self.buffer[:cur_open_tag_index]) |
| self.buffer =self.buffer[cur_open_tag_index:] |
| child_node = TreeNode(self.buffer, self.idx) |
| self.cur_node.parent.add_child(child_node) |
| self.idx += 1 |
| self.cur_node = child_node |
| |
| elif (self.cur_node.open_tag or self.cur_node.self_closing_tag) and not self.in_buffer('<') and self.is_not_empty_buffer(): |
| child_node = TreeNode(None, self.idx) |
| child_node.text = self.buffer |
| child_node.text_range[0] = self.buffer_start_index |
| child_node.text_range[1] = self.buffer_start_index + len(self.buffer) |
| |
| if self.cur_node.end_tag or self.cur_node.self_closing_tag: |
| self.cur_node.parent.add_child(child_node) |
| else: |
| self.cur_node.add_child(child_node) |
| |
| self.idx += 1 |
| self.cur_node = child_node |
| self.buffer_start_index += len(self.buffer) |
| self.buffer = [] |
| |
| elif self.cur_node.text and not self.in_buffer('<') and self.is_not_empty_buffer(): |
| self.cur_node.text += self.buffer |
| assert self.cur_node.text_range[0] != -1 and self.cur_node.text_range[1] != -1, f"self.cur_node.text_range[0] and [1] should not be -1 but: {self.cur_node.text_range[0]}, {self.cur_node.text_range[1]}" |
| self.cur_node.text_range[1] += len(self.buffer) |
| self.buffer_start_index += len(self.buffer) |
| self.buffer =[] |
| |
| except Exception as e: |
| traceback.format_exc() |
| raise Exception(e) |
| |
| if self.full_attention_list is None: |
| attn_range = list(range(len(self.buffer))) |
| else: |
| attn_range = list(range(len(self.full_attention_list))) + self.get_parent_and_siblings_attention_range() + self.cur_node.get_range() + [i + self.buffer_start_index for i in list(range(len(self.buffer)))] |
| return attn_range |
|
|