Spaces:
Sleeping
Sleeping
| """Complete generative inference module with model loading and inference capabilities.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| from torchvision.models.resnet import ResNet50_Weights | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import requests | |
| import time | |
| import copy | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| # Check for available hardware acceleration | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| device = torch.device("mps") # Use Apple Metal Performance Shaders for M-series Macs | |
| else: | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| # Constants for model URLs | |
| MODEL_URLS = { | |
| 'resnet50_robust': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps3.ckpt', | |
| 'resnet50_standard': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps0.ckpt', | |
| 'resnet50_robust_face': 'https://huggingface.co/ttoosi/resnet50_robust_face/resolve/main/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt' | |
| } | |
| # Model-specific preprocessing configurations | |
| MODEL_CONFIGS = { | |
| 'resnet50_robust_face': { | |
| 'input_size': 112, | |
| 'norm_mean': [0.5, 0.5, 0.5], | |
| 'norm_std': [0.5, 0.5, 0.5], | |
| 'n_classes': 500, | |
| 'dataset': 'VGGFace2' | |
| }, | |
| 'resnet50_standard': { | |
| 'input_size': 224, | |
| 'norm_mean': [0.485, 0.456, 0.406], | |
| 'norm_std': [0.229, 0.224, 0.225], | |
| 'n_classes': 1000, | |
| 'dataset': 'ImageNet' | |
| }, | |
| 'resnet50_robust': { | |
| 'input_size': 224, | |
| 'norm_mean': [0.485, 0.456, 0.406], | |
| 'norm_std': [0.229, 0.224, 0.225], | |
| 'n_classes': 1000, | |
| 'dataset': 'ImageNet' | |
| } | |
| } | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| def get_iterations_to_show(n_itr): | |
| """Generate a dynamic list of iterations to show based on total iterations.""" | |
| if n_itr <= 50: | |
| return [1, 5, 10, 20, 30, 40, 50, n_itr] | |
| elif n_itr <= 100: | |
| return [1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, n_itr] | |
| elif n_itr <= 200: | |
| return [1, 5, 10, 20, 30, 40, 50, 75, 100, 125, 150, 175, 200, n_itr] | |
| elif n_itr <= 500: | |
| return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, n_itr] | |
| else: | |
| return [1, 5, 10, 20, 30, 40, 50, 75, 100, 150, 200, 250, 300, 350, 400, 450, 500, | |
| int(n_itr*0.6), int(n_itr*0.7), int(n_itr*0.8), int(n_itr*0.9), n_itr] | |
| def get_inference_configs(inference_type='IncreaseConfidence', eps=0.5, n_itr=50, step_size=1.0): | |
| """Generate inference configuration with customizable parameters.""" | |
| config = { | |
| 'loss_infer': inference_type, | |
| 'n_itr': n_itr, | |
| 'eps': eps, | |
| 'step_size': step_size, | |
| 'diffusion_noise_ratio': 0.0, | |
| 'initial_inference_noise_ratio': 0.0, | |
| 'top_layer': 'all', | |
| 'inference_normalization': False, | |
| 'recognition_normalization': False, | |
| 'iterations_to_show': get_iterations_to_show(n_itr), | |
| 'misc_info': {'keep_grads': False} | |
| } | |
| if inference_type == 'IncreaseConfidence': | |
| config['loss_function'] = 'CE' | |
| elif inference_type == 'Prior-Guided Drift Diffusion': | |
| config['loss_function'] = 'MSE' | |
| config['initial_inference_noise_ratio'] = 0.05 | |
| config['diffusion_noise_ratio'] = 0.01 | |
| config['top_layer'] = 'layer4' | |
| elif inference_type == 'GradModulation': | |
| config['loss_function'] = 'CE' | |
| config['misc_info']['grad_modulation'] = 0.5 | |
| elif inference_type == 'CompositionalFusion': | |
| config['loss_function'] = 'CE' | |
| config['misc_info']['positive_classes'] = [] | |
| config['misc_info']['negative_classes'] = [] | |
| return config | |
| def get_model_preprocessing(model_type: str) -> Dict: | |
| """Get preprocessing configuration for specific model type.""" | |
| if model_type not in MODEL_CONFIGS: | |
| print(f"Fall-back: Unknown model type {model_type}, using ImageNet defaults") | |
| return MODEL_CONFIGS['resnet50_standard'] | |
| return MODEL_CONFIGS[model_type] | |
| class NormalizeByChannelMeanStd(nn.Module): | |
| """Normalization layer for models.""" | |
| def __init__(self, mean, std): | |
| super(NormalizeByChannelMeanStd, self).__init__() | |
| if not isinstance(mean, torch.Tensor): | |
| mean = torch.tensor(mean) | |
| if not isinstance(std, torch.Tensor): | |
| std = torch.tensor(std) | |
| self.register_buffer("mean", mean) | |
| self.register_buffer("std", std) | |
| def forward(self, tensor): | |
| return self.normalize_fn(tensor, self.mean, self.std) | |
| def normalize_fn(self, tensor, mean, std): | |
| """Differentiable version of torchvision.functional.normalize""" | |
| mean = mean[None, :, None, None] | |
| std = std[None, :, None, None] | |
| return tensor.sub(mean).div(std) | |
| class InferStep: | |
| """Inference step class for gradient-based optimization.""" | |
| def __init__(self, orig_image: torch.Tensor, eps: float, step_size: float): | |
| self.orig_image = orig_image | |
| self.eps = eps | |
| self.step_size = step_size | |
| def project(self, x: torch.Tensor) -> torch.Tensor: | |
| """Project x onto epsilon-ball around original image.""" | |
| diff = x - self.orig_image | |
| diff = torch.clamp(diff, -self.eps, self.eps) | |
| return torch.clamp(self.orig_image + diff, 0, 1) | |
| def step(self, x: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: | |
| """Take a normalized gradient step.""" | |
| dim = len(x.shape) - 1 | |
| grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=1).reshape(-1, *([1] * dim)) | |
| scaled_grad = grad / (grad_norm + 1e-10) | |
| return scaled_grad * self.step_size | |
| def extract_middle_layers(model: nn.Module, layer_index: Union[str, int]) -> nn.Module: | |
| """Extract middle layers from a model up to a specified layer index.""" | |
| if isinstance(layer_index, str) and layer_index == 'all': | |
| return model | |
| # Handle ResNet layer extraction | |
| modules = list(model.named_children()) | |
| cutoff_idx = next( | |
| (i for i, (name, _) in enumerate(modules) if name == str(layer_index)), | |
| None | |
| ) | |
| if cutoff_idx is not None: | |
| new_model = nn.Sequential(OrderedDict(modules[:cutoff_idx + 1])) | |
| return new_model | |
| else: | |
| print(f"Fall-back: Module {layer_index} not found, using full model") | |
| return model | |
| def calculate_loss(output_model: torch.Tensor, class_indices: List[int], loss_inference: str) -> torch.Tensor: | |
| """Calculate loss for specified class indices.""" | |
| losses = [] | |
| for idx in class_indices: | |
| target = torch.full((1,), idx, dtype=torch.long, device=output_model.device) | |
| if loss_inference == 'CE': | |
| loss = nn.CrossEntropyLoss()(output_model, target) | |
| elif loss_inference == 'MSE': | |
| one_hot_target = torch.zeros_like(output_model) | |
| one_hot_target[0, target] = 1 | |
| loss = nn.MSELoss()(output_model, one_hot_target) | |
| else: | |
| raise ValueError(f"Unsupported loss_inference: {loss_inference}") | |
| losses.append(loss) | |
| return torch.stack(losses).mean() | |
| def download_model(model_type): | |
| """Download model if needed.""" | |
| if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None: | |
| return None | |
| os.makedirs("models", exist_ok=True) | |
| if model_type == 'resnet50_robust_face': | |
| model_path = Path("models/resnet50_vggface2_L2_eps_0.50_checkpoint150.pt") | |
| else: | |
| model_path = Path(f"models/{model_type}.pt") | |
| if not model_path.exists(): | |
| print(f"Downloading {model_type} model...") | |
| url = MODEL_URLS[model_type] | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Model downloaded and saved to {model_path}") | |
| else: | |
| raise RuntimeError(f"Failed to download model: {response.status_code}") | |
| return model_path | |
| class GenerativeInferenceModel: | |
| """Complete generative inference model with model loading and inference.""" | |
| def __init__(self): | |
| self.models = {} | |
| self.model_preproc = {} | |
| self.labels = self.get_imagenet_labels() | |
| def get_imagenet_labels(self): | |
| """Get ImageNet labels.""" | |
| url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
| try: | |
| response = requests.get(url, timeout=10) # Add timeout to prevent hanging | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| print("Fall-back: Failed to fetch ImageNet labels, using placeholder") | |
| return [f"class_{i}" for i in range(1000)] | |
| except Exception as e: | |
| print(f"Fall-back: Error fetching labels: {e}") | |
| return [f"class_{i}" for i in range(1000)] | |
| def load_model(self, model_type): | |
| """Load and cache models for different model types.""" | |
| if model_type in self.models: | |
| print(f"Using cached {model_type} model") | |
| return self.models[model_type] | |
| start_time = time.time() | |
| # Get model-specific preprocessing config | |
| preproc_config = get_model_preprocessing(model_type) | |
| self.model_preproc[model_type] = preproc_config | |
| # Create normalizer | |
| normalizer = NormalizeByChannelMeanStd( | |
| preproc_config['norm_mean'], | |
| preproc_config['norm_std'] | |
| ).to(device) | |
| # Create base model architecture | |
| num_classes = preproc_config['n_classes'] | |
| resnet = models.resnet50(num_classes=num_classes) | |
| model = nn.Sequential(normalizer, resnet) | |
| # Download and load checkpoint | |
| model_path = download_model(model_type) | |
| if model_path: | |
| print(f"Loading {model_type} model from {model_path}...") | |
| try: | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Handle different checkpoint formats | |
| if 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| print("Using 'model' key from checkpoint") | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| print("Using 'state_dict' key from checkpoint") | |
| else: | |
| state_dict = checkpoint | |
| print("Using checkpoint directly as state_dict") | |
| # Extract ResNet state dict | |
| resnet_state_dict = {} | |
| resnet_keys = set(resnet.state_dict().keys()) | |
| # For face model, prioritize 'module.model.model.' structure (seen in actual checkpoint) | |
| if model_type == 'resnet50_robust_face': | |
| # Check for 'module.model.model.' structure first (face checkpoints use this) | |
| module_model_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.model.')] | |
| if module_model_model_keys: | |
| print(f"Found 'module.model.model.' structure with {len(module_model_model_keys)} parameters (face model)") | |
| for source_key, value in state_dict.items(): | |
| if source_key.startswith('module.model.model.'): | |
| target_key = source_key[len('module.model.model.'):] | |
| if target_key in resnet_keys: | |
| resnet_state_dict[target_key] = value | |
| print(f"Extracted {len(resnet_state_dict)} parameters from module.model.model.") | |
| # Also check for 'module.model.' structure as fallback | |
| if len(resnet_state_dict) < len(resnet_keys): | |
| module_model_keys = [key for key in state_dict.keys() if key.startswith('module.model.') and not key.startswith('module.model.model.')] | |
| if module_model_keys: | |
| print(f"Found additional 'module.model.' structure with {len(module_model_keys)} parameters") | |
| for source_key, value in state_dict.items(): | |
| if source_key.startswith('module.model.') and not source_key.startswith('module.model.model.'): | |
| target_key = source_key[len('module.model.'):] | |
| # Remove extra 'model.' if present | |
| if target_key.startswith('model.'): | |
| target_key = target_key[len('model.'):] | |
| if target_key in resnet_keys and target_key not in resnet_state_dict: | |
| resnet_state_dict[target_key] = value | |
| print(f"Now have {len(resnet_state_dict)} parameters after adding module.model. keys") | |
| # Handle different key prefixes in checkpoints (for other models) | |
| if len(resnet_state_dict) == 0: | |
| prefixes_to_try = ['', 'module.', 'model.', 'attacker.model.', 'attacker.'] | |
| for source_key, value in state_dict.items(): | |
| target_key = source_key | |
| # Try removing various prefixes | |
| for prefix in prefixes_to_try: | |
| if source_key.startswith(prefix): | |
| target_key = source_key[len(prefix):] | |
| break | |
| # Handle nested model keys | |
| if target_key.startswith('model.'): | |
| target_key = target_key[len('model.'):] | |
| # If the target key is in ResNet keys, add it | |
| if target_key in resnet_keys: | |
| resnet_state_dict[target_key] = value | |
| # Load the state dict | |
| if resnet_state_dict: | |
| result = resnet.load_state_dict(resnet_state_dict, strict=False) | |
| missing_keys, unexpected_keys = result | |
| loaded_percent = (len(resnet_state_dict) / len(resnet_keys)) * 100 | |
| print(f"Model loading: {len(resnet_state_dict)}/{len(resnet_keys)} parameters ({loaded_percent:.1f}%)") | |
| if loaded_percent < 50: | |
| print(f"Fall-back: Loading too incomplete ({loaded_percent:.1f}%), using PyTorch pretrained") | |
| if model_type != 'resnet50_robust_face': | |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = nn.Sequential(normalizer, resnet) | |
| else: | |
| print("Fall-back: No matching keys found in checkpoint, using PyTorch pretrained") | |
| if model_type != 'resnet50_robust_face': | |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = nn.Sequential(normalizer, resnet) | |
| except Exception as e: | |
| print(f"Fall-back: Error loading checkpoint: {e}") | |
| if model_type != 'resnet50_robust_face': | |
| print("Fall-back: Using PyTorch pretrained model") | |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = nn.Sequential(normalizer, resnet) | |
| else: | |
| print("Fall-back: Face model checkpoint failed, model may not work properly") | |
| else: | |
| # Use PyTorch's pretrained model for ImageNet models | |
| if model_type != 'resnet50_robust_face': | |
| print(f"No checkpoint for {model_type}, using PyTorch pretrained") | |
| resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = nn.Sequential(normalizer, resnet) | |
| else: | |
| print("Fall-back: Face model requires checkpoint, model may not work properly") | |
| model = model.to(device) | |
| model.eval() | |
| # Verify model | |
| self.verify_model_integrity(model, model_type) | |
| # Cache the model | |
| self.models[model_type] = model | |
| end_time = time.time() | |
| print(f"Model {model_type} loaded in {end_time - start_time:.2f} seconds") | |
| return model | |
| def verify_model_integrity(self, model, model_type): | |
| """Verify model integrity.""" | |
| try: | |
| print(f"Fall-back: Running model integrity check for {model_type}") | |
| config = get_model_preprocessing(model_type) | |
| H = W = config['input_size'] | |
| test_input = torch.zeros(1, 3, H, W, device=device) | |
| test_input[0, 0, H//4:3*H//4, W//4:3*W//4] = 0.5 | |
| with torch.no_grad(): | |
| output = model(test_input) | |
| expected_classes = config['n_classes'] | |
| if output.shape != (1, expected_classes): | |
| print(f"Fall-back: Unexpected output shape: {output.shape}, expected (1, {expected_classes})") | |
| return False | |
| probs = torch.nn.functional.softmax(output, dim=1) | |
| confidence, prediction = torch.max(probs, 1) | |
| print(f"Model integrity check passed:") | |
| print(f"- Output shape: {output.shape}") | |
| print(f"- Top prediction: Class {prediction.item()} with {confidence.item()*100:.2f}% confidence") | |
| return True | |
| except Exception as e: | |
| print(f"Fall-back: Model integrity check failed with error: {e}") | |
| return False | |
| def inference(self, image, model_type, config): | |
| """Run generative inference.""" | |
| inference_start = time.time() | |
| # Load the model | |
| model = self.load_model(model_type) | |
| # Handle image input | |
| if isinstance(image, str): | |
| if os.path.exists(image): | |
| image = Image.open(image).convert('RGB') | |
| else: | |
| raise ValueError(f"Image path does not exist: {image}") | |
| elif isinstance(image, np.ndarray): | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| if len(image.shape) == 3: | |
| if image.shape[0] == 3 or image.shape[0] == 1: | |
| image = np.transpose(image, (1, 2, 0)) | |
| if image.shape[2] == 4: | |
| image = image[:, :, :3] | |
| elif image.shape[2] == 1: | |
| image = np.repeat(image, 3, axis=2) | |
| image = Image.fromarray(image) | |
| elif not isinstance(image, Image.Image): | |
| try: | |
| image = Image.fromarray(np.array(image)).convert('RGB') | |
| except Exception as e: | |
| raise ValueError(f"Cannot convert image type {type(image)} to PIL Image: {e}") | |
| if isinstance(image, Image.Image) and image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Get preprocessing config | |
| preproc_config = get_model_preprocessing(model_type) | |
| input_size = preproc_config['input_size'] | |
| norm_mean = torch.tensor(preproc_config['norm_mean']) | |
| norm_std = torch.tensor(preproc_config['norm_std']) | |
| n_classes = preproc_config['n_classes'] | |
| # Create transform | |
| if config.get('inference_normalization', False): | |
| transform = transforms.Compose([ | |
| transforms.Resize(input_size), | |
| transforms.CenterCrop(input_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(norm_mean.tolist(), norm_std.tolist()), | |
| ]) | |
| print(f"Fall-back: Using normalization with mean={norm_mean.tolist()}, std={norm_std.tolist()}") | |
| else: | |
| transform = transforms.Compose([ | |
| transforms.Resize(input_size), | |
| transforms.CenterCrop(input_size), | |
| transforms.ToTensor(), | |
| ]) | |
| print(f"Normalization OFF - feeding raw [0,1] tensors to model (normalization applied in the model)") | |
| # Helper function to safely apply transform with fallback for numpy compatibility | |
| def safe_transform(img): | |
| try: | |
| return transform(img) | |
| except TypeError as e: | |
| if "expected np.ndarray" in str(e) or "got numpy.ndarray" in str(e): | |
| # Fallback: manually convert PIL to tensor | |
| print(f"[WARNING] Transform failed with numpy compatibility issue, using manual conversion") | |
| # Apply resize and center crop manually | |
| resize_transform = transforms.Resize(input_size) | |
| crop_transform = transforms.CenterCrop(input_size) | |
| img = crop_transform(resize_transform(img)) | |
| # Convert to numpy array and then to tensor using torch.tensor() to avoid numpy compatibility issues | |
| img_array = np.array(img, dtype=np.uint8) | |
| # Use torch.tensor() instead of torch.from_numpy() to avoid compatibility issues | |
| # Convert to float and normalize to [0, 1], then convert from HWC to CHW format | |
| img_tensor = torch.tensor(img_array, dtype=torch.float32).div(255.0).permute(2, 0, 1) | |
| # Apply normalization if needed | |
| if config.get('inference_normalization', False): | |
| img_tensor = transforms.Normalize(norm_mean.tolist(), norm_std.tolist())(img_tensor) | |
| return img_tensor | |
| else: | |
| raise | |
| # Prepare image tensor with safe transform | |
| image_tensor = safe_transform(image).unsqueeze(0).to(device) | |
| image_tensor.requires_grad = True | |
| # Get model components | |
| is_sequential = isinstance(model, nn.Sequential) | |
| if is_sequential and isinstance(model[0], NormalizeByChannelMeanStd): | |
| core_model = model[1] | |
| else: | |
| core_model = model | |
| # Prepare model for layer extraction | |
| if config.get('top_layer', 'all') != 'all': | |
| new_model = extract_middle_layers(core_model, config['top_layer']) | |
| else: | |
| new_model = model | |
| # Get original predictions | |
| with torch.no_grad(): | |
| if config.get('inference_normalization', False): | |
| output_original = model(image_tensor) | |
| else: | |
| output_original = core_model(image_tensor) | |
| probs_orig = F.softmax(output_original, dim=1) | |
| conf_orig, classes_orig = torch.max(probs_orig, 1) | |
| # Get least confident classes for IncreaseConfidence | |
| if config['loss_infer'] == 'IncreaseConfidence': | |
| _, least_confident_classes = torch.topk(probs_orig, k=int(n_classes / 10), largest=False) | |
| # Setup for Prior-Guided Drift Diffusion | |
| noisy_features = None | |
| if config['loss_infer'] == 'Prior-Guided Drift Diffusion': | |
| print(f"Setting up Prior-Guided Drift Diffusion...") | |
| added_noise = config.get('initial_inference_noise_ratio', 0.05) * torch.randn_like(image_tensor).to(device) | |
| noisy_image_tensor = image_tensor + added_noise | |
| noisy_features = new_model(noisy_image_tensor) | |
| # Initialize inference step | |
| infer_step = InferStep(image_tensor, config['eps'], config['step_size']) | |
| # Storage for inference steps | |
| x = image_tensor.clone().detach().requires_grad_(True) | |
| all_steps = [image_tensor[0].detach().cpu()] | |
| selected_inferred_patterns = [] | |
| perceived_categories = [] | |
| confidence_list = [] | |
| # Main inference loop | |
| print(f"Starting inference loop with {config['n_itr']} iterations for {config['loss_infer']}...") | |
| for i in range(config['n_itr']): | |
| # Reset gradients | |
| x.grad = None | |
| if i == 0: | |
| # Get predictions for first iteration | |
| if config.get('inference_normalization', False): | |
| output = model(x) | |
| else: | |
| output = core_model(x) | |
| if isinstance(output, torch.Tensor) and output.size(-1) == n_classes: | |
| probs = F.softmax(output, dim=1) | |
| conf, classes = torch.max(probs, 1) | |
| else: | |
| probs = 0 | |
| conf = 0 | |
| classes = 'N/A' | |
| else: | |
| # Calculate loss and gradients | |
| try: | |
| # Forward pass through new_model for feature extraction | |
| features = new_model(x) | |
| if config['loss_infer'] == 'Prior-Guided Drift Diffusion': | |
| assert config.get('loss_function', 'MSE') == 'MSE', "Prior-Guided Drift Diffusion requires MSE loss" | |
| if noisy_features is not None: | |
| loss = F.mse_loss(features, noisy_features) | |
| grad = torch.autograd.grad(loss, x)[0] | |
| adjusted_grad = infer_step.step(x, grad) | |
| else: | |
| raise ValueError("Noisy features not computed for Prior-Guided Drift Diffusion") | |
| elif config['loss_infer'] == 'IncreaseConfidence': | |
| # Calculate loss using least confident classes | |
| num_target_classes = min(int(n_classes / 10), least_confident_classes.size(1)) | |
| target_classes = least_confident_classes[0, :num_target_classes] | |
| loss = calculate_loss(features, target_classes.tolist(), config.get('loss_function', 'CE')) | |
| grad = torch.autograd.grad(loss, x, retain_graph=True)[0] | |
| adjusted_grad = infer_step.step(x, grad) | |
| else: | |
| raise ValueError(f"Loss inference method {config['loss_infer']} not supported") | |
| if grad is None: | |
| print("Fall-back: Direct gradient calculation failed") | |
| random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] | |
| x = infer_step.project(x.clone() + random_noise) | |
| else: | |
| # Add diffusion noise if specified | |
| diffusion_noise = config.get('diffusion_noise_ratio', 0.0) * torch.randn_like(x).to(device) | |
| x = infer_step.project(x.clone() + adjusted_grad + diffusion_noise) | |
| except Exception as e: | |
| print(f"Fall-back: Error in gradient calculation: {e}") | |
| random_noise = (torch.rand_like(x) - 0.5) * 2 * config['step_size'] | |
| x = infer_step.project(x.clone() + random_noise) | |
| # Store step if in iterations_to_show | |
| if i+1 in config.get('iterations_to_show', []) or i+1 == config['n_itr']: | |
| all_steps.append(x[0].detach().cpu()) | |
| selected_inferred_patterns.append(x[0].detach().cpu()) | |
| # Get current predictions | |
| with torch.no_grad(): | |
| if config.get('inference_normalization', False): | |
| current_output = model(x) | |
| else: | |
| current_output = core_model(x) | |
| if isinstance(current_output, torch.Tensor) and current_output.size(-1) == n_classes: | |
| current_probs = F.softmax(current_output, dim=1) | |
| current_conf, current_classes = torch.max(current_probs, 1) | |
| perceived_categories.append(current_classes.item()) | |
| confidence_list.append(current_conf.item()) | |
| else: | |
| perceived_categories.append('N/A') | |
| confidence_list.append(0.0) | |
| # Final predictions | |
| with torch.no_grad(): | |
| if config.get('inference_normalization', False): | |
| final_output = model(x) | |
| else: | |
| final_output = core_model(x) | |
| final_probs = F.softmax(final_output, dim=1) | |
| final_conf, final_classes = torch.max(final_probs, 1) | |
| total_time = time.time() - inference_start | |
| print(f"Original top class: {classes_orig.item()} ({conf_orig.item():.4f})") | |
| print(f"Final top class: {final_classes.item()} ({final_conf.item():.4f})") | |
| print(f"Total inference time: {total_time:.2f} seconds") | |
| # Return results in Code 1 format | |
| return { | |
| 'final_image': x[0].detach().cpu(), | |
| 'steps': all_steps, | |
| 'original_class': classes_orig.item(), | |
| 'original_confidence': conf_orig.item(), | |
| 'final_class': final_classes.item(), | |
| 'final_confidence': final_conf.item(), | |
| 'all_categories': perceived_categories, | |
| 'all_confidences': confidence_list, | |
| } | |
| def show_inference_steps(steps, figsize=(15, 10)): | |
| """Show inference steps using matplotlib.""" | |
| try: | |
| import matplotlib.pyplot as plt | |
| n_steps = len(steps) | |
| fig, axes = plt.subplots(1, n_steps, figsize=figsize) | |
| if n_steps == 1: | |
| axes = [axes] | |
| for i, step_img in enumerate(steps): | |
| if isinstance(step_img, torch.Tensor): | |
| img = step_img.permute(1, 2, 0).numpy() | |
| img = np.clip(img, 0, 1) | |
| else: | |
| img = step_img | |
| axes[i].imshow(img) | |
| axes[i].set_title(f"Step {i+1}") | |
| axes[i].axis('off') | |
| plt.tight_layout() | |
| return fig | |
| except ImportError: | |
| print("Fall-back: matplotlib not available for visualization") | |
| return None | |
| except Exception as e: | |
| print(f"Fall-back: Visualization failed: {e}") | |
| return None | |
| # Export the main classes and functions | |
| __all__ = ['GenerativeInferenceModel', 'get_inference_configs', 'show_inference_steps'] | |