miiann commited on
Commit
d9b26f0
·
verified ·
1 Parent(s): b8777eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -27
app.py CHANGED
@@ -2,41 +2,28 @@ import gradio as gr
2
  from transformers import AutoModel, AutoTokenizer
3
  import torch
4
 
5
- # Load model (cached after first run)
6
- def load_model():
7
- tokenizer = AutoTokenizer.from_pretrained(
8
- "Qwen/Qwen3-Embedding-8B",
9
- trust_remote_code=True
10
- )
11
- model = AutoModel.from_pretrained(
12
- "Qwen/Qwen3-Embedding-8B",
13
- trust_remote_code=True,
14
- device_map="auto"
15
- ).eval()
16
- return tokenizer, model
17
-
18
- tokenizer, model = load_model()
19
 
20
- # Embedding generation function
21
  def get_embedding(text):
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True).to(model.device)
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
- # Mean pooling for sentence embedding
26
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
27
- return {
28
- "input_text": text,
29
- "embedding_size": len(embedding),
30
- "first_5_values": embedding[:5] # Preview
31
- }
32
 
33
- # Gradio Interface
34
  demo = gr.Interface(
35
  fn=get_embedding,
36
- inputs=gr.Textbox(label="Input Text", placeholder="Enter text to embed..."),
37
- outputs=gr.JSON(label="Embedding Result"),
38
- title="Qwen3-Embedding-8B Demo",
39
- examples=["Hello world", "How does AI work?", "上海天气怎么样?"]
40
  )
41
-
42
  demo.launch()
 
2
  from transformers import AutoModel, AutoTokenizer
3
  import torch
4
 
5
+ # Load model with caching
6
+ tokenizer = AutoTokenizer.from_pretrained(
7
+ "Qwen/Qwen3-Embedding-8B",
8
+ trust_remote_code=True
9
+ )
10
+ model = AutoModel.from_pretrained(
11
+ "Qwen/Qwen3-Embedding-8B",
12
+ trust_remote_code=True,
13
+ device_map="auto"
14
+ ).eval()
 
 
 
 
15
 
 
16
  def get_embedding(text):
17
  inputs = tokenizer(text, return_tensors="pt", truncation=True).to(model.device)
18
  with torch.no_grad():
19
  outputs = model(**inputs)
 
20
  embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
21
+ return {"text": text, "embedding_size": len(embedding)}
 
 
 
 
22
 
 
23
  demo = gr.Interface(
24
  fn=get_embedding,
25
+ inputs=gr.Textbox(label="Input text"),
26
+ outputs=gr.JSON(),
27
+ title="Qwen3 Embeddings"
 
28
  )
 
29
  demo.launch()