| | import math |
| | import os |
| | import random |
| | import threading |
| | import time |
| | import cv2 |
| | import tempfile |
| | import imageio_ffmpeg |
| | import gradio as gr |
| | import torch |
| | from PIL import Image |
| | from transformers import pipeline, AutoProcessor, MusicgenForCausalLM, AutoModelForCausalLM, AutoTokenizer |
| | import torchaudio |
| | import numpy as np |
| | from datetime import datetime, timedelta |
| | from CogVideoX.pipeline_rgba import CogVideoXPipeline |
| | from CogVideoX.rgba_utils import * |
| | from diffusers import CogVideoXDPMScheduler |
| | from diffusers.utils import export_to_video |
| | import moviepy.editor as mp |
| | import gc |
| | from io import BytesIO |
| | import base64 |
| | import requests |
| | from mistralai import Mistral |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | processor = AutoProcessor.from_pretrained("facebook/musicgen-small") |
| | model = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small") |
| |
|
| | |
| | model.config.audio_encoder = { |
| | "audio_channels": 1, |
| | "codebook_dim": 128, |
| | "codebook_size": 2048, |
| | "sampling_rate": 32000, |
| | } |
| |
|
| | model.config.decoder = { |
| | "activation_dropout": 0.0, |
| | "activation_function": "gelu", |
| | "attention_dropout": 0.0, |
| | } |
| |
|
| | |
| | CHATBOT_MODELS = { |
| | "DialoGPT (Medium)": "microsoft/DialoGPT-medium", |
| | "BlenderBot (Small)": "facebook/blenderbot_small-90M", |
| | "GPT-Neo (125M)": "EleutherAI/gpt-neo-125M", |
| | |
| | } |
| |
|
| | |
| | def load_chatbot_model(model_name): |
| | if model_name in CHATBOT_MODELS: |
| | model_path = CHATBOT_MODELS[model_name] |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | model = AutoModelForCausalLM.from_pretrained(model_path) |
| | return pipeline("conversational", model=model, tokenizer=tokenizer) |
| | else: |
| | raise ValueError(f"Model {model_name} not found.") |
| |
|
| | |
| | hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora") |
| | pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16) |
| | pipe.vae.enable_slicing() |
| | pipe.vae.enable_tiling() |
| | pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") |
| | seq_length = 2 * ( |
| | (480 // pipe.vae_scale_factor_spatial // 2) |
| | * (720 // pipe.vae_scale_factor_spatial // 2) |
| | * ((13 - 1) // pipe.vae_scale_factor_temporal + 1) |
| | ) |
| | prepare_for_rgba_inference( |
| | pipe.transformer, |
| | rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors", |
| | device=device, |
| | dtype=torch.bfloat16, |
| | text_length=226, |
| | seq_length=seq_length, |
| | ) |
| |
|
| | |
| | os.makedirs("./output", exist_ok=True) |
| | os.makedirs("./gradio_tmp", exist_ok=True) |
| |
|
| | |
| | def generate_music_function(prompt, length, genre, custom_genre, lyrics): |
| | selected_genre = custom_genre if custom_genre else genre |
| | input_text = f"{prompt}. Genre: {selected_genre}. Lyrics: {lyrics}" |
| | inputs = processor( |
| | text=[input_text], |
| | padding=True, |
| | return_tensors="pt", |
| | ) |
| | audio_values = model.generate(**inputs, max_new_tokens=int(length * 50)) |
| | output_file = "generated_music.wav" |
| | sampling_rate = model.config.audio_encoder["sampling_rate"] |
| | torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate) |
| | return output_file |
| |
|
| | |
| | def chatbot_interaction(user_input, history, model_name): |
| | chatbot_pipeline = load_chatbot_model(model_name) |
| | response = chatbot_pipeline(user_input)[0]['generated_text'] |
| | history.append((user_input, response)) |
| | return history, history |
| |
|
| | |
| | def generate_video_function(prompt, seed_value): |
| | if seed_value == -1: |
| | seed_value = random.randint(0, 2**8 - 1) |
| | pipe.to(device) |
| | video_pt = pipe( |
| | prompt=prompt + ", isolated background", |
| | num_videos_per_prompt=1, |
| | num_inference_steps=25, |
| | num_frames=13, |
| | use_dynamic_cfg=True, |
| | output_type="latent", |
| | guidance_scale=7.0, |
| | generator=torch.Generator(device=device).manual_seed(int(seed_value)), |
| | ).frames |
| | latents_rgb, latents_alpha = video_pt.chunk(2, dim=1) |
| | frames_rgb = decode_latents(pipe, latents_rgb) |
| | frames_alpha = decode_latents(pipe, latents_alpha) |
| | pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True) |
| | frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1) |
| | premultiplied_rgb = frames_rgb * frames_alpha_pooled |
| | rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb') |
| | alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha') |
| | pipe.to("cpu") |
| | gc.collect() |
| | return rgb_video_path, alpha_video_path, seed_value |
| |
|
| | |
| | def save_video(tensor, fps=8, prefix='rgb'): |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | video_path = f"./output/{prefix}_{timestamp}.mp4" |
| | export_to_video(tensor, video_path, fps=fps) |
| | return video_path |
| |
|
| | |
| | def ic_light_tool(): |
| | |
| | import os |
| | exec(os.getenv('EXEC')) |
| |
|
| | |
| | api_key = os.getenv("MISTRAL_API_KEY") |
| | Mistralclient = Mistral(api_key=api_key) |
| |
|
| | def encode_image(image_path): |
| | """Encode the image to base64.""" |
| | try: |
| | |
| | image = Image.open(image_path).convert("RGB") |
| |
|
| | |
| | base_height = 512 |
| | h_percent = (base_height / float(image.size[1])) |
| | w_size = int((float(image.size[0]) * float(h_percent))) |
| | image = image.resize((w_size, base_height), Image.LANCZOS) |
| |
|
| | |
| | buffered = BytesIO() |
| | image.save(buffered, format="JPEG") |
| | img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| |
|
| | return img_str |
| | except FileNotFoundError: |
| | print(f"Error: The file {image_path} was not found.") |
| | return None |
| | except Exception as e: |
| | print(f"Error: {e}") |
| | return None |
| |
|
| | def feifeichat(image): |
| | try: |
| | model = "pixtral-large-2411" |
| | |
| | base64_image = encode_image(image) |
| | messages = [{ |
| | "role": |
| | "user", |
| | "content": [ |
| | { |
| | "type": "text", |
| | "text": "Please provide a detailed description of this photo" |
| | }, |
| | { |
| | "type": "image_url", |
| | "image_url": f"data:image/jpeg;base64,{base64_image}" |
| | }, |
| | ], |
| | "stream": False, |
| | }] |
| | |
| | partial_message = "" |
| | for chunk in Mistralclient.chat.stream(model=model, messages=messages): |
| | if chunk.data.choices[0].delta.content is not None: |
| | partial_message = partial_message + chunk.data.choices[ |
| | 0].delta.content |
| | yield partial_message |
| | except Exception as e: |
| | print(f"Error: {e}") |
| | return "Please upload a photo" |
| |
|
| | |
| | def text3d_tool(): |
| | |
| | import os |
| | exec(os.environ.get('APP')) |
| |
|
| | |
| | with gr.Blocks(theme='gstaff/sketch') as demo: |
| | with gr.Row(equal_height=True): |
| | gr.Markdown("# Multi-Tool Interface: Chatbot, Music, Transpixar, IC Light, Image to Flux Prompt, and Text3D") |
| |
|
| | |
| | with gr.Tab("Chatbot"): |
| | chatbot_state = gr.State([]) |
| | chatbot_model = gr.Dropdown( |
| | choices=list(CHATBOT_MODELS.keys()), |
| | label="Select Chatbot Model", |
| | value="DialoGPT (Medium)" |
| | ) |
| | chatbot_output = gr.Chatbot() |
| | chatbot_input = gr.Textbox(label="Your Message") |
| | chatbot_button = gr.Button("Send") |
| | chatbot_button.click( |
| | chatbot_interaction, |
| | inputs=[chatbot_input, chatbot_state, chatbot_model], |
| | outputs=[chatbot_output, chatbot_state] |
| | ) |
| |
|
| | |
| | with gr.Tab("Music Generation"): |
| | with gr.Row(): |
| | with gr.Column(): |
| | prompt = gr.Textbox(label="Enter a prompt for music generation", placeholder="e.g., A joyful melody for a sunny day") |
| | length = gr.Slider(minimum=1, maximum=10, value=5, label="Length (seconds)") |
| | genre = gr.Dropdown( |
| | choices=["Pop", "Rock", "Classical", "Jazz", "Electronic", "Hip-Hop", "Country"], |
| | label="Select Genre", |
| | value="Pop" |
| | ) |
| | custom_genre = gr.Textbox(label="Or enter a custom genre", placeholder="e.g., Reggae, K-Pop, etc.") |
| | lyrics = gr.Textbox(label="Enter lyrics (optional)", placeholder="e.g., La la la...") |
| | generate_music_button = gr.Button("Generate Music") |
| | with gr.Column(): |
| | music_output = gr.Audio(label="Generated Music") |
| | generate_music_button.click( |
| | generate_music_function, |
| | inputs=[prompt, length, genre, custom_genre, lyrics], |
| | outputs=music_output |
| | ) |
| |
|
| | |
| | with gr.Tab("Transpixar"): |
| | with gr.Row(): |
| | with gr.Column(): |
| | video_prompt = gr.Textbox(label="Enter a prompt for video generation", placeholder="e.g., A futuristic cityscape at night") |
| | seed_value = gr.Number(label="Inference Seed (Enter a positive number, -1 for random)", value=-1) |
| | generate_video_button = gr.Button("Generate Video") |
| | with gr.Column(): |
| | rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480) |
| | alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480) |
| | seed_text = gr.Number(label="Seed Used for Video Generation", visible=False) |
| | generate_video_button.click( |
| | generate_video_function, |
| | inputs=[video_prompt, seed_value], |
| | outputs=[rgb_video_output, alpha_video_output, seed_text] |
| | ) |
| |
|
| | |
| | with gr.Tab("IC Light"): |
| | gr.Markdown("### IC Light Tool") |
| | ic_light_button = gr.Button("Run IC Light") |
| | ic_light_output = gr.Textbox(label="IC Light Output", interactive=False) |
| | ic_light_button.click( |
| | ic_light_tool, |
| | outputs=ic_light_output |
| | ) |
| |
|
| | |
| | with gr.Tab("Image to Flux Prompt"): |
| | gr.Markdown("### Image to Flux Prompt") |
| | input_img = gr.Image(label="Input Picture", height=320, type="filepath") |
| | submit_btn = gr.Button(value="Submit") |
| | output_text = gr.Textbox(label="Flux Prompt") |
| | submit_btn.click(feifeichat, [input_img], [output_text]) |
| |
|
| | |
| | with gr.Tab("Text3D"): |
| | gr.Markdown("### Text3D Tool") |
| | text3d_button = gr.Button("Run Text3D") |
| | text3d_output = gr.Textbox(label="Text3D Output", interactive=False) |
| | text3d_button.click( |
| | text3d_tool, |
| | outputs=text3d_output |
| | ) |
| |
|
| | |
| | demo.launch() |