ThreadAbort commited on
Commit
f417d27
·
2 Parent(s): e3e7558 0393dfa

Merge branch 'main' of https://github.com/8b-is/IndexTTS-Rust

Browse files
README.md CHANGED
@@ -39,7 +39,7 @@ Compared to the Python implementation:
39
 
40
  ```bash
41
  # Clone the repository
42
- git clone https://github.com/your-org/IndexTTS-Rust.git
43
  cd IndexTTS-Rust
44
 
45
  # Build in release mode (optimized)
 
39
 
40
  ```bash
41
  # Clone the repository
42
+ git clone https://github.com/8b-is/IndexTTS-Rust.git
43
  cd IndexTTS-Rust
44
 
45
  # Build in release mode (optimized)
benches/mel_spectrogram.rs CHANGED
@@ -8,9 +8,7 @@ fn bench_mel_spectrogram(c: &mut Criterion) {
8
 
9
  // Generate 1 second of audio
10
  let num_samples = config.sample_rate as usize;
11
- let signal: Vec<f32> = (0..num_samples)
12
- .map(|i| (i as f32 * 0.01).sin())
13
- .collect();
14
 
15
  c.bench_function("mel_spectrogram_1s", |b| {
16
  b.iter(|| mel_spectrogram(black_box(&signal), black_box(&config)))
@@ -29,9 +27,7 @@ fn bench_mel_spectrogram(c: &mut Criterion) {
29
  fn bench_stft(c: &mut Criterion) {
30
  let config = AudioConfig::default();
31
  let num_samples = config.sample_rate as usize;
32
- let signal: Vec<f32> = (0..num_samples)
33
- .map(|i| (i as f32 * 0.01).sin())
34
- .collect();
35
 
36
  c.bench_function("stft_1s", |b| {
37
  b.iter(|| {
 
8
 
9
  // Generate 1 second of audio
10
  let num_samples = config.sample_rate as usize;
11
+ let signal: Vec<f32> = (0..num_samples).map(|i| (i as f32 * 0.01).sin()).collect();
 
 
12
 
13
  c.bench_function("mel_spectrogram_1s", |b| {
14
  b.iter(|| mel_spectrogram(black_box(&signal), black_box(&config)))
 
27
  fn bench_stft(c: &mut Criterion) {
28
  let config = AudioConfig::default();
29
  let num_samples = config.sample_rate as usize;
30
+ let signal: Vec<f32> = (0..num_samples).map(|i| (i as f32 * 0.01).sin()).collect();
 
 
31
 
32
  c.bench_function("stft_1s", |b| {
33
  b.iter(|| {
src/audio/dsp.rs CHANGED
@@ -1,6 +1,5 @@
1
  //! Digital Signal Processing utilities
2
 
3
- use crate::Result;
4
 
5
  /// Apply pre-emphasis filter to audio signal
6
  ///
 
1
  //! Digital Signal Processing utilities
2
 
 
3
 
4
  /// Apply pre-emphasis filter to audio signal
5
  ///
src/audio/mod.rs CHANGED
@@ -7,9 +7,12 @@ mod io;
7
  pub mod mel;
8
  mod resample;
9
 
10
- pub use dsp::{apply_preemphasis, dynamic_range_compression, dynamic_range_decompression, normalize_audio, normalize_audio_peak, apply_fade};
 
 
 
11
  pub use io::{load_audio, save_audio, AudioData};
12
- pub use mel::{mel_spectrogram, MelFilterbank, mel_to_linear};
13
  pub use resample::resample;
14
 
15
  use crate::Result;
@@ -48,10 +51,7 @@ impl Default for AudioConfig {
48
  }
49
 
50
  /// Compute mel spectrogram from audio file
51
- pub fn compute_mel_from_file(
52
- path: &str,
53
- config: &AudioConfig,
54
- ) -> Result<ndarray::Array2<f32>> {
55
  let audio = load_audio(path, Some(config.sample_rate))?;
56
  mel_spectrogram(&audio.samples, config)
57
  }
 
7
  pub mod mel;
8
  mod resample;
9
 
10
+ pub use dsp::{
11
+ apply_fade, apply_preemphasis, dynamic_range_compression, dynamic_range_decompression,
12
+ normalize_audio, normalize_audio_peak,
13
+ };
14
  pub use io::{load_audio, save_audio, AudioData};
15
+ pub use mel::{mel_spectrogram, mel_to_linear, MelFilterbank};
16
  pub use resample::resample;
17
 
18
  use crate::Result;
 
51
  }
52
 
53
  /// Compute mel spectrogram from audio file
54
+ pub fn compute_mel_from_file(path: &str, config: &AudioConfig) -> Result<ndarray::Array2<f32>> {
 
 
 
55
  let audio = load_audio(path, Some(config.sample_rate))?;
56
  mel_spectrogram(&audio.samples, config)
57
  }
src/config/mod.rs CHANGED
@@ -289,7 +289,7 @@ impl Config {
289
  if self.gpt.heads == 0 {
290
  return Err(Error::Config("GPT heads must be > 0".into()));
291
  }
292
- if self.gpt.model_dim % self.gpt.heads != 0 {
293
  return Err(Error::Config(
294
  "GPT model_dim must be divisible by heads".into(),
295
  ));
 
289
  if self.gpt.heads == 0 {
290
  return Err(Error::Config("GPT heads must be > 0".into()));
291
  }
292
+ if !self.gpt.model_dim.is_multiple_of(self.gpt.heads) {
293
  return Err(Error::Config(
294
  "GPT model_dim must be divisible by heads".into(),
295
  ));
src/model/embedding.rs CHANGED
@@ -136,7 +136,7 @@ impl EmotionEncoder {
136
  .map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?;
137
 
138
  let shape = tensor.shape();
139
- let mut data: Vec<f32> = tensor.data().chunks_exact(4).map(|b| {
140
  f32::from_le_bytes([b[0], b[1], b[2], b[3]])
141
  }).collect();
142
  if !tensor.data().chunks_exact(4).remainder().is_empty() {
@@ -170,7 +170,7 @@ impl EmotionEncoder {
170
  let mut embedding = vec![0.0f32; embedding_dim];
171
 
172
  let mut offset = 0;
173
- for (dim_idx, (&value, &dim_size)) in emotion_vector.iter().zip(self.dim_sizes.iter()).enumerate() {
174
  // Interpolate between discrete emotion levels
175
  let continuous_idx = value * (dim_size - 1) as f32;
176
  let lower_idx = continuous_idx.floor() as usize;
 
136
  .map_err(|e| Error::ModelLoading(format!("Missing emotion_matrix: {}", e)))?;
137
 
138
  let shape = tensor.shape();
139
+ let data: Vec<f32> = tensor.data().chunks_exact(4).map(|b| {
140
  f32::from_le_bytes([b[0], b[1], b[2], b[3]])
141
  }).collect();
142
  if !tensor.data().chunks_exact(4).remainder().is_empty() {
 
170
  let mut embedding = vec![0.0f32; embedding_dim];
171
 
172
  let mut offset = 0;
173
+ for (WIN_LENGTH, (&value, &dim_size)) in emotion_vector.iter().zip(self.dim_sizes.iter()).enumerate() {
174
  // Interpolate between discrete emotion levels
175
  let continuous_idx = value * (dim_size - 1) as f32;
176
  let lower_idx = continuous_idx.floor() as usize;
src/model/gpt.rs CHANGED
@@ -1,7 +1,7 @@
1
  //! GPT-based sequence generation model
2
 
3
  use crate::{Error, Result};
4
- use ndarray::{Array, Array1, Array2, Array3, IxDyn};
5
  use std::collections::HashMap;
6
  use std::path::Path;
7
 
 
1
  //! GPT-based sequence generation model
2
 
3
  use crate::{Error, Result};
4
+ use ndarray::{Array, Array1, Array2, IxDyn};
5
  use std::collections::HashMap;
6
  use std::path::Path;
7
 
src/model/mod.rs CHANGED
@@ -10,8 +10,6 @@ pub use gpt::{GptModel, GptConfig};
10
  pub use embedding::{SpeakerEncoder, EmotionEncoder, SemanticEncoder};
11
  pub use session::{OnnxSession, ModelCache};
12
 
13
- use crate::{Error, Result};
14
- use ndarray::{Array1, Array2, Array3};
15
 
16
  /// Sampling strategy for generation
17
  #[derive(Debug, Clone)]
 
10
  pub use embedding::{SpeakerEncoder, EmotionEncoder, SemanticEncoder};
11
  pub use session::{OnnxSession, ModelCache};
12
 
 
 
13
 
14
  /// Sampling strategy for generation
15
  #[derive(Debug, Clone)]
src/pipeline/synthesis.rs CHANGED
@@ -3,10 +3,9 @@
3
  use crate::{
4
  audio::{load_audio, save_audio, AudioConfig, AudioData},
5
  config::Config,
6
- model::{EmotionEncoder, GptConfig, SamplingStrategy, SemanticEncoder, SpeakerEncoder},
7
  text::{TextNormalizer, TextTokenizer, TokenizerConfig},
8
- vocoder::{BigVGAN, BigVGANConfig, Vocoder},
9
- Error, Result, SAMPLE_RATE,
10
  };
11
  use ndarray::Array1;
12
  use std::path::{Path, PathBuf};
 
3
  use crate::{
4
  audio::{load_audio, save_audio, AudioConfig, AudioData},
5
  config::Config,
6
+ model::{EmotionEncoder, SamplingStrategy, SemanticEncoder, SpeakerEncoder},
7
  text::{TextNormalizer, TextTokenizer, TokenizerConfig},
8
+ vocoder::{BigVGAN, BigVGANConfig, Vocoder}, Result,
 
9
  };
10
  use ndarray::Array1;
11
  use std::path::{Path, PathBuf};
src/text/mod.rs CHANGED
@@ -69,7 +69,7 @@ pub fn contains_chinese(text: &str) -> bool {
69
 
70
  /// Check if text contains only ASCII
71
  pub fn is_ascii_only(text: &str) -> bool {
72
- text.chars().all(|c| c.is_ascii())
73
  }
74
 
75
  /// Split text into segments by language
 
69
 
70
  /// Check if text contains only ASCII
71
  pub fn is_ascii_only(text: &str) -> bool {
72
+ text.is_ascii()
73
  }
74
 
75
  /// Split text into segments by language
src/text/normalizer.rs CHANGED
@@ -1,6 +1,6 @@
1
  //! Text normalization for TTS
2
 
3
- use crate::{Error, Result};
4
  use lazy_static::lazy_static;
5
  use regex::Regex;
6
  use std::collections::HashMap;
 
1
  //! Text normalization for TTS
2
 
3
+ use crate::Result;
4
  use lazy_static::lazy_static;
5
  use regex::Regex;
6
  use std::collections::HashMap;
src/text/phoneme.rs CHANGED
@@ -3,7 +3,6 @@
3
  //! Provides grapheme-to-phoneme (G2P) conversion for English
4
  //! and Pinyin handling for Chinese
5
 
6
- use crate::Result;
7
  use lazy_static::lazy_static;
8
  use std::collections::HashMap;
9
 
 
3
  //! Provides grapheme-to-phoneme (G2P) conversion for English
4
  //! and Pinyin handling for Chinese
5
 
 
6
  use lazy_static::lazy_static;
7
  use std::collections::HashMap;
8
 
src/vocoder/activations.rs CHANGED
@@ -39,7 +39,7 @@ pub fn anti_aliased_snake(x: &[f32], alpha: f32, upsample_factor: usize) -> Vec<
39
  // Upsample
40
  let upsampled: Vec<f32> = x
41
  .iter()
42
- .flat_map(|&v| std::iter::repeat(v).take(upsample_factor))
43
  .collect();
44
 
45
  // Apply activation
 
39
  // Upsample
40
  let upsampled: Vec<f32> = x
41
  .iter()
42
+ .flat_map(|&v| std::iter::repeat_n(v, upsample_factor))
43
  .collect();
44
 
45
  // Apply activation
src/vocoder/bigvgan.rs CHANGED
@@ -3,7 +3,7 @@
3
  //! High-quality neural vocoder for mel-spectrogram to waveform conversion
4
 
5
  use crate::{Error, Result};
6
- use ndarray::{Array, Array2, IxDyn};
7
  use std::collections::HashMap;
8
  use std::path::Path;
9
 
 
3
  //! High-quality neural vocoder for mel-spectrogram to waveform conversion
4
 
5
  use crate::{Error, Result};
6
+ use ndarray::{Array2, IxDyn};
7
  use std::collections::HashMap;
8
  use std::path::Path;
9
 
src/vocoder/mod.rs CHANGED
@@ -8,9 +8,8 @@ mod activations;
8
  pub use bigvgan::{BigVGAN, BigVGANConfig, create_bigvgan_22k, create_bigvgan_24k};
9
  pub use activations::{snake_activation, snake_beta_activation, snake_activation_vec};
10
 
11
- use crate::{Error, Result};
12
  use ndarray::Array2;
13
- use num_complex::Complex;
14
 
15
  /// Vocoder trait for mel-to-waveform conversion
16
  pub trait Vocoder {
 
8
  pub use bigvgan::{BigVGAN, BigVGANConfig, create_bigvgan_22k, create_bigvgan_24k};
9
  pub use activations::{snake_activation, snake_beta_activation, snake_activation_vec};
10
 
11
+ use crate::Result;
12
  use ndarray::Array2;
 
13
 
14
  /// Vocoder trait for mel-to-waveform conversion
15
  pub trait Vocoder {