frogleo commited on
Commit
8aaffb1
·
verified ·
1 Parent(s): f4f56a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -17
app.py CHANGED
@@ -15,7 +15,15 @@ from PIL import Image
15
  import json
16
  import base64
17
  from huggingface_hub import InferenceClient
 
18
 
 
 
 
 
 
 
 
19
 
20
  subprocess.check_call([sys.executable, "-m", "pip", "install", "spaces==0.43.0"])
21
 
@@ -78,6 +86,34 @@ pipe = Flux2Pipeline.from_pretrained(
78
  )
79
  pipe.to(device)
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Pull pre-compiled Flux2 Transformer blocks from HF hub
82
  # flash-attn估计库估计更新了,导致冲突了,不使用预编译的了
83
  # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
@@ -171,7 +207,7 @@ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps,
171
  return max(65, num_inference_steps * step_duration + 10)
172
 
173
  @spaces.GPU(duration=get_duration)
174
- def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress()):
175
  # Move embeddings to GPU only when inside the GPU decorated function
176
  prompt_embeds = prompt_embeds.to(device)
177
 
@@ -187,21 +223,42 @@ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps
187
  progress(progress_value, desc=f"Image generating, {step + 1}/{num_inference_steps} steps")
188
  return callback_kwargs
189
 
190
- image = pipe(
191
- prompt_embeds=prompt_embeds,
192
- image=image_list,
193
- num_inference_steps=num_inference_steps,
194
- guidance_scale=guidance_scale,
195
- generator=generator,
196
- width=width,
197
- height=height,
198
- callback_on_step_end=callback_fn,
199
- ).images[0]
200
-
201
- path = save_image(image, "./outputs")
202
- progress(1, desc="Complete")
203
-
204
- return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, prompt_upsampling=False, progress=gr.Progress()):
207
 
@@ -233,7 +290,7 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
233
 
234
  # 3. Image Generation (GPU bound)
235
  progress(0.3, desc="Waiting for GPU...")
236
- image = generate_image(
237
  prompt_embeds,
238
  image_list,
239
  width,
@@ -243,6 +300,10 @@ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024,
243
  seed,
244
  progress
245
  )
 
 
 
 
246
 
247
  return image, seed
248
 
 
15
  import json
16
  import base64
17
  from huggingface_hub import InferenceClient
18
+ import logging
19
 
20
+ # Enhanced logging configuration
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
24
+ datefmt='%Y-%m-%d %H:%M:%S'
25
+ )
26
+ logger = logging.getLogger(__name__)
27
 
28
  subprocess.check_call([sys.executable, "-m", "pip", "install", "spaces==0.43.0"])
29
 
 
86
  )
87
  pipe.to(device)
88
 
89
+ class GenerationError(Exception):
90
+ """Custom exception for generation errors"""
91
+ pass
92
+
93
+ # -------------------- NSFW 检测模型加载 --------------------
94
+ try:
95
+ logger.info("Loading NSFW detector...")
96
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
97
+ from transformers import AutoProcessor, AutoModelForImageClassification
98
+ nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
99
+ nsfw_model = AutoModelForImageClassification.from_pretrained(
100
+ "Falconsai/nsfw_image_detection"
101
+ ).to(device)
102
+ logger.info("NSFW detector loaded successfully.")
103
+ except Exception as e:
104
+ logger.error(f"Failed to load NSFW detector: {e}")
105
+ nsfw_model = None
106
+ nsfw_processor = None
107
+
108
+ def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
109
+ """Returns True if image is NSFW"""
110
+ inputs = nsfw_processor(images=image, return_tensors="pt").to(device)
111
+ with torch.no_grad():
112
+ outputs = nsfw_model(**inputs)
113
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
114
+ nsfw_score = probs[0][1].item() # label 1 = NSFW
115
+ return nsfw_score > threshold
116
+
117
  # Pull pre-compiled Flux2 Transformer blocks from HF hub
118
  # flash-attn估计库估计更新了,导致冲突了,不使用预编译的了
119
  # spaces.aoti_blocks_load(pipe.transformer, "zerogpu-aoti/FLUX.2", variant="fa3")
 
207
  return max(65, num_inference_steps * step_duration + 10)
208
 
209
  @spaces.GPU(duration=get_duration)
210
+ def _generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, progress=gr.Progress()):
211
  # Move embeddings to GPU only when inside the GPU decorated function
212
  prompt_embeds = prompt_embeds.to(device)
213
 
 
223
  progress(progress_value, desc=f"Image generating, {step + 1}/{num_inference_steps} steps")
224
  return callback_kwargs
225
 
226
+ try:
227
+ image = pipe(
228
+ prompt_embeds=prompt_embeds,
229
+ image=image_list,
230
+ num_inference_steps=num_inference_steps,
231
+ guidance_scale=guidance_scale,
232
+ generator=generator,
233
+ width=width,
234
+ height=height,
235
+ callback_on_step_end=callback_fn,
236
+ ).images[0]
237
+
238
+ # NSFW 检测
239
+ if nsfw_model and nsfw_processor:
240
+ if detect_nsfw(image):
241
+ msg = "Generated image contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again."
242
+ raise Exception(msg)
243
+
244
+ path = save_image(image, "./outputs")
245
+ progress(1, desc="Complete")
246
+ info = {
247
+ "status": "success"
248
+ }
249
+ return path, info
250
+ except GenerationError as e:
251
+ error_info = {
252
+ "error": str(e),
253
+ "status": "failed",
254
+ }
255
+ return None, error_info
256
+ except Exception as e:
257
+ error_info = {
258
+ "error": str(e),
259
+ "status": "failed",
260
+ }
261
+ return None, error_info
262
 
263
  def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, prompt_upsampling=False, progress=gr.Progress()):
264
 
 
290
 
291
  # 3. Image Generation (GPU bound)
292
  progress(0.3, desc="Waiting for GPU...")
293
+ image, info = _generate_image(
294
  prompt_embeds,
295
  image_list,
296
  width,
 
300
  seed,
301
  progress
302
  )
303
+
304
+ # 如果出错,抛出异常
305
+ if info["status"] == "failed":
306
+ raise gr.Error(info["error"])
307
 
308
  return image, seed
309