stem-remixer / app.py
ahk-d's picture
Create app.py
314aa29 verified
raw
history blame
15.2 kB
import gradio as gr
import tempfile
import os
import subprocess
import shutil
from pathlib import Path
import logging
from typing import List, Tuple, Dict
import json
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class StemSeparator:
"""Modern stem separation with multiple model support"""
def __init__(self):
self.supported_models = {
"htdemucs": {
"command": "demucs",
"stems": 4,
"description": "HTDemucs - High quality 4-stem separation"
},
"htdemucs_ft": {
"command": "demucs",
"model_name": "htdemucs_ft",
"stems": 4,
"description": "HTDemucs Fine-tuned - Enhanced 4-stem separation"
},
"htdemucs_6s": {
"command": "demucs",
"model_name": "htdemucs_6s",
"stems": 6,
"description": "HTDemucs 6-stem - Bass, Drums, Vocals, Other, Guitar, Piano"
},
"mdx": {
"command": "demucs",
"model_name": "mdx",
"stems": 4,
"description": "MDX - Optimized for vocal separation"
},
"mdx_extra": {
"command": "demucs",
"model_name": "mdx_extra",
"stems": 4,
"description": "MDX Extra - Enhanced vocal separation"
},
"spleeter_4stems": {
"command": "spleeter",
"model_name": "spleeter:4stems-waveform",
"stems": 4,
"description": "Spleeter 4-stem - Vocals, Bass, Drums, Other"
},
"spleeter_5stems": {
"command": "spleeter",
"model_name": "spleeter:5stems-waveform",
"stems": 5,
"description": "Spleeter 5-stem - Vocals, Bass, Drums, Piano, Other"
}
}
def check_dependencies(self) -> Dict[str, bool]:
"""Check if required tools are installed"""
dependencies = {}
# Check demucs
try:
result = subprocess.run(["python", "-m", "demucs", "--help"],
capture_output=True, text=True, timeout=10)
dependencies["demucs"] = result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
dependencies["demucs"] = False
# Check spleeter
try:
result = subprocess.run(["spleeter", "--help"],
capture_output=True, text=True, timeout=10)
dependencies["spleeter"] = result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
dependencies["spleeter"] = False
return dependencies
def separate_audio(self, audio_file: str, model_choice: str) -> Tuple[List[str], str]:
"""Separate audio into stems using the selected model"""
if not audio_file:
return [], "❌ No audio file provided"
if model_choice not in self.supported_models:
return [], f"❌ Unsupported model: {model_choice}"
model_config = self.supported_models[model_choice]
try:
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
# Copy input file to temp directory with proper extension
input_file = Path(audio_file)
temp_input = temp_path / f"input{input_file.suffix}"
shutil.copy2(audio_file, temp_input)
logger.info(f"Processing {temp_input} with {model_choice}")
# Build command based on model type
if model_config["command"] == "demucs":
command = self._build_demucs_command(temp_input, temp_path, model_config)
elif model_config["command"] == "spleeter":
command = self._build_spleeter_command(temp_input, temp_path, model_config)
else:
return [], f"❌ Unknown command type: {model_config['command']}"
# Execute separation
logger.info(f"Running command: {' '.join(command)}")
result = subprocess.run(
command,
capture_output=True,
text=True,
timeout=300, # 5 minute timeout
cwd=temp_dir
)
if result.returncode != 0:
error_msg = f"❌ Separation failed: {result.stderr}"
logger.error(error_msg)
return [], error_msg
# Collect output stems
stems = self._collect_stems(temp_path, model_choice)
if not stems:
return [], "❌ No stems were generated"
success_msg = f"βœ… Successfully separated into {len(stems)} stems"
logger.info(success_msg)
return stems, success_msg
except subprocess.TimeoutExpired:
return [], "❌ Process timed out - file may be too large"
except Exception as e:
error_msg = f"❌ Error during separation: {str(e)}"
logger.error(error_msg)
return [], error_msg
def _build_demucs_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]:
"""Build demucs command"""
command = ["python", "-m", "demucs"]
if "model_name" in model_config:
command.extend(["-n", model_config["model_name"]])
command.extend([
"-o", str(output_dir),
"--filename", "{track}/{stem}.{ext}", # Organized output structure
str(input_file)
])
return command
def _build_spleeter_command(self, input_file: Path, output_dir: Path, model_config: Dict) -> List[str]:
"""Build spleeter command"""
model_name = model_config.get("model_name", "spleeter:4stems-waveform")
command = [
"spleeter", "separate",
"-p", model_name,
"-o", str(output_dir),
"--filename_format", "{instrument}.{codec}",
str(input_file)
]
return command
def _collect_stems(self, output_dir: Path, model_choice: str) -> List[str]:
"""Collect generated stem files"""
stems = []
# Search for audio files in output directory
for audio_file in output_dir.rglob("*.wav"):
if audio_file.is_file() and audio_file.stat().st_size > 0:
# Copy to a permanent location that Gradio can access
permanent_path = self._copy_to_permanent_location(audio_file)
if permanent_path:
stems.append(permanent_path)
# Also check for other common audio formats
for ext in ["*.mp3", "*.flac", "*.m4a"]:
for audio_file in output_dir.rglob(ext):
if audio_file.is_file() and audio_file.stat().st_size > 0:
permanent_path = self._copy_to_permanent_location(audio_file)
if permanent_path:
stems.append(permanent_path)
return sorted(stems)
def _copy_to_permanent_location(self, temp_file: Path) -> str:
"""Copy temporary file to permanent location for Gradio"""
try:
# Create output directory if it doesn't exist
output_dir = Path("./separated_stems")
output_dir.mkdir(exist_ok=True)
# Generate unique filename
import time
timestamp = int(time.time() * 1000)
permanent_file = output_dir / f"{temp_file.stem}_{timestamp}{temp_file.suffix}"
shutil.copy2(temp_file, permanent_file)
return str(permanent_file)
except Exception as e:
logger.error(f"Failed to copy {temp_file}: {e}")
return None
# Initialize separator
separator = StemSeparator()
def get_available_models() -> List[Tuple[str, str]]:
"""Get list of available models based on installed dependencies"""
deps = separator.check_dependencies()
available_models = []
for model_id, config in separator.supported_models.items():
if config["command"] in deps and deps[config["command"]]:
label = f"{model_id} ({config['stems']} stems) - {config['description']}"
available_models.append((label, model_id))
if not available_models:
available_models = [("No models available - install demucs or spleeter", "none")]
return available_models
def separate_stems_ui(audio_file: str, model_choice: str) -> Tuple[List[str], str]:
"""UI wrapper for stem separation"""
if model_choice == "none":
return [], "❌ Please install demucs and/or spleeter first"
stems, message = separator.separate_audio(audio_file, model_choice)
return stems, message
def create_audio_gallery(stems: List[str]) -> List[gr.Audio]:
"""Create audio components for each stem"""
if not stems:
return []
audio_components = []
for i, stem_path in enumerate(stems):
stem_name = Path(stem_path).stem
audio_comp = gr.Audio(
value=stem_path,
label=f"Stem {i+1}: {stem_name}",
interactive=False,
show_download_button=True
)
audio_components.append(audio_comp)
return audio_components
# Create Gradio interface
def create_interface():
with gr.Blocks(
title="🎡 Advanced Music Stem Separator",
theme=gr.themes.Soft(),
css="""
.audio-container { margin: 10px 0; }
.status-success { color: #22c55e; font-weight: bold; }
.status-error { color: #ef4444; font-weight: bold; }
"""
) as demo:
gr.Markdown("""
# 🎡 Advanced Music Stem Separator
Separate music into individual stems (vocals, instruments, etc.) using state-of-the-art AI models.
Supports up to 6 stems depending on the model chosen.
**Supported Models:**
- **Demucs Models**: HTDemucs, HTDemucs-FT, HTDemucs-6s, MDX, MDX-Extra
- **Spleeter Models**: 4-stem and 5-stem separation
**Requirements**: Install `demucs` and/or `spleeter` packages
""")
with gr.Row():
with gr.Column(scale=2):
audio_input = gr.Audio(
type="filepath",
label="🎼 Upload Audio File",
info="Supported formats: WAV, MP3, FLAC, M4A"
)
model_dropdown = gr.Dropdown(
choices=get_available_models(),
value=get_available_models()[0][1] if get_available_models() else "none",
label="🧠 Separation Model",
info="Choose the AI model for stem separation"
)
separate_btn = gr.Button(
"πŸŽ›οΈ Separate Stems",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
gr.Markdown("""
### ℹ️ Model Info
- **4-stem**: Vocals, Bass, Drums, Other
- **5-stem**: + Piano
- **6-stem**: + Guitar
### πŸ’‘ Tips
- Higher quality input = better separation
- Processing time varies by model and file length
- Results will appear below after processing
""")
# Status display
status_display = gr.Textbox(
label="Status",
interactive=False,
visible=True
)
# Dynamic audio outputs
stems_state = gr.State([])
audio_outputs = gr.Column(visible=False)
def process_and_display(audio_file, model_choice):
if not audio_file:
return [], "❌ Please upload an audio file", gr.Column(visible=False)
# Process the audio
stems, message = separate_stems_ui(audio_file, model_choice)
# Create audio components
if stems:
with gr.Column() as output_col:
gr.Markdown(f"### 🎢 Separated Stems ({len(stems)} files)")
for i, stem_path in enumerate(stems):
stem_name = Path(stem_path).stem.replace("_", " ").title()
gr.Audio(
value=stem_path,
label=f"🎡 {stem_name}",
show_download_button=True,
interactive=False
)
return stems, message, gr.Column(visible=True)
else:
return [], message, gr.Column(visible=False)
separate_btn.click(
fn=process_and_display,
inputs=[audio_input, model_dropdown],
outputs=[stems_state, status_display, audio_outputs],
show_progress=True
)
# Dependency check display
with gr.Accordion("πŸ”§ System Status", open=False):
def check_system():
deps = separator.check_dependencies()
status_text = "**Dependency Status:**\n"
for tool, available in deps.items():
status = "βœ… Available" if available else "❌ Not installed"
status_text += f"- {tool}: {status}\n"
if not any(deps.values()):
status_text += "\n**Installation Instructions:**\n"
status_text += "```bash\n"
status_text += "# Install Demucs (recommended)\n"
status_text += "pip install demucs\n\n"
status_text += "# Install Spleeter (alternative)\n"
status_text += "pip install spleeter tensorflow\n"
status_text += "```"
return status_text
system_status = gr.Markdown(value=check_system())
gr.Button("πŸ”„ Refresh Status").click(
fn=check_system,
outputs=system_status
)
return demo
# Launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
debug=True
)