MEscriva commited on
Commit
b6b335c
·
verified ·
1 Parent(s): 9568604

Upload diarization_pyannote_demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diarization_pyannote_demo.py +444 -0
diarization_pyannote_demo.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script de diarisation utilisant pyannote.audio (Community-1 ou 3.1).
4
+
5
+ Ce script prend en entrée un fichier audio et génère :
6
+ - Un fichier RTTM
7
+ - Un fichier JSON avec les segments de diarisation
8
+
9
+ Le modèle Community-1 est utilisé par défaut (meilleur que 3.1 selon les benchmarks).
10
+
11
+ Usage:
12
+ python diarization_pyannote_demo.py <input_audio.wav> [--output_dir OUTPUT_DIR]
13
+ python diarization_pyannote_demo.py audio.wav --num_speakers 3
14
+ python diarization_pyannote_demo.py audio.wav --model pyannote/speaker-diarization-precision-2
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import List, Dict, Any
23
+
24
+ try:
25
+ # Importer pyannote en évitant les imports NeMo si possible
26
+ import os
27
+ # Désactiver temporairement l'import NeMo dans pyannote si nécessaire
28
+ os.environ['PYANNOTE_DISABLE_NEMO'] = '1'
29
+
30
+ from pyannote.audio import Pipeline
31
+ from pyannote.core import Annotation
32
+ try:
33
+ from pyannote.audio.pipelines.utils.hook import ProgressHook
34
+ HAS_PROGRESS_HOOK = True
35
+ except ImportError:
36
+ HAS_PROGRESS_HOOK = False
37
+ except ImportError as e:
38
+ print("ERREUR: pyannote.audio n'est pas installé. Voir INSTALL.md pour les instructions.")
39
+ print(f"Détails: {e}")
40
+ sys.exit(1)
41
+ except Exception as e:
42
+ # Si l'import échoue à cause de NeMo, donner des instructions
43
+ if 'nemo' in str(e).lower() or 'transformers' in str(e).lower():
44
+ print("ERREUR: Conflit de dépendances avec NeMo/transformers.")
45
+ print("Solution recommandée: Utiliser un environnement conda dédié.")
46
+ print("Exécuter: ./setup_nemo_env.sh")
47
+ print(f"Détails: {e}")
48
+ else:
49
+ print(f"ERREUR: {e}")
50
+ sys.exit(1)
51
+
52
+ import torch
53
+
54
+ # Corriger le problème PyTorch 2.6 avec weights_only
55
+ if hasattr(torch.serialization, 'add_safe_globals'):
56
+ try:
57
+ torch.serialization.add_safe_globals([torch.torch_version.TorchVersion])
58
+ except:
59
+ pass
60
+
61
+
62
+ def load_pyannote_pipeline(
63
+ model_name: str = "pyannote/speaker-diarization-community-1",
64
+ token: str = None
65
+ ) -> Pipeline:
66
+ """
67
+ Charge le pipeline de diarisation pyannote.
68
+
69
+ Args:
70
+ model_name: Nom du modèle Hugging Face
71
+ - "pyannote/speaker-diarization-community-1" (défaut, meilleur que 3.1)
72
+ - "pyannote/speaker-diarization-3.1" (legacy)
73
+ - "pyannote/speaker-diarization-precision-2" (nécessite API key pyannoteAI)
74
+ token: Token d'authentification (HF_TOKEN ou API key pyannoteAI)
75
+
76
+ Returns:
77
+ Pipeline pyannote configuré
78
+ """
79
+ print(f"Chargement du pipeline pyannote: {model_name}")
80
+
81
+ # Déterminer le token à utiliser
82
+ if token is None:
83
+ # Pour precision-2, utiliser l'API key pyannoteAI si disponible
84
+ if "precision-2" in model_name:
85
+ token = os.environ.get("PYANNOTEAI_API_KEY") or os.environ.get("HF_TOKEN")
86
+ else:
87
+ token = os.environ.get("HF_TOKEN")
88
+
89
+ # Configurer le token dans huggingface_hub si disponible
90
+ if token:
91
+ try:
92
+ from huggingface_hub import login
93
+ login(token=token, add_to_git_credential=False)
94
+ except Exception:
95
+ # Si login échoue, on essaiera quand même avec use_auth_token
96
+ pass
97
+
98
+ if not token:
99
+ print("ATTENTION: Token d'authentification non défini.")
100
+ if "precision-2" in model_name:
101
+ print("Pour precision-2, définir: export PYANNOTEAI_API_KEY='votre_api_key'")
102
+ else:
103
+ print("Définir: export HF_TOKEN='votre_token'")
104
+ print("Note: Le script fonctionnera mais le téléchargement du modèle peut échouer.")
105
+
106
+ try:
107
+ # Ne pas passer use_auth_token car il cause des erreurs avec les nouvelles versions
108
+ # Le token est déjà configuré via huggingface_hub.login() si disponible
109
+ pipeline = Pipeline.from_pretrained(model_name)
110
+
111
+ # Déplacer sur GPU si disponible
112
+ if torch.cuda.is_available():
113
+ pipeline = pipeline.to(torch.device("cuda"))
114
+ print("Pipeline chargé sur GPU")
115
+ else:
116
+ print("Pipeline chargé sur CPU")
117
+
118
+ return pipeline
119
+
120
+ except Exception as e:
121
+ print(f"ERREUR lors du chargement du pipeline: {e}")
122
+ print("\nSolutions possibles:")
123
+ print("1. Vérifier que vous avez accepté les conditions d'utilisation sur Hugging Face")
124
+ print("2. Configurer un token: export HF_TOKEN='votre_token'")
125
+ if "precision-2" in model_name:
126
+ print("3. Pour precision-2, créer une API key sur pyannoteAI dashboard")
127
+ print("4. Vérifier votre connexion internet")
128
+ sys.exit(1)
129
+
130
+
131
+ def convert_audio_if_needed(audio_path: str) -> str:
132
+ """
133
+ Convertit l'audio en WAV si nécessaire (pour les formats non supportés).
134
+
135
+ Args:
136
+ audio_path: Chemin vers le fichier audio
137
+
138
+ Returns:
139
+ Chemin vers le fichier audio (converti si nécessaire)
140
+ """
141
+ ext = Path(audio_path).suffix.lower()
142
+
143
+ # Formats supportés directement par pyannote
144
+ supported_formats = {'.wav', '.flac', '.ogg'}
145
+
146
+ if ext in supported_formats:
147
+ return audio_path
148
+
149
+ # Convertir en WAV si nécessaire
150
+ if ext in {'.m4a', '.mp3', '.mp4', '.aac'}:
151
+ print(f"Conversion de {ext} en WAV...")
152
+ import librosa
153
+ import soundfile as sf
154
+
155
+ wav_path = str(Path(audio_path).with_suffix('.wav'))
156
+
157
+ # Vérifier si le fichier WAV existe déjà
158
+ if os.path.exists(wav_path):
159
+ print(f"Fichier WAV existant trouvé: {wav_path}")
160
+ return wav_path
161
+
162
+ try:
163
+ y, sr = librosa.load(audio_path, sr=16000, mono=True)
164
+ sf.write(wav_path, y, sr)
165
+ print(f"✅ Converti en WAV: {wav_path}")
166
+ return wav_path
167
+ except Exception as e:
168
+ print(f"ATTENTION: Erreur lors de la conversion, utilisation du fichier original: {e}")
169
+ return audio_path
170
+
171
+ return audio_path
172
+
173
+
174
+ def run_pyannote_diarization(
175
+ audio_path: str,
176
+ output_dir: str = "outputs/pyannote",
177
+ model_name: str = "pyannote/speaker-diarization-community-1",
178
+ num_speakers: int = None,
179
+ min_speakers: int = None,
180
+ max_speakers: int = None,
181
+ use_exclusive: bool = False,
182
+ show_progress: bool = True
183
+ ) -> Dict[str, Any]:
184
+ """
185
+ Exécute le pipeline de diarisation pyannote.
186
+
187
+ Args:
188
+ audio_path: Chemin vers le fichier audio
189
+ output_dir: Répertoire de sortie
190
+ model_name: Nom du modèle à utiliser
191
+ num_speakers: Nombre exact de locuteurs (si connu)
192
+ min_speakers: Nombre minimum de locuteurs
193
+ max_speakers: Nombre maximum de locuteurs
194
+ use_exclusive: Utiliser exclusive_speaker_diarization (Community-1+)
195
+ show_progress: Afficher la progression
196
+
197
+ Returns:
198
+ Dictionnaire contenant les résultats de diarisation
199
+ """
200
+ # Convertir l'audio si nécessaire
201
+ audio_path = convert_audio_if_needed(audio_path)
202
+ print(f"Chargement de l'audio: {audio_path}")
203
+
204
+ # Créer le répertoire de sortie si nécessaire
205
+ os.makedirs(output_dir, exist_ok=True)
206
+
207
+ # Charger le pipeline
208
+ pipeline = load_pyannote_pipeline(model_name)
209
+
210
+ # Préparer les options de diarisation
211
+ diarization_options = {}
212
+ if num_speakers is not None:
213
+ diarization_options["num_speakers"] = num_speakers
214
+ print(f"Nombre de locuteurs fixé: {num_speakers}")
215
+ if min_speakers is not None:
216
+ diarization_options["min_speakers"] = min_speakers
217
+ print(f"Nombre minimum de locuteurs: {min_speakers}")
218
+ if max_speakers is not None:
219
+ diarization_options["max_speakers"] = max_speakers
220
+ print(f"Nombre maximum de locuteurs: {max_speakers}")
221
+
222
+ # Exécuter la diarisation
223
+ print("Exécution de la diarisation...")
224
+ try:
225
+ if show_progress and HAS_PROGRESS_HOOK:
226
+ with ProgressHook() as hook:
227
+ diarization = pipeline(audio_path, hook=hook, **diarization_options)
228
+ else:
229
+ diarization = pipeline(audio_path, **diarization_options)
230
+ except Exception as e:
231
+ print(f"ERREUR lors de la diarisation: {e}")
232
+ sys.exit(1)
233
+
234
+ # Utiliser exclusive_speaker_diarization si disponible et demandé
235
+ if use_exclusive and hasattr(diarization, 'exclusive_speaker_diarization'):
236
+ print("Utilisation de exclusive_speaker_diarization")
237
+ annotation = diarization.exclusive_speaker_diarization
238
+ else:
239
+ annotation = diarization
240
+
241
+ # Convertir l'annotation pyannote en format standard
242
+ segments = annotation_to_segments(annotation)
243
+
244
+ # Calculer les statistiques
245
+ num_speakers_detected = len(set(s["speaker"] for s in segments))
246
+
247
+ # Calculer la durée totale
248
+ if segments:
249
+ duration = max(s["end"] for s in segments)
250
+ else:
251
+ duration = 0.0
252
+
253
+ return {
254
+ "segments": segments,
255
+ "num_speakers": num_speakers_detected,
256
+ "duration": duration
257
+ }
258
+
259
+
260
+ def annotation_to_segments(annotation: Annotation) -> List[Dict[str, Any]]:
261
+ """
262
+ Convertit une annotation pyannote en liste de segments.
263
+
264
+ Args:
265
+ annotation: Annotation pyannote
266
+
267
+ Returns:
268
+ Liste de segments au format [{"speaker": "...", "start": ..., "end": ...}]
269
+ """
270
+ segments = []
271
+
272
+ # Obtenir tous les locuteurs uniques
273
+ speakers = sorted(annotation.labels())
274
+
275
+ # Créer un mapping pour normaliser les IDs
276
+ speaker_mapping = {}
277
+ for idx, speaker in enumerate(speakers):
278
+ speaker_mapping[speaker] = f"SPEAKER_{idx:02d}"
279
+
280
+ # Parcourir tous les segments
281
+ for segment, track, speaker in annotation.itertracks(yield_label=True):
282
+ normalized_speaker = speaker_mapping.get(speaker, speaker)
283
+
284
+ segments.append({
285
+ "speaker": normalized_speaker,
286
+ "start": round(segment.start, 2),
287
+ "end": round(segment.end, 2)
288
+ })
289
+
290
+ # Trier par temps de début
291
+ segments.sort(key=lambda x: x["start"])
292
+ return segments
293
+
294
+
295
+ def write_rttm(segments: List[Dict[str, Any]], output_path: str, audio_name: str):
296
+ """
297
+ Écrit un fichier RTTM à partir des segments.
298
+
299
+ Args:
300
+ segments: Liste de segments
301
+ output_path: Chemin du fichier RTTM de sortie
302
+ audio_name: Nom du fichier audio (sans extension)
303
+ """
304
+ with open(output_path, 'w') as f:
305
+ for seg in segments:
306
+ duration = seg["end"] - seg["start"]
307
+ # Format RTTM: SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
308
+ f.write(f"SPEAKER {audio_name} 1 {seg['start']:.3f} {duration:.3f} <NA> <NA> {seg['speaker']} <NA> <NA>\n")
309
+
310
+
311
+ def write_json(segments: List[Dict[str, Any]], output_path: str):
312
+ """
313
+ Écrit un fichier JSON à partir des segments.
314
+
315
+ Args:
316
+ segments: Liste de segments
317
+ output_path: Chemin du fichier JSON de sortie
318
+ """
319
+ with open(output_path, 'w', encoding='utf-8') as f:
320
+ json.dump(segments, f, indent=2, ensure_ascii=False)
321
+
322
+
323
+ def main():
324
+ parser = argparse.ArgumentParser(
325
+ description="Diarisation avec pyannote.audio 3.x",
326
+ formatter_class=argparse.RawDescriptionHelpFormatter,
327
+ epilog=__doc__
328
+ )
329
+ parser.add_argument(
330
+ "audio_path",
331
+ type=str,
332
+ help="Chemin vers le fichier audio"
333
+ )
334
+ parser.add_argument(
335
+ "--output_dir",
336
+ type=str,
337
+ default="outputs/pyannote",
338
+ help="Répertoire de sortie (défaut: outputs/pyannote)"
339
+ )
340
+ parser.add_argument(
341
+ "--model",
342
+ type=str,
343
+ default="pyannote/speaker-diarization-community-1",
344
+ help="Nom du modèle Hugging Face (défaut: pyannote/speaker-diarization-community-1). "
345
+ "Options: community-1, 3.1, precision-2 (nécessite API key pyannoteAI)"
346
+ )
347
+ parser.add_argument(
348
+ "--num_speakers",
349
+ type=int,
350
+ default=None,
351
+ help="Nombre exact de locuteurs (si connu à l'avance)"
352
+ )
353
+ parser.add_argument(
354
+ "--min_speakers",
355
+ type=int,
356
+ default=None,
357
+ help="Nombre minimum de locuteurs"
358
+ )
359
+ parser.add_argument(
360
+ "--max_speakers",
361
+ type=int,
362
+ default=None,
363
+ help="Nombre maximum de locuteurs"
364
+ )
365
+ parser.add_argument(
366
+ "--exclusive",
367
+ action="store_true",
368
+ help="Utiliser exclusive_speaker_diarization (Community-1+, simplifie la réconciliation avec transcription)"
369
+ )
370
+ parser.add_argument(
371
+ "--no-progress",
372
+ action="store_true",
373
+ help="Ne pas afficher la barre de progression"
374
+ )
375
+
376
+ args = parser.parse_args()
377
+
378
+ if not os.path.exists(args.audio_path):
379
+ print(f"ERREUR: Fichier audio introuvable: {args.audio_path}")
380
+ sys.exit(1)
381
+
382
+ # Normaliser le nom du modèle si version courte fournie
383
+ model_name = args.model
384
+ if model_name == "community-1":
385
+ model_name = "pyannote/speaker-diarization-community-1"
386
+ elif model_name == "3.1":
387
+ model_name = "pyannote/speaker-diarization-3.1"
388
+ elif model_name == "precision-2":
389
+ model_name = "pyannote/speaker-diarization-precision-2"
390
+
391
+ # Exécuter la diarisation
392
+ results = run_pyannote_diarization(
393
+ args.audio_path,
394
+ args.output_dir,
395
+ model_name,
396
+ num_speakers=args.num_speakers,
397
+ min_speakers=args.min_speakers,
398
+ max_speakers=args.max_speakers,
399
+ use_exclusive=args.exclusive,
400
+ show_progress=not args.no_progress
401
+ )
402
+
403
+ # Préparer les chemins de sortie
404
+ audio_name = Path(args.audio_path).stem
405
+ rttm_path = os.path.join(args.output_dir, f"{audio_name}.rttm")
406
+ json_path = os.path.join(args.output_dir, f"{audio_name}.json")
407
+
408
+ # Écrire les fichiers de sortie
409
+ write_rttm(results["segments"], rttm_path, audio_name)
410
+ write_json(results["segments"], json_path)
411
+
412
+ # Afficher les statistiques
413
+ print("\n" + "="*50)
414
+ print("RÉSULTATS DE LA DIARISATION")
415
+ print("="*50)
416
+ print(f"Nombre de locuteurs détectés: {results['num_speakers']}")
417
+ print(f"Durée totale: {results['duration']:.2f} secondes")
418
+ print(f"Nombre de segments: {len(results['segments'])}")
419
+
420
+ # Statistiques par locuteur
421
+ speaker_stats = {}
422
+ for seg in results["segments"]:
423
+ speaker = seg["speaker"]
424
+ duration = seg["end"] - seg["start"]
425
+ if speaker not in speaker_stats:
426
+ speaker_stats[speaker] = {"total_duration": 0.0, "num_segments": 0}
427
+ speaker_stats[speaker]["total_duration"] += duration
428
+ speaker_stats[speaker]["num_segments"] += 1
429
+
430
+ print("\nStatistiques par locuteur:")
431
+ for speaker, stats in sorted(speaker_stats.items()):
432
+ avg_duration = stats["total_duration"] / stats["num_segments"] if stats["num_segments"] > 0 else 0
433
+ print(f" {speaker}: {stats['num_segments']} segments, "
434
+ f"{stats['total_duration']:.2f}s total, "
435
+ f"{avg_duration:.2f}s moyenne/segment")
436
+
437
+ print(f"\nFichiers générés:")
438
+ print(f" RTTM: {rttm_path}")
439
+ print(f" JSON: {json_path}")
440
+
441
+
442
+ if __name__ == "__main__":
443
+ main()
444
+