ttoosi's picture
Update inference.py
226318e verified
"""Complete generative inference module with model loading and inference capabilities."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models.resnet import ResNet50_Weights
from PIL import Image
import numpy as np
import os
import requests
import time
import copy
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
# Check for available hardware acceleration
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# Constants for model URLs
MODEL_URLS = {
'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt',
'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt',
'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt'
}
# Model-specific preprocessing configurations
MODEL_CONFIGS = {
'resnet50_robust_face': {
'input_size': 112,
'norm_mean': [0.5, 0.5, 0.5],
'norm_std': [0.5, 0.5, 0.5],
'n_classes': 500,
'dataset': 'VGGFace2'
},
'resnet50_standard': {
'input_size': 224,
'norm_mean': [0.485, 0.456, 0.406],
'norm_std': [0.229, 0.224, 0.225],
'n_classes': 1000,
'dataset': 'ImageNet'
},
'resnet50_robust': {
'input_size': 224,
'norm_mean': [0.485, 0.456, 0.406],
'norm_std': [0.229, 0.224, 0.225],
'n_classes': 1000,
'dataset': 'ImageNet'
}
}
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def get_iterations_to_show(n_itr):
"""Generate a dynamic list of iterations to show based on total iterations."""
if n_itr <= 50:
return [1, 5, 10, 20, 30, 40, 50, n_itr]
elif n_itr <= 100:
return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr]
elif n_itr <= 200:
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr]
elif n_itr <= 500:
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr]
else:
return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500,
int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr]
def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0):
"""Generate inference configuration with customizable parameters."""
config = {
'loss_infer': inference_type,
'n_itr': n_itr,
'eps': eps,
'step_size': step_size,
'diffusion_noise_ratio': 0.0,
'initial_inference_noise_ratio': 0.0,
'top_layer': 'all',
'inference_normalization': False,
'recognition_normalization': False,
'iterations_to_show': get_iterations_to_show(n_itr),
'misc_info': {'keep_grads': False}
}
if inference_type == 'IncreaseConfidence':
config['loss_function'] = 'CE'
elif inference_type == 'Prior-Guided Drift Diffusion':
config['loss_function'] = 'MSE'
config['initial_inference_noise_ratio'] = 0.05
config['diffusion_noise_ratio'] = 0.01
config['top_layer'] = 'layer4'
elif inference_type == 'GradModulation':
config['loss_function'] = 'CE'
config['misc_info']['grad_modulation'] = 0.5
elif inference_type == 'CompositionalFusion':
config['loss_function'] = 'CE'
config['misc_info']['positive_classes'] = []
config['misc_info']['negative_classes'] = []
return config
def get_model_preprocessing(model_type: str) -> Dict:
"""Get preprocessing configuration for specific model type."""
if model_type not in MODEL_CONFIGS:
print(f"Fall-back: Unknown model type {model_type}, using ImageNet defaults")
return MODEL_CONFIGS['resnet50_standard']
return MODEL_CONFIGS[model_type]
class NormalizeByChannelMeanStd(nn.Module):
"""Normalization layer for models."""
def __init__(self, mean, std):
super(NormalizeByChannelMeanStd, self).__init__()
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, tensor):
return self.normalize_fn(tensor, self.mean, self.std)
def normalize_fn(self, tensor, mean, std):
"""Differentiable version of torchvision.functional.normalize"""
mean = mean[None, :, None, None]
std = std[None, :, None, None]
return tensor.sub(mean).div(std)
class InferStep:
"""Inference step class for gradient-based optimization."""
def __init__(self, orig_image: torch.Tensor, eps: float, step_size: float):
self.orig_image = orig_image
self.eps = eps
self.step_size = step_size
def project(self, x: torch.Tensor) -> torch.Tensor:
"""Project x onto epsilon-ball around original image."""
diff = x - self.orig_image
diff = torch.clamp(diff, -self.eps, self.eps)
return torch.clamp(self.orig_image + diff, 0, 1)
def step(self, x: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
"""Take a normalized gradient step."""
dim = len(x.shape) - 1
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=1).reshape(-1, *([1] * dim))
scaled_grad = grad / (grad_norm + 1e-10)
return scaled_grad * self.step_size
def extract_middle_layers(model: nn.Module, layer_index: Union[str, int]) -> nn.Module:
"""Extract middle layers from a model up to a specified layer index."""
if isinstance(layer_index, str) and layer_index == 'all':
return model
# Handle ResNet layer extraction
modules = list(model.named_children())
cutoff_idx = next(
(i for i, (name, _) in enumerate(modules) if name == str(layer_index)),
None
)
if cutoff_idx is not None:
new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx + 1]))
return new_model
else:
print(f"Fall-back: Module {layer_index} not found, using full model")
return model
def calculate_loss(output_model: torch.Tensor, class_indices: List[int], loss_inference: str) -> torch.Tensor:
"""Calculate loss for specified class indices."""
losses = []
for idx in class_indices:
target = torch.full((1,), idx, dtype=torch.long, device=output_model.device)
if loss_inference == 'CE':
loss = nn.CrossEntropyLoss()(output_model, target)
elif loss_inference == 'MSE':
one_hot_target = torch.zeros_like(output_model)
one_hot_target[0, target] = 1
loss = nn.MSELoss()(output_model, one_hot_target)
else:
raise ValueError(f"Unsupported loss_inference: {loss_inference}")
losses.append(loss)
return torch.stack(losses).mean()
def download_model(model_type):
"""Download model if needed."""
if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
return None
os.makedirs("models", exist_ok=True)
if model_type == 'resnet50_robust_face':
model_path = Path("models/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt")
else:
model_path = Path(f"models/{model_type}.pt")
if not model_path.exists():
print(f"Downloading {model_type} model...")
url = MODEL_URLS[model_type]
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Model downloaded and saved to {model_path}")
else:
raise RuntimeError(f"Failed to download model: {response.status_code}")
return model_path
class GenerativeInferenceModel:
"""Complete generative inference model with model loading and inference."""
def __init__(self):
self.models = {}
self.model_preproc = {}
self.labels = self.get_imagenet_labels()
def get_imagenet_labels(self):
"""Get ImageNet labels."""
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
try:
response = requests.get(url, timeout=10) # Add timeout to prevent hanging
if response.status_code == 200:
return response.json()
else:
print("Fall-back: Failed to fetch ImageNet labels, using placeholder")
return [f"class_{i}" for i in range(1000)]
except Exception as e:
print(f"Fall-back: Error fetching labels: {e}")
return [f"class_{i}" for i in range(1000)]
def load_model(self, model_type):
"""Load and cache models for different model types."""
if model_type in self.models:
print(f"Using cached {model_type} model")
return self.models[model_type]
start_time = time.time()
# Get model-specific preprocessing config
preproc_config = get_model_preprocessing(model_type)
self.model_preproc[model_type] = preproc_config
# Create normalizer
normalizer = NormalizeByChannelMeanStd(
preproc_config['norm_mean'],
preproc_config['norm_std']
).to(device)
# Create base model architecture
num_classes = preproc_config['n_classes']
resnet = models.resnet50(num_classes=num_classes)
model = nn.Sequential(normalizer, resnet)
# Download and load checkpoint
model_path = download_model(model_type)
if model_path:
print(f"Loading {model_type} model from {model_path}...")
try:
checkpoint = torch.load(model_path, map_location=device)
# Handle different checkpoint formats
if 'model' in checkpoint:
state_dict = checkpoint['model']
print("Using 'model' key from checkpoint")
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
print("Using 'state_dict' key from checkpoint")
else:
state_dict = checkpoint
print("Using checkpoint directly as state_dict")
# Extract ResNet state dict
resnet_state_dict = {}
resnet_keys = set(resnet.state_dict().keys())
# For face model, prioritize 'module.model.model.' structure (seen in actual checkpoint)
if model_type == 'resnet50_robust_face':
# Check for 'module.model.model.' structure first (face checkpoints use this)
module_model_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.model.')]
if module_model_model_keys:
print(f"Found 'module.model.model.' structure with {len(module_model_model_keys)} parameters (face model)")
for source_key, value in state_dict.items():
if source_key.startswith('module.model.model.'):
target_key = source_key[len('module.model.model.'):]
if target_key in resnet_keys:
resnet_state_dict[target_key] = value
print(f"Extracted {len(resnet_state_dict)} parameters from module.model.model.")
# Also check for 'module.model.' structure as fallback
if len(resnet_state_dict) < len(resnet_keys):
module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.') and not key.startswith('module.model.model.')]
if module_model_keys:
print(f"Found additional 'module.model.' structure with {len(module_model_keys)} parameters")
for source_key, value in state_dict.items():
if source_key.startswith('module.model.') and not source_key.startswith('module.model.model.'):
target_key = source_key[len('module.model.'):]
# Remove extra 'model.' if present
if target_key.startswith('model.'):
target_key = target_key[len('model.'):]
if target_key in resnet_keys and target_key not in resnet_state_dict:
resnet_state_dict[target_key] = value
print(f"Now have {len(resnet_state_dict)} parameters after adding module.model. keys")
# Handle different key prefixes in checkpoints (for other models)
if len(resnet_state_dict) == 0:
prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.', 'attacker.']
for source_key, value in state_dict.items():
target_key = source_key
# Try removing various prefixes
for prefix in prefixes_to_try:
if source_key.startswith(prefix):
target_key = source_key[len(prefix):]
break
# Handle nested model keys
if target_key.startswith('model.'):
target_key = target_key[len('model.'):]
# If the target key is in ResNet keys, add it
if target_key in resnet_keys:
resnet_state_dict[target_key] = value
# Load the state dict
if resnet_state_dict:
result = resnet.load_state_dict(resnet_state_dict, strict=False)
missing_keys, unexpected_keys = result
loaded_percent = (len(resnet_state_dict) / len(resnet_keys)) * 100
print(f"Model loading: {len(resnet_state_dict)}/{len(resnet_keys)} parameters ({loaded_percent:.1f}%)")
if loaded_percent < 50:
print(f"Fall-back: Loading too incomplete ({loaded_percent:.1f}%), using PyTorch pretrained")
if model_type != 'resnet50_robust_face':
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = nn.Sequential(normalizer, resnet)
else:
print("Fall-back: No matching keys found in checkpoint, using PyTorch pretrained")
if model_type != 'resnet50_robust_face':
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = nn.Sequential(normalizer, resnet)
except Exception as e:
print(f"Fall-back: Error loading checkpoint: {e}")
if model_type != 'resnet50_robust_face':
print("Fall-back: Using PyTorch pretrained model")
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = nn.Sequential(normalizer, resnet)
else:
print("Fall-back: Face model checkpoint failed, model may not work properly")
else:
# Use PyTorch's pretrained model for ImageNet models
if model_type != 'resnet50_robust_face':
print(f"No checkpoint for {model_type}, using PyTorch pretrained")
resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = nn.Sequential(normalizer, resnet)
else:
print("Fall-back: Face model requires checkpoint, model may not work properly")
model = model.to(device)
model.eval()
# Verify model
self.verify_model_integrity(model, model_type)
# Cache the model
self.models[model_type] = model
end_time = time.time()
print(f"Model {model_type} loaded in {end_time - start_time:.2f} seconds")
return model
def verify_model_integrity(self, model, model_type):
"""Verify model integrity."""
try:
print(f"Fall-back: Running model integrity check for {model_type}")
config = get_model_preprocessing(model_type)
H = W = config['input_size']
test_input = torch.zeros(1, 3, H, W, device=device)
test_input[0, 0, H//4:3*H//4, W//4:3*W//4] = 0.5
with torch.no_grad():
output = model(test_input)
expected_classes = config['n_classes']
if output.shape != (1, expected_classes):
print(f"Fall-back: Unexpected output shape: {output.shape}, expected (1, {expected_classes})")
return False
probs = torch.nn.functional.softmax(output, dim=1)
confidence, prediction = torch.max(probs, 1)
print(f"Model integrity check passed:")
print(f"- Output shape: {output.shape}")
print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence")
return True
except Exception as e:
print(f"Fall-back: Model integrity check failed with error: {e}")
return False
def inference(self, image, model_type, config):
"""Run generative inference."""
inference_start = time.time()
# Load the model
model = self.load_model(model_type)
# Handle image input
if isinstance(image, str):
if os.path.exists(image):
image = Image.open(image).convert('RGB')
else:
raise ValueError(f"Image path does not exist: {image}")
elif isinstance(image, np.ndarray):
if image.dtype != np.uint8:
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
if len(image.shape) == 3:
if image.shape[0] == 3 or image.shape[0] == 1:
image = np.transpose(image, (1, 2, 0))
if image.shape[2] == 4:
image = image[:, :, :3]
elif image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
try:
image = Image.fromarray(np.array(image)).convert('RGB')
except Exception as e:
raise ValueError(f"Cannot convert image type {type(image)} to PIL Image: {e}")
if isinstance(image, Image.Image) and image.mode != 'RGB':
image = image.convert('RGB')
# Get preprocessing config
preproc_config = get_model_preprocessing(model_type)
input_size = preproc_config['input_size']
norm_mean = torch.tensor(preproc_config['norm_mean'])
norm_std = torch.tensor(preproc_config['norm_std'])
n_classes = preproc_config['n_classes']
# Create transform
if config.get('inference_normalization', False):
transform = transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(norm_mean.tolist(), norm_std.tolist()),
])
print(f"Fall-back: Using normalization with mean={norm_mean.tolist()}, std={norm_std.tolist()}")
else:
transform = transforms.Compose([
transforms.Resize(input_size),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
])
print(f"Normalization OFF - feeding raw [0,1] tensors to model (normalization applied in the model)")
# Helper function to safely apply transform with fallback for numpy compatibility
def safe_transform(img):
try:
return transform(img)
except TypeError as e:
if "expected np.ndarray" in str(e) or "got numpy.ndarray" in str(e):
# Fallback: manually convert PIL to tensor
print(f"[WARNING] Transform failed with numpy compatibility issue, using manual conversion")
# Apply resize and center crop manually
resize_transform = transforms.Resize(input_size)
crop_transform = transforms.CenterCrop(input_size)
img = crop_transform(resize_transform(img))
# Convert to numpy array and then to tensor using torch.tensor() to avoid numpy compatibility issues
img_array = np.array(img, dtype=np.uint8)
# Use torch.tensor() instead of torch.from_numpy() to avoid compatibility issues
# Convert to float and normalize to [0, 1], then convert from HWC to CHW format
img_tensor = torch.tensor(img_array, dtype=torch.float32).div(255.0).permute(2, 0, 1)
# Apply normalization if needed
if config.get('inference_normalization', False):
img_tensor = transforms.Normalize(norm_mean.tolist(), norm_std.tolist())(img_tensor)
return img_tensor
else:
raise
# Prepare image tensor with safe transform
image_tensor = safe_transform(image).unsqueeze(0).to(device)
image_tensor.requires_grad = True
# Get model components
is_sequential = isinstance(model, nn.Sequential)
if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd):
core_model = model[1]
else:
core_model = model
# Prepare model for layer extraction
if config.get('top_layer', 'all') != 'all':
new_model = extract_middle_layers(core_model, config['top_layer'])
else:
new_model = model
# Get original predictions
with torch.no_grad():
if config.get('inference_normalization', False):
output_original = model(image_tensor)
else:
output_original = core_model(image_tensor)
probs_orig = F.softmax(output_original, dim=1)
conf_orig, classes_orig = torch.max(probs_orig, 1)
# Get least confident classes for IncreaseConfidence
if config['loss_infer'] == 'IncreaseConfidence':
_, least_confident_classes = torch.topk(probs_orig, k=int(n_classes / 10), largest=False)
# Setup for Prior-Guided Drift Diffusion
noisy_features = None
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
print(f"Setting up Prior-Guided Drift Diffusion...")
added_noise = config.get('initial_inference_noise_ratio', 0.05) * torch.randn_like(image_tensor).to(device)
noisy_image_tensor = image_tensor + added_noise
noisy_features = new_model(noisy_image_tensor)
# Initialize inference step
infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
# Storage for inference steps
x = image_tensor.clone().detach().requires_grad_(True)
all_steps = [image_tensor[0].detach().cpu()]
selected_inferred_patterns = []
perceived_categories = []
confidence_list = []
# Main inference loop
print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...")
for i in range(config['n_itr']):
# Reset gradients
x.grad = None
if i == 0:
# Get predictions for first iteration
if config.get('inference_normalization', False):
output = model(x)
else:
output = core_model(x)
if isinstance(output, torch.Tensor) and output.size(-1) == n_classes:
probs = F.softmax(output, dim=1)
conf, classes = torch.max(probs, 1)
else:
probs = 0
conf = 0
classes = 'N/A'
else:
# Calculate loss and gradients
try:
# Forward pass through new_model for feature extraction
features = new_model(x)
if config['loss_infer'] == 'Prior-Guided Drift Diffusion':
assert config.get('loss_function', 'MSE') == 'MSE', "Prior-Guided Drift Diffusion requires MSE loss"
if noisy_features is not None:
loss = F.mse_loss(features, noisy_features)
grad = torch.autograd.grad(loss, x)[0]
adjusted_grad = infer_step.step(x, grad)
else:
raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion")
elif config['loss_infer'] == 'IncreaseConfidence':
# Calculate loss using least confident classes
num_target_classes = min(int(n_classes / 10), least_confident_classes.size(1))
target_classes = least_confident_classes[0, :num_target_classes]
loss = calculate_loss(features, target_classes.tolist(), config.get('loss_function', 'CE'))
grad = torch.autograd.grad(loss, x, retain_graph=True)[0]
adjusted_grad = infer_step.step(x, grad)
else:
raise ValueError(f"Loss inference method {config['loss_infer']} not supported")
if grad is None:
print("Fall-back: Direct gradient calculation failed")
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
x = infer_step.project(x.clone() + random_noise)
else:
# Add diffusion noise if specified
diffusion_noise = config.get('diffusion_noise_ratio', 0.0) * torch.randn_like(x).to(device)
x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise)
except Exception as e:
print(f"Fall-back: Error in gradient calculation: {e}")
random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size']
x = infer_step.project(x.clone() + random_noise)
# Store step if in iterations_to_show
if i+1 in config.get('iterations_to_show', []) or i+1 == config['n_itr']:
all_steps.append(x[0].detach().cpu())
selected_inferred_patterns.append(x[0].detach().cpu())
# Get current predictions
with torch.no_grad():
if config.get('inference_normalization', False):
current_output = model(x)
else:
current_output = core_model(x)
if isinstance(current_output, torch.Tensor) and current_output.size(-1) == n_classes:
current_probs = F.softmax(current_output, dim=1)
current_conf, current_classes = torch.max(current_probs, 1)
perceived_categories.append(current_classes.item())
confidence_list.append(current_conf.item())
else:
perceived_categories.append('N/A')
confidence_list.append(0.0)
# Final predictions
with torch.no_grad():
if config.get('inference_normalization', False):
final_output = model(x)
else:
final_output = core_model(x)
final_probs = F.softmax(final_output, dim=1)
final_conf, final_classes = torch.max(final_probs, 1)
total_time = time.time() - inference_start
print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})")
print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})")
print(f"Total inference time: {total_time:.2f} seconds")
# Return results in Code 1 format
return {
'final_image': x[0].detach().cpu(),
'steps': all_steps,
'original_class': classes_orig.item(),
'original_confidence': conf_orig.item(),
'final_class': final_classes.item(),
'final_confidence': final_conf.item(),
'all_categories': perceived_categories,
'all_confidences': confidence_list,
}
def show_inference_steps(steps, figsize=(15, 10)):
"""Show inference steps using matplotlib."""
try:
import matplotlib.pyplot as plt
n_steps = len(steps)
fig, axes = plt.subplots(1, n_steps, figsize=figsize)
if n_steps == 1:
axes = [axes]
for i, step_img in enumerate(steps):
if isinstance(step_img, torch.Tensor):
img = step_img.permute(1, 2, 0).numpy()
img = np.clip(img, 0, 1)
else:
img = step_img
axes[i].imshow(img)
axes[i].set_title(f"Step {i+1}")
axes[i].axis('off')
plt.tight_layout()
return fig
except ImportError:
print("Fall-back: matplotlib not available for visualization")
return None
except Exception as e:
print(f"Fall-back: Visualization failed: {e}")
return None
# Export the main classes and functions
__all__ = ['GenerativeInferenceModel', 'get_inference_configs', 'show_inference_steps']