# streamlit_app.py # Compute-Optimal LLM Training Estimator (Chinchilla-style) # --------------------------------------------------------- # Usage: `streamlit run streamlit_app.py` # This tool helps estimate total FLOPs, steps, wall-clock time, and rough cost # for LLM pretraining given model parameters, token budget, and hardware. import math import streamlit as st st.set_page_config(page_title="LLM Compute Estimator", page_icon="🧮", layout="centered") st.title("🧮 LLM Compute-Optimal Estimator") st.caption("Estimate total FLOPs, wall-clock time, steps, and cost for pretraining — with a Chinchilla-style token rule.") # --- Sidebar: assumptions --- with st.sidebar: st.logo('./static/logo_light.png') st.header("Assumptions & Notes") st.markdown( """ **Formulas** - **Total FLOPs** ≈ `c * N_params * N_tokens`, with default **c = 6** (forward+backward+optimizer overhead). - **Compute-optimal tokens** (rule-of-thumb): `N_tokens ≈ k * N_params`, default **k = 20**. - **Effective compute** = `GPU_count * (peak TFLOPs × 1e12) * efficiency`. **Disclaimers** - This is a *back-of-the-envelope* estimator. Real training efficiency depends on data pipeline, parallelism strategy, sequence length, kernel fusion, optimizer, etc. - Preset TFLOPs are **approximate** and depend on precision (FP8/BF16), sparsity, clocks, and vendor kernels. """ ) # --- 1) Model size & tokens --- st.subheader("1) Model & Token Budget") col1, col2, col3 = st.columns([1.2, 1, 1]) with col1: model_params_b = st.number_input("Model size (Billions of parameters)", min_value=0.05, value=4.0, step=0.5, format="%.2f") with col2: c_overhead = st.number_input("c (FLOPs constant)", min_value=4.0, value=6.0, step=0.5) with col3: k_tokens_per_param = st.number_input("k (tokens per param for compute-optimal)", min_value=5.0, value=20.0, step=1.0) use_compute_optimal = st.toggle("Use compute‑optimal tokens (k × params)", value=True) if use_compute_optimal: tokens_b = model_params_b * k_tokens_per_param st.info(f"Compute‑optimal token budget ≈ **{tokens_b:,.2f} B** (k = {k_tokens_per_param:g})") else: tokens_b = st.number_input("Token budget (Billions)", min_value=1.0, value=80.0, step=5.0, format="%.2f") # --- 2) Hardware (moved before batch to define gpu_count first) --- st.subheader("2) Hardware") col6, col7 = st.columns(2) with col6: gpu_preset = st.selectbox( "GPU preset (approx peak TFLOPs per GPU)", ( "Custom", "A100 80GB BF16 ≈ 312", "H100 SXM BF16 ≈ 989", "B200 (FP8-ish) ≈ 20000", ), index=0, help="Values are back-of-the-envelope. Choose 'Custom' to enter your own.", ) preset_map = { "A100 80GB BF16 ≈ 312": 312.0, "H100 SXM BF16 ≈ 989": 989.0, "B200 (FP8-ish) ≈ 20000": 20000.0, } with col7: if gpu_preset == "Custom": peak_tflops = st.number_input("Peak TFLOPs per GPU (approx)", min_value=10.0, value=20000.0, step=100.0) else: peak_tflops = preset_map[gpu_preset] st.number_input("Peak TFLOPs per GPU (approx)", value=peak_tflops, disabled=True) col8, col9, col10 = st.columns(3) with col8: gpu_count = st.number_input("GPU count", min_value=1, value=8, step=1) with col9: efficiency = st.slider("Training efficiency (MFU, %)", min_value=10, max_value=95, value=50, step=1) with col10: price_per_gpu_hour = st.number_input("Price per GPU·hour (USD)", min_value=0.0, value=25.0, step=1.0) # --- 3) Batch & Sequence Settings (tokens_per_step computed from gpu_count) --- st.subheader("3) Batch & Sequence Settings") col4, col5 = st.columns(2) with col4: micro_batch = st.number_input("Micro batch size per GPU", min_value=1, value=4, step=1, help="Number of sequences per GPU per optimizer step.") with col5: seq_len = st.number_input("Sequence length (tokens)", min_value=128, value=2048, step=128) tokens_per_step = micro_batch * seq_len * gpu_count st.info(f"Tokens per optimization step ≈ {tokens_per_step:,} (with {gpu_count} GPUs)") # --- Compute --- N_params = model_params_b * 1e9 N_tokens = tokens_b * 1e9 c = c_overhead # Total FLOPs (scalar) flops_total = c * N_params * N_tokens # in FLOPs # Effective machine compute per second effective_flops_per_s = gpu_count * (peak_tflops * 1e12) * (efficiency / 100.0) # Time estimate seconds = flops_total / effective_flops_per_s if effective_flops_per_s > 0 else float('inf') hours = seconds / 3600 days = hours / 24 # Steps steps = N_tokens / tokens_per_step if tokens_per_step > 0 else float('inf') # Throughput throughput_tokens_per_s = N_tokens / seconds if seconds > 0 else float('inf') # Cost cost = price_per_gpu_hour * gpu_count * hours # --- Display --- st.divider() st.subheader("Results") colA, colB = st.columns(2) with colA: st.metric("Total FLOPs", f"{flops_total:,.2e} FLOPs") st.metric("Effective compute", f"{effective_flops_per_s:,.2e} FLOPs/s") st.metric("Steps (est)", f"{0 if steps == float('inf') else steps:,.0f}") with colB: st.metric("Wall‑clock time", f"{hours:,.1f} h (~{days:,.2f} d)") st.metric("Throughput", f"{0 if throughput_tokens_per_s == float('inf') else throughput_tokens_per_s:,.0f} tok/s") st.metric("Projected cost", f"${0 if cost == float('inf') else cost:,.0f}") st.divider() st.markdown( f""" **Summary** - Params: **{model_params_b:,.2f}B** · Tokens: **{tokens_b:,.2f}B** (compute‑optimal: {use_compute_optimal}) - Constant **c = {c:g}** → Total ≈ **{flops_total:,.2e} FLOPs**. - Hardware: **{gpu_count}× GPU**, peak **{peak_tflops:g} TFLOPs/GPU**, MFU **{efficiency}%** → Effective ≈ **{effective_flops_per_s:,.2e} FLOPs/s**. - Time ≈ **{hours:,.1f} hours** (≈ {days:,.2f} days). Steps ≈ **{0 if steps == float('inf') else steps:,.0f}** (@ {tokens_per_step:,} tok/step). - Rough cost ≈ **${0 if cost == float('inf') else cost:,.0f}** (@ ${price_per_gpu_hour:g}/GPU·h). """ ) with st.expander("What is the Chinchilla rule? Is it 1 epoch?"): st.markdown( """ **Chinchilla scaling** is a *compute‑optimal* rule of thumb: for a fixed compute budget, scale the **training tokens** roughly in proportion to the **model parameters** (commonly ~20× tokens per parameter). It is **not** about training for exactly one epoch. In web‑scale pretraining, datasets are often sampled with replacement or mixed; you might see data multiple times or less than once. The rule speaks to the *total number of tokens* the model should process for best use of compute, not to dataset passes. """ ) st.success("Ready. Tweak inputs on the left to explore different scenarios.")