File size: 23,114 Bytes
d8ec62e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
import os
import json
import shutil
import librosa
import soundfile
import numpy as np
import gradio as gr
from UVR_interface import root, UVRInterface, VR_MODELS_DIR, MDX_MODELS_DIR, DEMUCS_MODELS_DIR
from gui_data.constants import *
from typing import List, Dict, Callable, Union
import wget

class UVRWebUI:

    def __init__(self, uvr: UVRInterface, online_data_path: str) -> None:
        self.uvr = uvr
        self.models_url = self.get_models_url(online_data_path)
        self.define_layout()
        self.input_temp_dir = '__temp'
        self.export_path = 'out'
        if not os.path.exists(self.input_temp_dir):
            os.makedirs(self.input_temp_dir, exist_ok=True)
        if not os.path.exists(self.export_path):
            os.makedirs(self.export_path, exist_ok=True)

    def get_models_url(self, models_info_path: str) -> Dict[str, Dict]:
        with open(models_info_path, 'r') as f:
            online_data = json.loads(f.read())
        models_url = {}
        for arch, download_list_key in zip([VR_ARCH_TYPE, MDX_ARCH_TYPE], ['vr_download_list', 'mdx_download_list']):
            models_url[arch] = {model_name: NORMAL_REPO + model_filename_part for model_name, model_filename_part in online_data[download_list_key].items()}
        models_url[DEMUCS_ARCH_TYPE] = online_data['demucs_download_list']
        return models_url

    def get_local_models(self, arch: str) -> List[str]:
        model_config = {VR_ARCH_TYPE: (VR_MODELS_DIR, '.pth'), MDX_ARCH_TYPE: (MDX_MODELS_DIR, '.onnx'), DEMUCS_ARCH_TYPE: (DEMUCS_MODELS_DIR, '.yaml')}
        try:
            model_dir, suffix = model_config[arch]
            if not os.path.exists(model_dir):
                os.makedirs(model_dir, exist_ok=True)
                return []
        except KeyError:
            print(f'Error: Unknown arch type: {arch} in get_local_models')
            return []
        if not os.path.exists(model_dir):
            print(f'Warning: Model directory {model_dir} still does not exist for arch {arch}.')
            return []
        return sorted([os.path.splitext(f)[0] for f in os.listdir(model_dir) if f.endswith(suffix) and os.path.isfile(os.path.join(model_dir, f))])

    def set_arch_setting_value(self, arch: str, setting1, setting2):
        if arch == VR_ARCH_TYPE:
            root.window_size_var.set(setting1)
            root.aggression_setting_var.set(setting2)
        elif arch == MDX_ARCH_TYPE:
            root.mdx_batch_size_var.set(setting1)
            root.compensate_var.set(setting2)
        elif arch == DEMUCS_ARCH_TYPE:
            pass

    def arch_select_update(self, arch: str) -> List[Dict]:
        choices = self.get_local_models(arch)
        if not choices:
            print(f'Warning: No local models found for {arch}. Dropdown will be empty.')
        model_update_label = CHOOSE_MODEL
        if arch == VR_ARCH_TYPE:
            model_update_label = SELECT_VR_MODEL_MAIN_LABEL
            setting1_update = self.arch_setting1.update(choices=VR_WINDOW, label=WINDOW_SIZE_MAIN_LABEL, value=root.window_size_var.get())
            setting2_update = self.arch_setting2.update(choices=VR_AGGRESSION, label=AGGRESSION_SETTING_MAIN_LABEL, value=root.aggression_setting_var.get())
        elif arch == MDX_ARCH_TYPE:
            model_update_label = CHOOSE_MDX_MODEL_MAIN_LABEL
            setting1_update = self.arch_setting1.update(choices=BATCH_SIZE, label=BATCHES_MDX_MAIN_LABEL, value=root.mdx_batch_size_var.get())
            setting2_update = self.arch_setting2.update(choices=VOL_COMPENSATION, label=VOL_COMP_MDX_MAIN_LABEL, value=root.compensate_var.get())
        elif arch == DEMUCS_ARCH_TYPE:
            model_update_label = CHOOSE_DEMUCS_MODEL_MAIN_LABEL
            setting1_update = self.arch_setting1.update(choices=[], label='Demucs Setting 1', value=None, visible=False)
            setting2_update = self.arch_setting2.update(choices=[], label='Demucs Setting 2', value=None, visible=False)
        else:
            gr.Error(f'Unknown arch type: {arch}')
            model_update = self.model_choice.update(choices=[], value=CHOOSE_MODEL, label='Error: Unknown Arch')
            setting1_update = self.arch_setting1.update(choices=[], value=None, label='Setting 1')
            setting2_update = self.arch_setting2.update(choices=[], value=None, label='Setting 2')
            return [model_update, setting1_update, setting2_update]
        model_update = self.model_choice.update(choices=choices, value=CHOOSE_MODEL, label=model_update_label)
        return [model_update, setting1_update, setting2_update]

    def model_select_update(self, arch: str, model_name: str) -> List[Union[str, Dict, None]]:
        if model_name == CHOOSE_MODEL or model_name is None:
            return [self.primary_stem_only.update(label=f'{PRIMARY_STEM} only'), self.secondary_stem_only.update(label=f'{SECONDARY_STEM} only'), self.primary_stem_out.update(label=f'Output {PRIMARY_STEM}'), self.secondary_stem_out.update(label=f'Output {SECONDARY_STEM}')]
        model_data_list = self.uvr.assemble_model_data(model_name, arch)
        if not model_data_list:
            gr.Error(f'Cannot get model data for model {model_name}, arch {arch}. Model list empty.')
            return [None for _ in range(4)]
        model = model_data_list[0]
        if not model.model_status:
            gr.Error(f'Cannot get model data, model hash = {model.model_hash}')
            return [None for _ in range(4)]
        stem1_check_update = self.primary_stem_only.update(label=f'{model.primary_stem} Only')
        stem2_check_update = self.secondary_stem_only.update(label=f'{model.secondary_stem} Only')
        stem1_out_update = self.primary_stem_out.update(label=f'Output {model.primary_stem}')
        stem2_out_update = self.secondary_stem_out.update(label=f'Output {model.secondary_stem}')
        return [stem1_check_update, stem2_check_update, stem1_out_update, stem2_out_update]

    def checkbox_set_root_value(self, checkbox: gr.Checkbox, root_attr: str):
        checkbox.change(lambda value: root.__getattribute__(root_attr).set(value), inputs=checkbox)

    def set_checkboxes_exclusive(self, checkboxes: List[gr.Checkbox], pure_callbacks: List[Callable], exclusive_value=True):

        def exclusive_onchange(i, callback_i):

            def new_onchange(*check_values):
                current_values = [cb.value for cb in checkboxes]
                if current_values[i] == exclusive_value:
                    return_values = []
                    for j, value_j in enumerate(current_values):
                        if j != i and value_j == exclusive_value:
                            return_values.append(not exclusive_value)
                        else:
                            return_values.append(current_values[j])
                    return_values[i] = exclusive_value
                else:
                    return_values = current_values
                for cb_idx, final_val in enumerate(return_values):
                    pure_callbacks[cb_idx](final_val)
                return tuple(return_values)
            return new_onchange
        for i, (checkbox, callback) in enumerate(zip(checkboxes, pure_callbacks)):

            def create_exclusive_handler(changed_idx, all_checkboxes, all_callbacks):

                def handler(is_checked):
                    outputs = []
                    all_callbacks[changed_idx](is_checked)
                    for k_idx, cb_k in enumerate(all_checkboxes):
                        if k_idx == changed_idx:
                            outputs.append(is_checked)
                        elif is_checked:
                            all_callbacks[k_idx](False)
                            outputs.append(False)
                        else:
                            outputs.append(gr.update())
                    return tuple(outputs)
                return handler
            checkbox.change(create_exclusive_handler(i, checkboxes, pure_callbacks), inputs=checkbox, outputs=checkboxes)

    def process(self, input_audio, input_filename, model_name, arch, setting1, setting2, progress=gr.Progress(track_tqdm=True)):
        if input_audio is None:
            return (None, None, 'Error: No input audio provided.')
        if model_name == CHOOSE_MODEL or model_name is None:
            return (None, None, 'Error: Please select a model.')

        def set_progress_func(step, inference_iterations=0):
            pass
        sampling_rate, audio_data = input_audio
        if np.issubdtype(audio_data.dtype, np.integer):
            audio_data = (audio_data / np.iinfo(audio_data.dtype).max).astype(np.float32)
        elif not np.issubdtype(audio_data.dtype, np.floating):
            return (None, None, f'Error: Unsupported audio data type {audio_data.dtype}')
        if len(audio_data.shape) > 1 and audio_data.shape[0] > 5:
            audio_data = audio_data.T
        if len(audio_data.shape) > 1:
            audio_data = librosa.to_mono(audio_data)
        if not input_filename:
            input_filename = 'audio_input.wav'
        elif not input_filename.lower().endswith(('.wav', '.mp3', '.flac')):
            input_filename += '.wav'
        input_path = os.path.join(self.input_temp_dir, os.path.basename(input_filename))
        try:
            soundfile.write(input_path, audio_data, sampling_rate, format='wav')
        except Exception as e:
            return (None, None, f'Error writing temporary input file: {e}')
        self.set_arch_setting_value(arch, setting1, setting2)
        separator = self.uvr.process(model_name=model_name, arch_type=arch, audio_file=input_path, export_path=self.export_path, is_model_sample_mode=root.model_sample_mode_var.get(), set_progress_func=set_progress_func)
        if separator is None:
            if os.path.exists(input_path):
                os.remove(input_path)
            return (None, None, 'Error during processing. Separator object is None.')
        primary_audio_out = None
        secondary_audio_out = None
        msg = ''
        if separator.export_path and separator.audio_file_base and separator.primary_stem:
            if not separator.is_secondary_stem_only:
                primary_stem_path = os.path.join(separator.export_path, f'{separator.audio_file_base}_({separator.primary_stem}).wav')
                if os.path.exists(primary_stem_path):
                    audio_p, rate_p = soundfile.read(primary_stem_path)
                    primary_audio_out = (rate_p, audio_p)
                    msg += f'{separator.primary_stem} saved at {primary_stem_path}\n'
                else:
                    msg += f'Error: Primary stem file not found at {primary_stem_path}\n'
        else:
            msg += 'Error: Missing data in separator object for primary stem.\n'
        if separator.export_path and separator.audio_file_base and separator.secondary_stem:
            if not separator.is_primary_stem_only:
                secondary_stem_path = os.path.join(separator.export_path, f'{separator.audio_file_base}_({separator.secondary_stem}).wav')
                if os.path.exists(secondary_stem_path):
                    audio_s, rate_s = soundfile.read(secondary_stem_path)
                    secondary_audio_out = (rate_s, audio_s)
                    msg += f'{separator.secondary_stem} saved at {secondary_stem_path}\n'
                else:
                    msg += f'Error: Secondary stem file not found at {secondary_stem_path}\n'
        else:
            msg += 'Error: Missing data in separator object for secondary stem.\n'
        if os.path.exists(input_path):
            os.remove(input_path)
        return (primary_audio_out, secondary_audio_out, msg.strip())

    def define_layout(self):
        with gr.Blocks() as app:
            self.app = app
            gr.HTML('<h1> 🎵 Ultimate Vocal Remover WebUI Local Patch By Germanized🎵 </h1>')
            gr.Markdown('This is an experimental demo with CPU. Duplicate the space for use in private')
            gr.Markdown('[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm-dark.svg)](https://huggingface.co/spaces/r3gm/Ultimate-Vocal-Remover-WebUI?duplicate=true)\n\n')
            with gr.Tabs():
                with gr.TabItem('Process'):
                    with gr.Row():
                        self.arch_choice = gr.Dropdown(choices=[VR_ARCH_TYPE, MDX_ARCH_TYPE], value=VR_ARCH_TYPE, label=CHOOSE_PROC_METHOD_MAIN_LABEL, interactive=True)
                        self.model_choice = gr.Dropdown(choices=self.get_local_models(VR_ARCH_TYPE), value=CHOOSE_MODEL, label=SELECT_VR_MODEL_MAIN_LABEL + ' 👋Select a model', interactive=True)
                    with gr.Row():
                        self.arch_setting1 = gr.Dropdown(choices=VR_WINDOW, value=root.window_size_var.get(), label=WINDOW_SIZE_MAIN_LABEL + ' 👋Select one', interactive=True)
                        self.arch_setting2 = gr.Dropdown(choices=VR_AGGRESSION, value=root.aggression_setting_var.get(), label=AGGRESSION_SETTING_MAIN_LABEL, interactive=True)
                    with gr.Row():
                        self.use_gpu = gr.Checkbox(label='Rhythmic Transmutation Device', value=True, interactive=True)
                        self.primary_stem_only = gr.Checkbox(label=f'{PRIMARY_STEM} only', value=root.is_primary_stem_only_var.get(), interactive=True)
                        self.secondary_stem_only = gr.Checkbox(label=f'{SECONDARY_STEM} only', value=root.is_secondary_stem_only_var.get(), interactive=True)
                        self.sample_mode = gr.Checkbox(label=SAMPLE_MODE_CHECKBOX(root.model_sample_mode_duration_var.get()), value=root.model_sample_mode_var.get(), interactive=True)
                    with gr.Row():
                        self.input_filename = gr.Textbox(label='Input filename (e.g., song.wav)', value='temp.wav', interactive=True)
                    with gr.Row():
                        self.audio_in = gr.Audio(label='Input audio', type='numpy', interactive=True)
                    with gr.Row():
                        self.process_submit = gr.Button(START_PROCESSING, variant='primary')
                    with gr.Row():
                        self.primary_stem_out = gr.Audio(label=f'Output {PRIMARY_STEM}', interactive=False)
                        self.secondary_stem_out = gr.Audio(label=f'Output {SECONDARY_STEM}', interactive=False)
                    with gr.Row():
                        self.out_message = gr.Textbox(label='Output Message', interactive=False)
                with gr.TabItem('Settings'):
                    with gr.Tabs():
                        with gr.TabItem('Settings Guide (Placeholder)'):
                            gr.Markdown('Details about settings would go here.')
                        with gr.TabItem('Additional Settings'):
                            self.wav_type = gr.Dropdown(choices=WAV_TYPE, label='Wav Type Output', value='PCM_16', interactive=True)
                            self.mp3_rate = gr.Dropdown(choices=MP3_BIT_RATES, label='MP3 Bitrate Output', value='320k', interactive=True)
                        with gr.TabItem('Download Models'):
                            gr.Markdown('Select a model category and model name to see its download URL. Models are downloaded automatically on startup if missing.')

                            def md_url(url, text=None):
                                if text is None:
                                    text = url
                                return f'[{text}]({url})'
                            with gr.Row():
                                vr_models_for_dl = self.models_url.get(VR_ARCH_TYPE, {})
                                self.vr_download_choice = gr.Dropdown(choices=list(vr_models_for_dl.keys()), label=f'Select {VR_ARCH_TYPE} Model', interactive=True)
                                self.vr_download_url = gr.Markdown()
                                self.vr_download_choice.change(lambda model: md_url(vr_models_for_dl.get(model, 'URL not found')) if model else '', inputs=self.vr_download_choice, outputs=self.vr_download_url)
                            with gr.Row():
                                mdx_models_for_dl = self.models_url.get(MDX_ARCH_TYPE, {})
                                self.mdx_download_choice = gr.Dropdown(choices=list(mdx_models_for_dl.keys()), label=f'Select {MDX_ARCH_TYPE} Model', interactive=True)
                                self.mdx_download_url = gr.Markdown()
                                self.mdx_download_choice.change(lambda model: md_url(mdx_models_for_dl.get(model, 'URL not found')) if model else '', inputs=self.mdx_download_choice, outputs=self.mdx_download_url)
                            with gr.Row():
                                demucs_models_for_dl: Dict[str, Dict] = self.models_url.get(DEMUCS_ARCH_TYPE, {})
                                self.demucs_download_choice = gr.Dropdown(choices=list(demucs_models_for_dl.keys()), label=f'Select {DEMUCS_ARCH_TYPE} Model', interactive=True)
                                self.demucs_download_url = gr.Markdown()
                                self.demucs_download_choice.change(lambda model: '\n'.join(['- ' + md_url(url, text=filename) for filename, url in demucs_models_for_dl.get(model, {}).items()]) if model else '', inputs=self.demucs_download_choice, outputs=self.demucs_download_url)
            self.arch_choice.change(self.arch_select_update, inputs=self.arch_choice, outputs=[self.model_choice, self.arch_setting1, self.arch_setting2])
            self.model_choice.change(self.model_select_update, inputs=[self.arch_choice, self.model_choice], outputs=[self.primary_stem_only, self.secondary_stem_only, self.primary_stem_out, self.secondary_stem_out])
            self.checkbox_set_root_value(self.use_gpu, 'is_gpu_conversion_var')
            self.checkbox_set_root_value(self.sample_mode, 'model_sample_mode_var')

            def make_exclusive_primary(is_checked_primary):
                root.is_primary_stem_only_var.set(is_checked_primary)
                if is_checked_primary:
                    root.is_secondary_stem_only_var.set(False)
                    return (gr.update(value=is_checked_primary), gr.update(value=False))
                return (gr.update(value=is_checked_primary), gr.update())

            def make_exclusive_secondary(is_checked_secondary):
                root.is_secondary_stem_only_var.set(is_checked_secondary)
                if is_checked_secondary:
                    root.is_primary_stem_only_var.set(False)
                    return (gr.update(value=False), gr.update(value=is_checked_secondary))
                return (gr.update(), gr.update(value=is_checked_secondary))
            self.primary_stem_only.change(make_exclusive_primary, inputs=self.primary_stem_only, outputs=[self.primary_stem_only, self.secondary_stem_only])
            self.secondary_stem_only.change(make_exclusive_secondary, inputs=self.secondary_stem_only, outputs=[self.primary_stem_only, self.secondary_stem_only])
            self.process_submit.click(self.process, inputs=[self.audio_in, self.input_filename, self.model_choice, self.arch_choice, self.arch_setting1, self.arch_setting2], outputs=[self.primary_stem_out, self.secondary_stem_out, self.out_message])

    def launch(self, **kwargs):
        self.app.queue().launch(**kwargs)
uvr_interface_instance = UVRInterface()
uvr_interface_instance.cached_sources_clear()
webui_instance = UVRWebUI(uvr_interface_instance, online_data_path='models/download_checks.json')
print('INFO: Checking and downloading models if necessary...')
model_dict_to_download = webui_instance.models_url
for category, models_in_category in model_dict_to_download.items():
    target_model_dir = None
    expected_suffix = None
    if category == VR_ARCH_TYPE:
        target_model_dir = VR_MODELS_DIR
        expected_suffix = '.pth'
    elif category == MDX_ARCH_TYPE:
        target_model_dir = MDX_MODELS_DIR
        expected_suffix = '.onnx'
    elif category == DEMUCS_ARCH_TYPE:
        print(f'INFO: Skipping direct download for {category} in this loop. Demucs models are handled by their own mechanism or need specific download paths.')
        continue
    else:
        print(f'INFO: Unknown category for download: {category}')
        continue
    if not target_model_dir:
        continue
    if not os.path.exists(target_model_dir):
        os.makedirs(target_model_dir, exist_ok=True)
        print(f'INFO: Created directory: {target_model_dir}')
    print(f'\nINFO: Checking/Downloading models for {category} into {target_model_dir}...')
    if not isinstance(models_in_category, dict):
        print(f'Warning: Expected a dictionary of models for {category}, but got {type(models_in_category)}. Skipping.')
        continue
    for model_base_name, model_full_url in models_in_category.items():
        filename_from_url = model_full_url.split('/')[-1]
        if not filename_from_url.endswith(expected_suffix):
            correct_local_filename = model_base_name + expected_suffix
            local_file_path = os.path.join(target_model_dir, correct_local_filename)
        else:
            local_file_path = os.path.join(target_model_dir, filename_from_url)
        if not os.path.exists(local_file_path):
            print(f'INFO: Downloading {model_full_url} to {target_model_dir} (expected as {os.path.basename(local_file_path)})...')
            try:
                downloaded_filepath_actual = wget.download(model_full_url, out=target_model_dir)
                if os.path.basename(downloaded_filepath_actual) != os.path.basename(local_file_path):
                    print(f'INFO: Downloaded as {os.path.basename(downloaded_filepath_actual)}, renaming to {os.path.basename(local_file_path)}')
                    if os.path.exists(local_file_path):
                        os.remove(local_file_path)
                    shutil.move(downloaded_filepath_actual, local_file_path)
                print(f'INFO: Download successful: {local_file_path}')
            except Exception as e:
                print(f'ERROR: wget download failed for {model_full_url}: {e}')
        else:
            print(f'INFO: Skipping {local_file_path}, already exists.')
    print(f'INFO: Finished checking/downloading for {category}.')
print('INFO: Model download check complete.')
initial_model_choices = webui_instance.get_local_models(VR_ARCH_TYPE)
webui_instance.model_choice.choices = initial_model_choices
print('INFO: Re-initializing WebUI to pick up downloaded models for dropdowns...')
webui_instance = UVRWebUI(uvr_interface_instance, online_data_path='models/download_checks.json')
print('INFO: Launching WebUI...')
webui_instance.launch()