Spaces:
Running
Running
| import gradio as gr | |
| import tempfile | |
| import os | |
| import subprocess | |
| import shutil | |
| from pathlib import Path | |
| import logging | |
| from typing import List, Tuple, Dict | |
| import json | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class StemSeparator: | |
| """Modern stem separation with multiple model support""" | |
| def __init__(self): | |
| self.supported_models = { | |
| "htdemucs": { | |
| "command": "demucs", | |
| "stems": 4, | |
| "description": "HTDemucs - High quality 4-stem separation" | |
| }, | |
| "htdemucs_ft": { | |
| "command": "demucs", | |
| "model_name": "htdemucs_ft", | |
| "stems": 4, | |
| "description": "HTDemucs Fine-tuned - Enhanced 4-stem separation" | |
| }, | |
| "htdemucs_6s": { | |
| "command": "demucs", | |
| "model_name": "htdemucs_6s", | |
| "stems": 6, | |
| "description": "HTDemucs 6-stem - Bass, Drums, Vocals, Other, Guitar, Piano" | |
| }, | |
| "mdx": { | |
| "command": "demucs", | |
| "model_name": "mdx", | |
| "stems": 4, | |
| "description": "MDX - Optimized for vocal separation" | |
| }, | |
| "mdx_extra": { | |
| "command": "demucs", | |
| "model_name": "mdx_extra", | |
| "stems": 4, | |
| "description": "MDX Extra - Enhanced vocal separation" | |
| }, | |
| "spleeter_4stems": { | |
| "command": "spleeter", | |
| "model_name": "spleeter:4stems-waveform", | |
| "stems": 4, | |
| "description": "Spleeter 4-stem - Vocals, Bass, Drums, Other" | |
| }, | |
| "spleeter_5stems": { | |
| "command": "spleeter", | |
| "model_name": "spleeter:5stems-waveform", | |
| "stems": 5, | |
| "description": "Spleeter 5-stem - Vocals, Bass, Drums, Piano, Other" | |
| } | |
| } | |
| def check_dependencies(self) -> Dict[str, bool]: | |
| """Check if required tools are installed""" | |
| dependencies = {} | |
| # Check demucs | |
| try: | |
| result = subprocess.run(["python", "-m", "demucs", "--help"], | |
| capture_output=True, text=True, timeout=10) | |
| dependencies["demucs"] = result.returncode == 0 | |
| except (subprocess.TimeoutExpired, FileNotFoundError): | |
| dependencies["demucs"] = False | |
| # Check spleeter | |
| try: | |
| result = subprocess.run(["spleeter", "--help"], | |
| capture_output=True, text=True, timeout=10) | |
| dependencies["spleeter"] = result.returncode == 0 | |
| except (subprocess.TimeoutExpired, FileNotFoundError): | |
| dependencies["spleeter"] = False | |
| return dependencies | |
| def separate_audio(self, audio_file: str, model_choice: str) -> Tuple[List[str], str]: | |
| """Separate audio into stems using the selected model""" | |
| if not audio_file: | |
| return [], "β No audio file provided" | |
| if model_choice not in self.supported_models: | |
| return [], f"β Unsupported model: {model_choice}" | |
| model_config = self.supported_models[model_choice] | |
| try: | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = Path(temp_dir) | |
| # Copy input file to temp directory with proper extension | |
| input_file = Path(audio_file) | |
| temp_input = temp_path / f"input{input_file.suffix}" | |
| shutil.copy2(audio_file, temp_input) | |
| logger.info(f"Processing {temp_input} with {model_choice}") | |
| # Build command based on model type | |
| if model_config["command"] == "demucs": | |
| command = self._build_demucs_command(temp_input, temp_path, model_config) | |
| elif model_config["command"] == "spleeter": | |
| command = self._build_spleeter_command(temp_input, temp_path, model_config) | |
| else: | |
| return [], f"β Unknown command type: {model_config['command']}" | |
| # Execute separation | |
| logger.info(f"Running command: {' '.join(command)}") | |
| result = subprocess.run( | |
| command, | |
| capture_output=True, | |
| text=True, | |
| timeout=300, # 5 minute timeout | |
| cwd=temp_dir | |
| ) | |
| if result.returncode != 0: | |
| error_msg = f"β Separation failed: {result.stderr}" | |
| logger.error(error_msg) | |
| return [], error_msg | |
| # Collect output stems | |
| stems = self._collect_stems(temp_path, model_choice) | |
| if not stems: | |
| return [], "β No stems were generated" | |
| success_msg = f"β Successfully separated into {len(stems)} stems" | |
| logger.info(success_msg) | |
| return stems, success_msg | |
| except subprocess.TimeoutExpired: | |
| return [], "β Process timed out - file may be too large" | |
| except Exception as e: | |
| error_msg = f"β Error during separation: {str(e)}" | |
| logger.error(error_msg) | |
| return [], error_msg | |
| def _build_demucs_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]: | |
| """Build demucs command""" | |
| command = ["python", "-m", "demucs"] | |
| if "model_name" in model_config: | |
| command.extend(["-n", model_config["model_name"]]) | |
| command.extend([ | |
| "-o", str(output_dir), | |
| "--filename", "{track}/{stem}.{ext}", # Organized output structure | |
| str(input_file) | |
| ]) | |
| return command | |
| def _build_spleeter_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]: | |
| """Build spleeter command""" | |
| model_name = model_config.get("model_name", "spleeter:4stems-waveform") | |
| command = [ | |
| "spleeter", "separate", | |
| "-p", model_name, | |
| "-o", str(output_dir), | |
| "--filename_format", "{instrument}.{codec}", | |
| str(input_file) | |
| ] | |
| return command | |
| def _collect_stems(self, output_dir: Path, model_choice: str) -> List[str]: | |
| """Collect generated stem files""" | |
| stems = [] | |
| # Search for audio files in output directory | |
| for audio_file in output_dir.rglob("*.wav"): | |
| if audio_file.is_file() and audio_file.stat().st_size > 0: | |
| # Copy to a permanent location that Gradio can access | |
| permanent_path = self._copy_to_permanent_location(audio_file) | |
| if permanent_path: | |
| stems.append(permanent_path) | |
| # Also check for other common audio formats | |
| for ext in ["*.mp3", "*.flac", "*.m4a"]: | |
| for audio_file in output_dir.rglob(ext): | |
| if audio_file.is_file() and audio_file.stat().st_size > 0: | |
| permanent_path = self._copy_to_permanent_location(audio_file) | |
| if permanent_path: | |
| stems.append(permanent_path) | |
| return sorted(stems) | |
| def _copy_to_permanent_location(self, temp_file: Path) -> str: | |
| """Copy temporary file to permanent location for Gradio""" | |
| try: | |
| # Create output directory if it doesn't exist | |
| output_dir = Path("./separated_stems") | |
| output_dir.mkdir(exist_ok=True) | |
| # Generate unique filename | |
| import time | |
| timestamp = int(time.time() * 1000) | |
| permanent_file = output_dir / f"{temp_file.stem}_{timestamp}{temp_file.suffix}" | |
| shutil.copy2(temp_file, permanent_file) | |
| return str(permanent_file) | |
| except Exception as e: | |
| logger.error(f"Failed to copy {temp_file}: {e}") | |
| return None | |
| # Initialize separator | |
| separator = StemSeparator() | |
| def get_available_models() -> List[Tuple[str, str]]: | |
| """Get list of available models based on installed dependencies""" | |
| deps = separator.check_dependencies() | |
| available_models = [] | |
| for model_id, config in separator.supported_models.items(): | |
| if config["command"] in deps and deps[config["command"]]: | |
| label = f"{model_id} ({config['stems']} stems) - {config['description']}" | |
| available_models.append((label, model_id)) | |
| if not available_models: | |
| available_models = [("No models available - install demucs or spleeter", "none")] | |
| return available_models | |
| def separate_stems_ui(audio_file: str, model_choice: str) -> Tuple[List[str], str]: | |
| """UI wrapper for stem separation""" | |
| if model_choice == "none": | |
| return [], "β Please install demucs and/or spleeter first" | |
| stems, message = separator.separate_audio(audio_file, model_choice) | |
| return stems, message | |
| def create_audio_gallery(stems: List[str]) -> List[gr.Audio]: | |
| """Create audio components for each stem""" | |
| if not stems: | |
| return [] | |
| audio_components = [] | |
| for i, stem_path in enumerate(stems): | |
| stem_name = Path(stem_path).stem | |
| audio_comp = gr.Audio( | |
| value=stem_path, | |
| label=f"Stem {i+1}: {stem_name}", | |
| interactive=False, | |
| show_download_button=True | |
| ) | |
| audio_components.append(audio_comp) | |
| return audio_components | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks( | |
| title="π΅ Advanced Music Stem Separator", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .audio-container { margin: 10px 0; } | |
| .status-success { color: #22c55e; font-weight: bold; } | |
| .status-error { color: #ef4444; font-weight: bold; } | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π΅ Advanced Music Stem Separator | |
| Separate music into individual stems (vocals, instruments, etc.) using state-of-the-art AI models. | |
| Supports up to 6 stems depending on the model chosen. | |
| **Supported Models:** | |
| - **Demucs Models**: HTDemucs, HTDemucs-FT, HTDemucs-6s, MDX, MDX-Extra | |
| - **Spleeter Models**: 4-stem and 5-stem separation | |
| **Requirements**: Install `demucs` and/or `spleeter` packages | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_input = gr.Audio( | |
| type="filepath", | |
| label="πΌ Upload Audio File", | |
| info="Supported formats: WAV, MP3, FLAC, M4A" | |
| ) | |
| model_dropdown = gr.Dropdown( | |
| choices=get_available_models(), | |
| value=get_available_models()[0][1] if get_available_models() else "none", | |
| label="π§ Separation Model", | |
| info="Choose the AI model for stem separation" | |
| ) | |
| separate_btn = gr.Button( | |
| "ποΈ Separate Stems", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| ### βΉοΈ Model Info | |
| - **4-stem**: Vocals, Bass, Drums, Other | |
| - **5-stem**: + Piano | |
| - **6-stem**: + Guitar | |
| ### π‘ Tips | |
| - Higher quality input = better separation | |
| - Processing time varies by model and file length | |
| - Results will appear below after processing | |
| """) | |
| # Status display | |
| status_display = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| visible=True | |
| ) | |
| # Dynamic audio outputs | |
| stems_state = gr.State([]) | |
| audio_outputs = gr.Column(visible=False) | |
| def process_and_display(audio_file, model_choice): | |
| if not audio_file: | |
| return [], "β Please upload an audio file", gr.Column(visible=False) | |
| # Process the audio | |
| stems, message = separate_stems_ui(audio_file, model_choice) | |
| # Create audio components | |
| if stems: | |
| with gr.Column() as output_col: | |
| gr.Markdown(f"### πΆ Separated Stems ({len(stems)} files)") | |
| for i, stem_path in enumerate(stems): | |
| stem_name = Path(stem_path).stem.replace("_", " ").title() | |
| gr.Audio( | |
| value=stem_path, | |
| label=f"π΅ {stem_name}", | |
| show_download_button=True, | |
| interactive=False | |
| ) | |
| return stems, message, gr.Column(visible=True) | |
| else: | |
| return [], message, gr.Column(visible=False) | |
| separate_btn.click( | |
| fn=process_and_display, | |
| inputs=[audio_input, model_dropdown], | |
| outputs=[stems_state, status_display, audio_outputs], | |
| show_progress=True | |
| ) | |
| # Dependency check display | |
| with gr.Accordion("π§ System Status", open=False): | |
| def check_system(): | |
| deps = separator.check_dependencies() | |
| status_text = "**Dependency Status:**\n" | |
| for tool, available in deps.items(): | |
| status = "β Available" if available else "β Not installed" | |
| status_text += f"- {tool}: {status}\n" | |
| if not any(deps.values()): | |
| status_text += "\n**Installation Instructions:**\n" | |
| status_text += "```bash\n" | |
| status_text += "# Install Demucs (recommended)\n" | |
| status_text += "pip install demucs\n\n" | |
| status_text += "# Install Spleeter (alternative)\n" | |
| status_text += "pip install spleeter tensorflow\n" | |
| status_text += "```" | |
| return status_text | |
| system_status = gr.Markdown(value=check_system()) | |
| gr.Button("π Refresh Status").click( | |
| fn=check_system, | |
| outputs=system_status | |
| ) | |
| return demo | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| debug=True | |
| ) |