import os import gradio as gr from PIL import Image import logging from zipfile import ZipFile from .inference import run_model from .utils import load_pred_volume_to_numpy from .utils import load_to_numpy from .utils import nifti_to_glb class WebUI: def __init__( self, model_name: str = None, cwd: str = "/home/user/app/", share: int = 1, ): self.file_output = None self.model_selector = None self.stripped_cb = None self.registered_cb = None self.run_btn = None self.slider = None self.download_file = None # global states self.images = [] self.pred_images = [] self.image_boxes = [] self.model_name = model_name self.cwd = cwd self.share = share self.class_name = "tumorcore" # default self.class_names = { "tumorcore": "MRI_TumorCore", "NETC": "MRI_Necrosis", "residual-tumor": "MRI_TumorCE_Postop", "cavity": "MRI_Cavity", "brain": "MRI_Brain", } self.result_names = { "tumorcore": "Tumor", "NETC": "NETC", "residual-tumor": "Tumor", "cavity": "Cavity", "brain": "Brain", } self.volume_renderer = gr.Model3D( clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model", visible=True, elem_id="model-3d", height=512, ) def set_class_name(self, value): print("Changed task to:", value) self.class_name = value def combine_ct_and_seg(self, img, pred): return (img, [(pred, self.class_name)]) def upload_file(self, file): return file.name def process(self, mesh_file_name, stripped_inputs_status:bool=False): path = mesh_file_name.name run_model( path, model_path=os.path.join(self.cwd, "resources/models/"), task=self.class_names[self.class_name], name=self.result_names[self.class_name], stripped_inputs_status=stripped_inputs_status, ) nifti_to_glb("prediction.nii.gz") self.images = load_to_numpy(path) self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz") slider = gr.Slider( minimum=0, maximum=len(self.images) - 1, value=int(len(self.images) / 2), step=1, label="Which 2D slice to show", interactive=True, ) return "./prediction.obj", slider def get_img_pred_pair(self, k): img = self.images[k] img_pil = Image.fromarray(img) seg_list = [] seg_list.append((self.pred_images[k], self.class_name)) return img_pil, seg_list def setup_interface_inputs(self): with gr.Row(): with gr.Column(): self.file_output = gr.File(file_count="single", elem_id="upload") with gr.Column(): self.model_selector = gr.Dropdown( list(self.class_names.keys()), label="Segmentation task", info="Select the segmentation model to run", multiselect=False, # size="sm", ) with gr.Column(): with gr.Row(): self.stripped_cb = gr.Checkbox(label="Stripped inputs") self.registered_cb = gr.Checkbox(label="Co-registered inputs") with gr.Row(): self.run_btn = gr.Button("Run segmentation", scale=1) def setup_interface_outputs(self): with gr.Row(): with gr.Group(): with gr.Column(): t = gr.AnnotatedImage( visible=True, elem_id="model-2d", color_map={self.class_name: "#ffae00"}, height=512, width=512, ) self.slider = gr.Slider( minimum=0, maximum=1, value=0, step=1, label="Which 2D slice to show", interactive=True, ) self.slider.change(fn=self.get_img_pred_pair, inputs=self.slider, outputs=t) with gr.Group(): self.volume_renderer.render() self.download_btn = gr.DownloadButton(label="Download results", visible=False) self.download_file = gr.File(label="Download Zip", interactive=True, visible=False) def package_results(self): """Generates text files and zips them.""" output_dir = "temp_output" os.makedirs(output_dir, exist_ok=True) zip_filename = os.path.join(output_dir, "generated_files.zip") with ZipFile(zip_filename, 'w') as zf: zf.write("./prediction.nii.gz") return zip_filename def run(self): css = """ #model-3d { height: 512px; } #model-2d { height: 512px; margin: auto; } #upload { height: 120px; } """ with gr.Blocks(css=css) as demo: # Define the interface components first self.setup_interface_inputs() with gr.Row(): gr.Examples( examples=[ os.path.join(self.cwd, "t1gd.nii.gz"), ], inputs=self.file_output, outputs=self.file_output, fn=self.upload_file, cache_examples=True, ) self.setup_interface_outputs() # Define the signals/slots self.file_output.upload(self.upload_file, self.file_output, self.file_output) self.model_selector.input(fn=lambda x: self.set_class_name(x), inputs=self.model_selector, outputs=None) self.run_btn.click(fn=self.process, inputs=[self.file_output, self.stripped_cb], outputs=[self.volume_renderer, self.slider]).then(fn=lambda: gr.DownloadButton(visible=True), inputs=None, outputs=self.download_btn) self.download_btn.click(fn=self.package_results, inputs=[], outputs=self.download_file).then(fn=lambda file_path: gr.File(label="Download Zip", visible=True, value=file_path), inputs=self.download_file, outputs=self.download_file) # sharing app publicly -> share=True: # https://gradio.app/sharing-your-app/ # inference times > 60 seconds -> need queue(): # https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062 demo.queue().launch( server_name="0.0.0.0", server_port=7860, share=self.share )