| --- |
| language: |
| - en |
| license: apache-2.0 |
| pipeline_tag: text-generation |
| library_name: transformers |
| datasets: |
| - HuggingFaceFW/fineweb-edu |
| - HuggingFaceTB/stack-edu |
| - HuggingFaceTB/finemath |
| tags: |
| - causal-lm |
| - 100m-parameters |
| - single-gpu-training |
| - flashattention2 |
| - gqa |
| model-index: |
| - name: Rain-v2 |
| results: |
| - task: |
| type: multiple-choice-qa |
| name: ARC-Easy (5-shot) |
| metrics: |
| - type: accuracy |
| value: 0.35-0.40 |
| - task: |
| type: multiple-choice-qa |
| name: HellaSwag (5-shot) |
| metrics: |
| - type: accuracy |
| value: 0.28-0.30 |
| - task: |
| type: multiple-choice-qa |
| name: PIQA (5-shot) |
| metrics: |
| - type: accuracy |
| value: 0.60 |
| - task: |
| type: coreference-resolution |
| name: Winogrande (5-shot) |
| metrics: |
| - type: accuracy |
| value: 0.51-0.52 |
| --- |
| |
| # Rain-v2 |
|
|
| Rain-v2 是一个约 1 亿参数的英文自回归语言模型,在 RTX 4090 约两天内完成预训练,展示了在有限算力下从数据到模型的完整实践路径。 |
|
|
| ## 模型与训练配置 |
|
|
| - 参数规模:≈100M |
| - 架构:32 层解码器,隐藏维 512,8 头 GQA(4 个 KV 头),RoPE,RMSNorm,SwiGLU,输入/输出权重共享 |
| - 词表:自训 BPE,16,384 词,面向英文/代码/数学混合语料 |
| - 上下文长度:1024 |
| - 学习率调度:1% warmup + cosine decay |
| - 训练总量:≈6.64×10^8 tokens,总用时 ~40 小时 @ RTX 4090 |
|
|
| ## 数据配比 |
|
|
| - FineWeb-Edu(高质量英文教育语料)60% |
| - Stack-Edu(Python 教学代码/问答子集)30% |
| - FineMath-4+(高质量数学/逻辑)10% |
|
|
| 总量约 10 B。 |
|
|
| ## 评测摘要(5-shot) |
|
|
| - ARC-Easy:40% |
| - HellaSwag:30% |
| - PIQA:60% |
| - Winogrande: 51% |
|
|
| ## 安全与限制 |
|
|
| 易输出错误事实或伪造信息。未经对齐,会生成偏见/有害/违法内容;请勿直接面向终端用户。 |
|
|
| ## 使用示例 |
|
|
| ```python |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| |
| model = AutoModelForCausalLM.from_pretrained("raincandy-u/Rain-v2", torch_dtype=torch.bfloat16, device_map="auto") |
| tok = AutoTokenizer.from_pretrained("your-namespace/Rain-v2") |
| |
| prompt = "Here's a fairy tale about a little pig. A long, long time ago, there was a little pig called " |
| inputs = tok(prompt, return_tensors="pt").to(model.device) |
| out = model.generate(**inputs, max_new_tokens=120, temperature=0.8, top_p=0.9) |
| print(tok.decode(out[0], skip_special_tokens=True)) |
| ``` |