Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Teacher-Student知识蒸馏脚本 | |
| 将经过SFT+PPO RLHF的Teacher模型蒸馏到更小的Student模型 | |
| """ | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorForLanguageModeling, | |
| logging, | |
| ) | |
| from datasets import load_dataset, Dataset as HFDataset | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| import numpy as np | |
| import wandb | |
| from typing import Dict, List, Any, Optional | |
| import json | |
| from tqdm import tqdm | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| logging.set_verbosity(logging.CRITICAL) | |
| class DistillationConfig: | |
| """蒸馏训练配置""" | |
| # 模型路径 | |
| teacher_model_path = "./rlhf_teacher_model" # RLHF后的Teacher模型 | |
| student_model_name = "microsoft/DialoGPT-medium" # 替换为实际的OpenAI OSS 20B模型 | |
| # 蒸馏参数 | |
| temperature = 4.0 # 蒸馏温度 | |
| alpha = 0.7 # 蒸馏损失权重 | |
| beta = 0.3 # 学生损失权重 | |
| gamma = 0.1 # 特征匹配损失权重 | |
| # 训练参数 | |
| learning_rate = 1e-4 | |
| num_train_epochs = 3 | |
| per_device_train_batch_size = 2 | |
| per_device_eval_batch_size = 4 | |
| gradient_accumulation_steps = 8 | |
| warmup_ratio = 0.1 | |
| weight_decay = 0.01 | |
| logging_steps = 50 | |
| eval_steps = 500 | |
| save_steps = 1000 | |
| # LoRA配置(为Student模型添加LoRA以提高训练效率) | |
| use_lora = True | |
| lora_r = 32 | |
| lora_alpha = 64 | |
| lora_dropout = 0.1 | |
| # 数据配置 | |
| max_length = 512 | |
| num_distill_samples = 10000 # 用于蒸馏的样本数量 | |
| # 输出配置 | |
| output_dir = "./distilled_student_model" | |
| run_name = "teacher-student-distillation" | |
| class DistillationDataset(Dataset): | |
| """蒸馏数据集类""" | |
| def __init__(self, teacher_outputs: List[Dict], tokenizer, max_length: int = 512): | |
| self.data = teacher_outputs | |
| self.tokenizer = tokenizer | |
| self.max_length = max_length | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| # 构建完整的输入-输出序列 | |
| full_text = f"### Human: {item['prompt']}\n### Assistant: {item['response']}" | |
| # Tokenize | |
| encoded = self.tokenizer( | |
| full_text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| return_tensors="pt" | |
| ) | |
| return { | |
| "input_ids": encoded["input_ids"].squeeze(), | |
| "attention_mask": encoded["attention_mask"].squeeze(), | |
| "teacher_logits": torch.tensor(item["teacher_logits"], dtype=torch.float), | |
| "labels": encoded["input_ids"].squeeze() | |
| } | |
| class KnowledgeDistillationTrainer(Trainer): | |
| """知识蒸馏训练器""" | |
| def __init__(self, teacher_model, student_model, temperature=4.0, alpha=0.7, beta=0.3, gamma=0.1, **kwargs): | |
| super().__init__(model=student_model, **kwargs) | |
| self.teacher_model = teacher_model | |
| self.teacher_model.eval() # 冻结Teacher模型 | |
| self.temperature = temperature | |
| self.alpha = alpha # 蒸馏损失权重 | |
| self.beta = beta # 学生损失权重 | |
| self.gamma = gamma # 特征匹配损失权重 | |
| def compute_loss(self, model, inputs, return_outputs=False): | |
| """计算蒸馏损失""" | |
| labels = inputs.get("labels") | |
| teacher_logits = inputs.get("teacher_logits").to(model.device) | |
| # Student模型前向传播 | |
| student_outputs = model(**{k: v for k, v in inputs.items() if k not in ["teacher_logits"]}) | |
| student_logits = student_outputs.logits | |
| # 计算各种损失 | |
| losses = {} | |
| # 1. 标准语言模型损失 (学生模型自己的损失) | |
| if labels is not None: | |
| shift_logits = student_logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss_fct = torch.nn.CrossEntropyLoss() | |
| student_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| losses["student_loss"] = student_loss | |
| # 2. 蒸馏损失 (KL散度) | |
| if teacher_logits is not None: | |
| # 确保维度匹配 | |
| if teacher_logits.shape != student_logits.shape: | |
| min_seq_len = min(teacher_logits.shape[1], student_logits.shape[1]) | |
| teacher_logits = teacher_logits[:, :min_seq_len, :] | |
| student_logits_for_distill = student_logits[:, :min_seq_len, :] | |
| else: | |
| student_logits_for_distill = student_logits | |
| # 计算软标签概率 | |
| teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1) | |
| student_log_probs = F.log_softmax(student_logits_for_distill / self.temperature, dim=-1) | |
| # KL散度损失 | |
| distill_loss = F.kl_div( | |
| student_log_probs, | |
| teacher_probs, | |
| reduction="batchmean" | |
| ) * (self.temperature ** 2) | |
| losses["distill_loss"] = distill_loss | |
| # 3. 组合总损失 | |
| total_loss = 0 | |
| if "student_loss" in losses: | |
| total_loss += self.beta * losses["student_loss"] | |
| if "distill_loss" in losses: | |
| total_loss += self.alpha * losses["distill_loss"] | |
| # 记录各项损失 | |
| self.log({ | |
| "train/total_loss": total_loss.item(), | |
| "train/student_loss": losses.get("student_loss", 0).item() if "student_loss" in losses else 0, | |
| "train/distill_loss": losses.get("distill_loss", 0).item() if "distill_loss" in losses else 0, | |
| }) | |
| return (total_loss, student_outputs) if return_outputs else total_loss | |
| def prepare_student_model(config: DistillationConfig): | |
| """准备Student模型""" | |
| print("🎓 Preparing student model...") | |
| # 加载Student基础模型 | |
| student_model = AutoModelForCausalLM.from_pretrained( | |
| config.student_model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| # 添加LoRA(可选,用于高效训练) | |
| if config.use_lora: | |
| print("🔧 Adding LoRA to student model...") | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=config.lora_r, | |
| lora_alpha=config.lora_alpha, | |
| lora_dropout=config.lora_dropout, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ] | |
| ) | |
| student_model = get_peft_model(student_model, lora_config) | |
| student_model.print_trainable_parameters() | |
| return student_model | |
| def load_teacher_model(config: DistillationConfig): | |
| """加载Teacher模型""" | |
| print("👨🏫 Loading teacher model...") | |
| teacher_model = AutoModelForCausalLM.from_pretrained( | |
| config.teacher_model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| teacher_model.eval() | |
| return teacher_model | |
| def generate_distillation_data(teacher_model, tokenizer, config: DistillationConfig): | |
| """生成蒸馏数据""" | |
| print("📊 Generating distillation dataset...") | |
| # 加载提示数据集 | |
| dataset_sources = [ | |
| "smangrul/ad-copy-generation", | |
| # 可以添加更多数据源 | |
| ] | |
| all_prompts = [] | |
| for source in dataset_sources: | |
| try: | |
| ds = load_dataset(source, split="train") | |
| # 提取提示词 | |
| for item in ds: | |
| if "conversations" in item and len(item["conversations"]) > 0: | |
| prompt = item["conversations"][0].get("value", "") | |
| if len(prompt.strip()) > 10: | |
| all_prompts.append(prompt.strip()) | |
| except Exception as e: | |
| print(f"⚠️ Error loading {source}: {e}") | |
| # 限制样本数量 | |
| if len(all_prompts) > config.num_distill_samples: | |
| all_prompts = all_prompts[:config.num_distill_samples] | |
| print(f"📝 Generating responses for {len(all_prompts)} prompts...") | |
| distillation_data = [] | |
| teacher_model.eval() | |
| with torch.no_grad(): | |
| for i, prompt in enumerate(tqdm(all_prompts, desc="Generating teacher responses")): | |
| try: | |
| # 格式化输入 | |
| formatted_prompt = f"### Human: {prompt}\n### Assistant:" | |
| inputs = tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=config.max_length // 2 | |
| ).to(teacher_model.device) | |
| # 生成响应 | |
| outputs = teacher_model.generate( | |
| **inputs, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_dict_in_generate=True, | |
| output_scores=True | |
| ) | |
| # 解码响应 | |
| generated_ids = outputs.sequences[0][inputs.input_ids.shape[1]:] | |
| response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() | |
| # 获取Teacher的logits | |
| full_text = f"### Human: {prompt}\n### Assistant: {response}" | |
| full_inputs = tokenizer( | |
| full_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=config.max_length | |
| ).to(teacher_model.device) | |
| teacher_outputs = teacher_model(**full_inputs) | |
| teacher_logits = teacher_outputs.logits.cpu().numpy() | |
| distillation_data.append({ | |
| "prompt": prompt, | |
| "response": response, | |
| "teacher_logits": teacher_logits.tolist() | |
| }) | |
| # 定期保存中间结果 | |
| if (i + 1) % 100 == 0: | |
| print(f"Generated {i + 1}/{len(all_prompts)} samples") | |
| except Exception as e: | |
| print(f"⚠️ Error generating for prompt {i}: {e}") | |
| continue | |
| print(f"✅ Generated {len(distillation_data)} teacher-student pairs") | |
| # 保存蒸馏数据 | |
| with open("distillation_data.json", "w", encoding="utf-8") as f: | |
| json.dump(distillation_data, f, ensure_ascii=False, indent=2) | |
| return distillation_data | |
| def create_data_collator(tokenizer): | |
| """创建数据整理器""" | |
| return DataCollatorForLanguageModeling( | |
| tokenizer=tokenizer, | |
| mlm=False, | |
| pad_to_multiple_of=8 | |
| ) | |
| def run_distillation(): | |
| """主要的蒸馏训练流程""" | |
| print("🚀 Starting Teacher-Student Distillation...") | |
| config = DistillationConfig() | |
| # 初始化wandb | |
| wandb.init( | |
| project="teacher-student-distillation", | |
| config=vars(config), | |
| name=config.run_name | |
| ) | |
| # 加载tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 加载模型 | |
| teacher_model = load_teacher_model(config) | |
| student_model = prepare_student_model(config) | |
| # 生成蒸馏数据 | |
| if os.path.exists("distillation_data.json"): | |
| print("📂 Loading existing distillation data...") | |
| with open("distillation_data.json", "r", encoding="utf-8") as f: | |
| distillation_data = json.load(f) | |
| else: | |
| distillation_data = generate_distillation_data(teacher_model, tokenizer, config) | |
| # 创建数据集 | |
| train_size = int(0.9 * len(distillation_data)) | |
| train_data = distillation_data[:train_size] | |
| eval_data = distillation_data[train_size:] | |
| train_dataset = DistillationDataset(train_data, tokenizer, config.max_length) | |
| eval_dataset = DistillationDataset(eval_data, tokenizer, config.max_length) | |
| print(f"📊 Training samples: {len(train_dataset)}") | |
| print(f"📊 Evaluation samples: {len(eval_dataset)}") | |
| # 训练参数 | |
| training_args = TrainingArguments( | |
| output_dir=config.output_dir, | |
| overwrite_output_dir=True, | |
| num_train_epochs=config.num_train_epochs, | |
| per_device_train_batch_size=config.per_device_train_batch_size, | |
| per_device_eval_batch_size=config.per_device_eval_batch_size, | |
| gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| learning_rate=config.learning_rate, | |
| weight_decay=config.weight_decay, | |
| warmup_ratio=config.warmup_ratio, | |
| logging_steps=config.logging_steps, | |
| eval_steps=config.eval_steps, | |
| save_steps=config.save_steps, | |
| evaluation_strategy="steps", | |
| save_strategy="steps", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| report_to="wandb", | |
| run_name=config.run_name, | |
| fp16=True, | |
| dataloader_pin_memory=False, | |
| remove_unused_columns=False, | |
| group_by_length=True, | |
| ) | |
| # 创建数据整理器 | |
| data_collator = create_data_collator(tokenizer) | |
| # 创建蒸馏训练器 | |
| trainer = KnowledgeDistillationTrainer( | |
| teacher_model=teacher_model, | |
| student_model=student_model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| data_collator=data_collator, | |
| tokenizer=tokenizer, | |
| temperature=config.temperature, | |
| alpha=config.alpha, | |
| beta=config.beta, | |
| gamma=config.gamma, | |
| ) | |
| # 开始训练 | |
| print("🔥 Starting distillation training...") | |
| trainer.train() | |
| # 保存最终模型 | |
| print("💾 Saving distilled student model...") | |
| trainer.save_model() | |
| tokenizer.save_pretrained(config.output_dir) | |
| # 评估模型 | |
| print("🧪 Evaluating distilled model...") | |
| evaluate_distilled_model(trainer.model, tokenizer, config) | |
| wandb.finish() | |
| print("✅ Distillation training completed!") | |
| def evaluate_distilled_model(model, tokenizer, config: DistillationConfig): | |
| """评估蒸馏后的模型""" | |
| print("📊 Evaluating distilled student model...") | |
| test_prompts = [ | |
| "Create an advertisement for a revolutionary AI-powered fitness tracker", | |
| "Write marketing copy for an eco-friendly electric vehicle", | |
| "Generate a slogan for a productivity app for remote workers", | |
| "Create ad copy for a sustainable fashion brand targeting millennials", | |
| "Write promotional content for a mental health app", | |
| ] | |
| model.eval() | |
| results = [] | |
| for prompt in test_prompts: | |
| formatted_prompt = f"### Human: {prompt}\n### Assistant:" | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| generated_text = response[len(formatted_prompt):].strip() | |
| results.append({ | |
| "prompt": prompt, | |
| "response": generated_text | |
| }) | |
| print(f"\n🔍 Prompt: {prompt}") | |
| print(f"📝 Student Response: {generated_text}") | |
| print("-" * 80) | |
| # 保存评估结果 | |
| with open(f"{config.output_dir}/evaluation_results.json", "w", encoding="utf-8") as f: | |
| json.dump(results, f, ensure_ascii=False, indent=2) | |
| return results | |
| if __name__ == "__main__": | |
| # 设置环境变量 | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| # 检查GPU | |
| if torch.cuda.is_available(): | |
| print(f"🔥 Using {torch.cuda.device_count()} GPUs") | |
| for i in range(torch.cuda.device_count()): | |
| print(f" GPU {i}: {torch.cuda.get_device_name(i)}") | |
| else: | |
| print("⚠️ Warning: No GPU available, using CPU (very slow)") | |
| # 开始蒸馏训练 | |
| run_distillation() |