rootp1 commited on
Commit
e91d4b8
·
1 Parent(s): 36157e9
Files changed (3) hide show
  1. app.py +36 -0
  2. model.py +58 -0
  3. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from model import load_model, predict_species, get_label_names
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ # Load environment variables
8
+ load_dotenv()
9
+
10
+ app = Flask(__name__)
11
+
12
+ # Configure CORS with environment variables
13
+ cors_origins = os.getenv('CORS_ORIGINS', 'http://localhost:3000').split(',')
14
+ CORS(app, origins=cors_origins)
15
+
16
+ model = load_model()
17
+ label_names = get_label_names()
18
+
19
+ @app.route('/predict', methods=['GET'])
20
+ def predict():
21
+ image_url = request.args.get('url')
22
+ if not image_url:
23
+ return jsonify({'error': 'URL parameter is missing'}), 400
24
+ try:
25
+ predicted_species = predict_species(model, image_url, label_names)
26
+ return jsonify({'species': predicted_species})
27
+ except Exception as e:
28
+ return jsonify({'error': str(e)}), 500
29
+
30
+ if __name__ == '__main__':
31
+ # Get configuration from environment variables
32
+ host = os.getenv('FLASK_HOST', '127.0.0.1')
33
+ port = int(os.getenv('FLASK_PORT', 5000))
34
+ debug = os.getenv('FLASK_DEBUG', 'True').lower() == 'true'
35
+
36
+ app.run(host=host, port=port, debug=debug)
model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import requests
6
+ from io import BytesIO
7
+
8
+
9
+ def load_model():
10
+ """Load the pre-trained model."""
11
+ model = timm.create_model("hf_hub:timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k_inat21", pretrained=True)
12
+ model.eval()
13
+ return model
14
+
15
+
16
+ def get_label_names():
17
+ """Fetch the class labels from the Hugging Face Hub."""
18
+ config_url = "https://huggingface.co/timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k_inat21/resolve/main/config.json"
19
+ response = requests.get(config_url)
20
+ response.raise_for_status()
21
+ config = response.json()
22
+ return config["label_names"]
23
+
24
+
25
+ def preprocess_image(image_url):
26
+ """Fetch and preprocess the image."""
27
+ preprocess = transforms.Compose([
28
+ transforms.Resize(336),
29
+ transforms.CenterCrop(336),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
+ ])
33
+
34
+ response = requests.get(image_url)
35
+ response.raise_for_status()
36
+ image = Image.open(BytesIO(response.content))
37
+ input_tensor = preprocess(image).unsqueeze(0) # Add a batch dimension
38
+ return input_tensor
39
+
40
+
41
+ def predict_species(model, image_url, label_names):
42
+ """Make a prediction using the model."""
43
+ input_tensor = preprocess_image(image_url)
44
+
45
+ # Move to GPU if available
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ model = model.to(device)
48
+ input_tensor = input_tensor.to(device)
49
+
50
+ # Make prediction
51
+ with torch.no_grad():
52
+ output = model(input_tensor)
53
+ _, predicted_class = torch.max(output, 1)
54
+
55
+ # Map prediction to species
56
+ predicted_species = label_names[predicted_class.item()]
57
+ return predicted_species
58
+ #finish
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ flask==2.3.3
2
+ flask-cors==4.0.0
3
+ timm==0.9.5
4
+ torch==2.8.0
5
+ torchvision==0.23.0
6
+ pillow==10.0.0
7
+ requests==2.31.0
8
+ gunicorn
9
+ python-dotenv==1.0.0