Aratako commited on
Commit
0038229
·
verified ·
1 Parent(s): cc9d83f

Upload 8 files

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: MioTTS 0.1B Demo
3
- emoji: 👀
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
@@ -8,6 +8,11 @@ sdk_version: 6.5.1
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: MioTTS 0.1B Demo
3
+ emoji: 📈
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
 
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
+ license: mit
12
+ short_description: TTS demo for MioTTS-0.1B
13
+ models:
14
+ - Aratako/MioTTS-0.1B
15
+ - Aratako/MioCodec-25Hz-24kHz
16
  ---
17
 
18
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import re
4
+ from typing import Optional
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import spaces
9
+ import torch
10
+
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from miocodec import MioCodecModel
13
+
14
+ from text import normalize_text
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Configuration
18
+ # ---------------------------------------------------------------------------
19
+
20
+ MODEL_REPO = os.environ.get("MODEL_REPO", "Aratako/MioTTS-0.1B")
21
+ CODEC_REPO = os.environ.get("CODEC_REPO", "Aratako/MioCodec-25Hz-24kHz")
22
+
23
+ # Global variables for lazy loading
24
+ _model = None
25
+ _tokenizer = None
26
+ _codec = None
27
+
28
+ # Presets directory
29
+ PRESETS_DIR = "presets"
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Helpers
33
+ # ---------------------------------------------------------------------------
34
+
35
+ TOKEN_PATTERN = re.compile(r"<\|s_(\d+)\|>")
36
+
37
+
38
+ def seed_everything(seed: Optional[int]) -> int:
39
+ if seed is None:
40
+ seed = random.SystemRandom().randint(0, 2**31 - 1)
41
+ print(f"[Info] No seed provided; using random seed {seed}")
42
+
43
+ os.environ["PYTHONHASHSEED"] = str(seed)
44
+ random.seed(seed)
45
+ np.random.seed(seed)
46
+ torch.manual_seed(seed)
47
+ if torch.cuda.is_available():
48
+ torch.cuda.manual_seed(seed)
49
+ torch.backends.cudnn.benchmark = False
50
+ torch.backends.cudnn.deterministic = True
51
+ return seed
52
+
53
+
54
+ def parse_speech_tokens(text: str) -> list[int]:
55
+ tokens = [int(value) for value in TOKEN_PATTERN.findall(text)]
56
+ if not tokens:
57
+ raise ValueError("No speech tokens found in LLM output.")
58
+ return tokens
59
+
60
+
61
+ # ---------------------------------------------------------------------------
62
+ # Model Loading
63
+ # ---------------------------------------------------------------------------
64
+
65
+ def load_models():
66
+ global _model, _tokenizer, _codec
67
+
68
+ if _model is not None:
69
+ return
70
+
71
+ print(f"[Info] Loading LLM from {MODEL_REPO}...")
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+
74
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
75
+ _model = AutoModelForCausalLM.from_pretrained(
76
+ MODEL_REPO,
77
+ torch_dtype=torch.bfloat16,
78
+ ).to(device)
79
+ _model.eval()
80
+
81
+ print(f"[Info] Loading codec from {CODEC_REPO}...")
82
+ _codec = MioCodecModel.from_pretrained(CODEC_REPO)
83
+ _codec = _codec.eval().to(device)
84
+
85
+ print("[Info] Models loaded successfully.")
86
+
87
+
88
+ def get_preset_list() -> list[str]:
89
+ if not os.path.exists(PRESETS_DIR):
90
+ return []
91
+ presets = []
92
+ for f in os.listdir(PRESETS_DIR):
93
+ if f.endswith(".pt"):
94
+ presets.append(f[:-3])
95
+ return sorted(presets)
96
+
97
+
98
+ def load_preset_embedding(preset_id: str) -> torch.Tensor:
99
+ path = os.path.join(PRESETS_DIR, f"{preset_id}.pt")
100
+ if not os.path.exists(path):
101
+ raise FileNotFoundError(f"Preset '{preset_id}' not found.")
102
+ embedding = torch.load(path, map_location="cpu", weights_only=True)
103
+ if isinstance(embedding, dict):
104
+ embedding = embedding.get("global_embedding", embedding)
105
+ return embedding.squeeze()
106
+
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # GPU-decorated Inference Functions
110
+ # ---------------------------------------------------------------------------
111
+
112
+ @spaces.GPU(duration=120)
113
+ def run_inference_gpu(
114
+ target_text: str,
115
+ reference_mode: str,
116
+ reference_audio: Optional[tuple[int, np.ndarray]],
117
+ preset_id: Optional[str],
118
+ temperature: float,
119
+ top_p: float,
120
+ top_k: int,
121
+ repetition_penalty: float,
122
+ max_tokens: int,
123
+ seed: Optional[int],
124
+ num_samples: int = 1,
125
+ ) -> list[tuple[int, np.ndarray]]:
126
+ load_models()
127
+
128
+ used_seed = seed_everything(None if seed is None else int(seed))
129
+ device = "cuda" if torch.cuda.is_available() else "cpu"
130
+
131
+ # Normalize text
132
+ normalized_text = normalize_text(target_text)
133
+ print(f"[Info] Normalized text: {normalized_text}")
134
+
135
+ # Prepare reference
136
+ reference_waveform = None
137
+ global_embedding = None
138
+
139
+ if reference_mode == "upload" and reference_audio is not None:
140
+ sr, audio = reference_audio
141
+ # Convert to tensor
142
+ if audio.ndim == 1:
143
+ audio_tensor = torch.from_numpy(audio).float()
144
+ else:
145
+ audio_tensor = torch.from_numpy(audio.mean(axis=1)).float()
146
+
147
+ # Resample if needed
148
+ codec_sr = _codec.config.sample_rate
149
+ if sr != codec_sr:
150
+ import torchaudio
151
+ audio_tensor = audio_tensor.unsqueeze(0)
152
+ resampler = torchaudio.transforms.Resample(sr, codec_sr)
153
+ audio_tensor = resampler(audio_tensor).squeeze(0)
154
+
155
+ # Trim to max 20 seconds
156
+ max_samples = int(codec_sr * 20)
157
+ if audio_tensor.shape[0] > max_samples:
158
+ audio_tensor = audio_tensor[:max_samples]
159
+ print(f"[Info] Reference audio trimmed to 20 seconds")
160
+
161
+ reference_waveform = audio_tensor.to(device)
162
+ elif reference_mode == "preset" and preset_id:
163
+ global_embedding = load_preset_embedding(preset_id).to(device)
164
+ else:
165
+ raise ValueError("Either reference audio or preset must be provided.")
166
+
167
+ # Tokenize input
168
+ messages = [{"role": "user", "content": normalized_text}]
169
+ input_text = _tokenizer.apply_chat_template(
170
+ messages, tokenize=False, add_generation_prompt=True
171
+ )
172
+ inputs = _tokenizer(input_text, return_tensors="pt").to(device)
173
+ # Remove token_type_ids if present (not used by this model)
174
+ inputs.pop("token_type_ids", None)
175
+
176
+ # Generate (batch)
177
+ with torch.no_grad():
178
+ outputs = _model.generate(
179
+ **inputs,
180
+ max_new_tokens=max_tokens,
181
+ temperature=temperature,
182
+ top_p=top_p,
183
+ top_k=top_k,
184
+ repetition_penalty=repetition_penalty,
185
+ do_sample=True,
186
+ pad_token_id=_tokenizer.eos_token_id,
187
+ num_return_sequences=num_samples,
188
+ )
189
+
190
+ # Parse all generated sequences
191
+ tokens_list = []
192
+ for i in range(outputs.shape[0]):
193
+ generated_text = _tokenizer.decode(outputs[i], skip_special_tokens=False)
194
+ generated_part = generated_text[len(input_text):]
195
+ try:
196
+ speech_tokens = parse_speech_tokens(generated_part)
197
+ tokens_list.append(speech_tokens)
198
+ except ValueError as e:
199
+ print(f"[Warning] Sample {i + 1}: {e}")
200
+
201
+ if not tokens_list:
202
+ raise ValueError("No valid speech tokens generated.")
203
+
204
+ # Decode audio (batch)
205
+ results = []
206
+ sample_rate = _codec.config.sample_rate
207
+
208
+ with torch.no_grad():
209
+ # Prepare batch tokens
210
+ max_len = max(len(t) for t in tokens_list)
211
+ batch_tokens = torch.zeros((len(tokens_list), max_len), dtype=torch.long, device=device)
212
+ content_lengths = []
213
+ for i, tokens in enumerate(tokens_list):
214
+ batch_tokens[i, :len(tokens)] = torch.tensor(tokens, dtype=torch.long)
215
+ content_lengths.append(len(tokens))
216
+
217
+ # Get global embeddings
218
+ if reference_waveform is not None:
219
+ # Extract global embedding from reference waveform
220
+ ref_features = _codec.encode(reference_waveform, return_content=False, return_global=True)
221
+ global_embeddings = ref_features.global_embedding.unsqueeze(0).expand(len(tokens_list), -1)
222
+ else:
223
+ global_embeddings = global_embedding.unsqueeze(0).expand(len(tokens_list), -1)
224
+
225
+ # Batch decode
226
+ audio_batch, audio_lengths = _codec.decode_batch(
227
+ global_embeddings=global_embeddings,
228
+ content_token_indices=batch_tokens,
229
+ content_lengths=content_lengths,
230
+ )
231
+ for i in range(len(tokens_list)):
232
+ audio_len = int(audio_lengths[i])
233
+ audio_np = audio_batch[i, :audio_len].cpu().numpy()
234
+ results.append((sample_rate, audio_np))
235
+
236
+ print(f"[Info] Seed used: {used_seed}")
237
+ return results
238
+
239
+
240
+ # Load models at startup
241
+ load_models()
242
+
243
+
244
+ # ---------------------------------------------------------------------------
245
+ # Gradio UI
246
+ # ---------------------------------------------------------------------------
247
+
248
+ MAX_NUM_SAMPLES = 32
249
+
250
+
251
+ def gradio_inference(
252
+ target_text: str,
253
+ reference_mode: str,
254
+ reference_audio: Optional[tuple[int, np.ndarray]],
255
+ preset_id: Optional[str],
256
+ temperature: float,
257
+ top_p: float,
258
+ top_k: int,
259
+ repetition_penalty: float,
260
+ max_tokens: int,
261
+ seed: str,
262
+ num_samples: int,
263
+ ):
264
+ if not target_text.strip():
265
+ outputs = [gr.update(value=None, visible=False) for _ in range(MAX_NUM_SAMPLES)]
266
+ return outputs
267
+
268
+ seed_val = None
269
+ if seed.strip() not in {"", "None", "none"}:
270
+ seed_val = int(float(seed))
271
+
272
+ try:
273
+ results = run_inference_gpu(
274
+ target_text=target_text,
275
+ reference_mode=reference_mode,
276
+ reference_audio=reference_audio,
277
+ preset_id=preset_id,
278
+ temperature=temperature,
279
+ top_p=top_p,
280
+ top_k=top_k,
281
+ repetition_penalty=repetition_penalty,
282
+ max_tokens=max_tokens,
283
+ seed=seed_val,
284
+ num_samples=int(num_samples),
285
+ )
286
+ except Exception as e:
287
+ print(f"[Error] {e}")
288
+ raise gr.Error(str(e))
289
+
290
+ outputs = []
291
+ for i in range(MAX_NUM_SAMPLES):
292
+ if i < len(results):
293
+ outputs.append(gr.update(value=results[i], visible=True))
294
+ else:
295
+ outputs.append(gr.update(value=None, visible=False))
296
+ return outputs
297
+
298
+
299
+ def build_demo():
300
+ presets = get_preset_list()
301
+
302
+ MODEL_LINK = f"https://huggingface.co/{MODEL_REPO}"
303
+ GITHUB_REPO = "https://github.com/Aratako/MioTTS-Inference"
304
+
305
+ title = "# MioTTS-0.1B Demo"
306
+ description = f"""
307
+ - **Model**: [{MODEL_REPO}]({MODEL_LINK})
308
+ - For faster and more efficient inference, see [MioTTS-Inference]({GITHUB_REPO})
309
+
310
+ **Usage:**
311
+ - Select a preset voice OR upload your own reference audio (max 20 seconds)
312
+ - Enter text to synthesize
313
+ - Adjust generation parameters as needed
314
+ """
315
+
316
+ with gr.Blocks() as demo:
317
+ gr.Markdown(title)
318
+ gr.Markdown(description)
319
+
320
+ with gr.Row():
321
+ with gr.Column(scale=1):
322
+ reference_mode = gr.Radio(
323
+ choices=["preset", "upload"],
324
+ value="preset",
325
+ label="Reference Mode",
326
+ )
327
+ preset_id = gr.Dropdown(
328
+ choices=presets,
329
+ value=presets[0] if presets else None,
330
+ label="Preset Voice",
331
+ allow_custom_value=False,
332
+ visible=True,
333
+ )
334
+ reference_audio = gr.Audio(
335
+ label="Reference Audio",
336
+ type="numpy",
337
+ visible=False,
338
+ )
339
+
340
+ def update_reference_visibility(mode):
341
+ if mode == "preset":
342
+ return gr.update(visible=True), gr.update(visible=False)
343
+ else:
344
+ return gr.update(visible=False), gr.update(visible=True)
345
+
346
+ reference_mode.change(
347
+ fn=update_reference_visibility,
348
+ inputs=[reference_mode],
349
+ outputs=[preset_id, reference_audio],
350
+ )
351
+
352
+ target_text = gr.Textbox(
353
+ label="Text to Synthesize",
354
+ value="",
355
+ placeholder="Enter text to synthesize",
356
+ lines=3,
357
+ )
358
+
359
+ with gr.Row():
360
+ seed_box = gr.Textbox(
361
+ label="Seed (optional)",
362
+ value="",
363
+ placeholder="Leave blank for random",
364
+ )
365
+ num_samples = gr.Slider(
366
+ label="Number of Samples",
367
+ minimum=1,
368
+ maximum=MAX_NUM_SAMPLES,
369
+ step=1,
370
+ value=1,
371
+ )
372
+
373
+ with gr.Row():
374
+ temperature = gr.Slider(
375
+ label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.8
376
+ )
377
+ top_p = gr.Slider(
378
+ label="Top-p", minimum=0.1, maximum=1.0, step=0.05, value=1.0
379
+ )
380
+ top_k = gr.Slider(
381
+ label="Top-k", minimum=0, maximum=100, step=1, value=50
382
+ )
383
+
384
+ with gr.Row():
385
+ repetition_penalty = gr.Slider(
386
+ label="Repetition Penalty",
387
+ minimum=1.0,
388
+ maximum=1.5,
389
+ step=0.05,
390
+ value=1.0,
391
+ )
392
+ max_tokens = gr.Slider(
393
+ label="Max Tokens",
394
+ minimum=100,
395
+ maximum=1000,
396
+ step=50,
397
+ value=700,
398
+ )
399
+
400
+ generate_button = gr.Button("Generate", variant="primary")
401
+
402
+ # Output audio components
403
+ output_audios = []
404
+ cols_per_row = 4
405
+ num_rows = (MAX_NUM_SAMPLES + cols_per_row - 1) // cols_per_row
406
+ with gr.Column():
407
+ for row_idx in range(num_rows):
408
+ with gr.Row():
409
+ for col_idx in range(cols_per_row):
410
+ i = row_idx * cols_per_row + col_idx
411
+ if i >= MAX_NUM_SAMPLES:
412
+ break
413
+ audio = gr.Audio(
414
+ label=f"Sample #{i+1}",
415
+ type="numpy",
416
+ interactive=False,
417
+ visible=(i == 0),
418
+ )
419
+ output_audios.append(audio)
420
+
421
+ generate_button.click(
422
+ fn=gradio_inference,
423
+ inputs=[
424
+ target_text,
425
+ reference_mode,
426
+ reference_audio,
427
+ preset_id,
428
+ temperature,
429
+ top_p,
430
+ top_k,
431
+ repetition_penalty,
432
+ max_tokens,
433
+ seed_box,
434
+ num_samples,
435
+ ],
436
+ outputs=output_audios,
437
+ )
438
+
439
+ return demo
440
+
441
+
442
+ if __name__ == "__main__":
443
+ demo = build_demo()
444
+ demo.launch()
presets/en_female.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a386d2ee0b48036586fc322da0ccf4b88f585ef9162e0371a9999637ebd7645
3
+ size 1997
presets/en_male.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33a37cffd19795491707edcd760b5f9ecf9da1354296fbca6b7add25f2de42d1
3
+ size 2096
presets/jp_female.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9fb6483be458d81b7edcb8edc49487a781ea2a344495f964bfd0d463d560dba
3
+ size 2103
presets/jp_male.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb6667eff5e7a1a80314dd2d52b3a818b7cc00d54a3b659e1b21f2423581cb82
3
+ size 2096
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ transformers<5
4
+ accelerate
5
+ gradio
6
+ soundfile
7
+ numpy
8
+ miocodec @ git+https://github.com/Aratako/MioCodec@main
text.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import re
4
+
5
+ REPLACE_MAP: dict[str, str] = {
6
+ r"\t": "",
7
+ r"\[n\]": "",
8
+ r" ": "",
9
+ r"[;▼♀♂《》≪≫①②③④⑤⑥]": "",
10
+ r"[\u02d7\u2010-\u2015\u2043\u2212\u23af\u23e4\u2500\u2501\u2e3a\u2e3b]": "",
11
+ r"[\uff5e\u301C]": "ー",
12
+ r"?": "?",
13
+ r"!": "!",
14
+ r"[●◯〇]": "○",
15
+ r"♥": "♡",
16
+ }
17
+
18
+ FULLWIDTH_ALPHA_TO_HALFWIDTH = str.maketrans(
19
+ {
20
+ chr(full): chr(half)
21
+ for full, half in zip(
22
+ list(range(0xFF21, 0xFF3B)) + list(range(0xFF41, 0xFF5B)),
23
+ list(range(0x41, 0x5B)) + list(range(0x61, 0x7B)),
24
+ strict=True,
25
+ )
26
+ }
27
+ )
28
+ _HALFWIDTH_KATAKANA_CHARS = "ヲァィゥェォャュョッーアイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワン"
29
+ _FULLWIDTH_KATAKANA_CHARS = "ヲァィゥェォャュョッーアイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワン"
30
+ HALFWIDTH_KATAKANA_TO_FULLWIDTH = str.maketrans(
31
+ _HALFWIDTH_KATAKANA_CHARS, _FULLWIDTH_KATAKANA_CHARS
32
+ )
33
+ FULLWIDTH_DIGITS_TO_HALFWIDTH = str.maketrans(
34
+ {
35
+ chr(full): chr(half)
36
+ for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A), strict=True)
37
+ }
38
+ )
39
+
40
+
41
+ def normalize_text(text: str) -> str:
42
+ """Normalize text for TTS."""
43
+ for pattern, replacement in REPLACE_MAP.items():
44
+ text = re.sub(pattern, replacement, text)
45
+
46
+ text = text.translate(FULLWIDTH_ALPHA_TO_HALFWIDTH)
47
+ text = text.translate(FULLWIDTH_DIGITS_TO_HALFWIDTH)
48
+ text = text.translate(HALFWIDTH_KATAKANA_TO_FULLWIDTH)
49
+
50
+ text = re.sub(r"…{3,}", "……", text)
51
+
52
+ if text.startswith("「") and text.endswith("」"):
53
+ text = text[1:-1]
54
+ if text.startswith("『") and text.endswith("』"):
55
+ text = text[1:-1]
56
+ if text.startswith("(") and text.endswith(")"):
57
+ text = text[1:-1]
58
+ if text.startswith("【") and text.endswith("】"):
59
+ text = text[1:-1]
60
+ if text.startswith("(") and text.endswith(")"):
61
+ text = text[1:-1]
62
+
63
+ if text.endswith("。"):
64
+ text = text.rstrip("。")
65
+
66
+ return text