Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| import xplique | |
| import tensorflow as tf | |
| from xplique.attributions import ( | |
| Saliency, | |
| GradientInput, | |
| IntegratedGradients, | |
| SmoothGrad, | |
| VarGrad, | |
| SquareGrad, | |
| GradCAM, | |
| Occlusion, | |
| Rise, | |
| GuidedBackprop, | |
| GradCAMPP, | |
| Lime, | |
| KernelShap, | |
| SobolAttributionMethod, | |
| HsicAttributionMethod, | |
| ) | |
| from xplique.attributions.global_sensitivity_analysis import LatinHypercube | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from inference_resnet import inference_resnet_finer, preprocess | |
| from labels import lookup_140 | |
| import cv2 | |
| BATCH_SIZE = 1 | |
| _FAMILY_TO_INDEX = {v: k for k, v in lookup_140.items()} | |
| def letterbox_preprocess(img, size): | |
| """ | |
| Resize image to fit inside (size, size) with black padding. No tiling/duplication. | |
| Returns preprocessed (size, size) array in [0,1] and (top, left, content_h, content_w) | |
| so heatmap content can be extracted and mapped back to original (h, w). | |
| """ | |
| img = np.asarray(img, np.float32) / 255.0 | |
| if img.ndim == 2: | |
| img = np.stack([img] * 3, axis=-1) | |
| h, w = img.shape[:2] | |
| scale = min(size / h, size / w) | |
| content_h = int(round(h * scale)) | |
| content_w = int(round(w * scale)) | |
| resized = cv2.resize(img, (content_w, content_h), interpolation=cv2.INTER_AREA) | |
| top = (size - content_h) // 2 | |
| left = (size - content_w) // 2 | |
| out = np.zeros((size, size, 3), dtype=np.float32) | |
| out[top : top + content_h, left : left + content_w] = resized | |
| return np.clip(out, 0, 1).astype(np.float32), (top, left, content_h, content_w) | |
| def preprocess_image(image, output_size=(300, 300)): | |
| #shape (height, width, channels) | |
| h, w = image.shape[:2] | |
| #padding | |
| if h > w: | |
| padding = (h - w) // 2 | |
| image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
| else: | |
| padding = (w - h) // 2 | |
| image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
| # resize | |
| image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA) | |
| return image_resized | |
| def transform(image, original_size,output_size): | |
| """ | |
| resize xai output back to original scale and pad to square-shape | |
| """ | |
| h,w = original_size | |
| image = cv2.resize(image,(h,w), interpolation = cv2.INTER_AREA) | |
| if h > w: | |
| padding = (h - w) // 2 | |
| image= cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
| else: | |
| padding = (w - h) // 2 | |
| image = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
| image = cv2.resize(image,output_size, interpolation = cv2.INTER_AREA) | |
| return image | |
| def show(img, original_size, output_size,p=False, **kwargs): | |
| #img = preprocess_image(img, output_size=(output_size,output_size)) | |
| # check if channel first | |
| if img.shape[0] == 1: | |
| img = img[0] | |
| # check if cmap | |
| if img.shape[-1] == 1: | |
| img = img[:,:,0] | |
| elif img.shape[-1] == 3: | |
| img = img[:,:,::-1] | |
| # normalize | |
| if img.max() > 1 or img.min() < 0: | |
| img -= img.min(); img/=img.max() | |
| # check if clip percentile | |
| if p is not False: | |
| img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p)) | |
| img = transform(img,original_size=original_size,output_size=output_size) | |
| plt.imshow(img, **kwargs) | |
| plt.axis('off') | |
| #return img | |
| def explain( | |
| model, | |
| input_image, | |
| h, | |
| w, | |
| explain_method, | |
| nb_samples, | |
| size=600, | |
| n_classes=171, | |
| heatmap_alpha=0.22, | |
| target_labels=None, | |
| ): | |
| """ | |
| Generate explanations for a given model and dataset. | |
| :param model: The model to explain. | |
| :param X: The dataset. | |
| :param Y: The labels. | |
| :param explainer: The explainer to use. | |
| :param batch_size: The batch size to use. | |
| :return: The explanations. | |
| """ | |
| print('using explain_method:',explain_method) | |
| # we only need the classification part of the model | |
| class_model = tf.keras.Model(model.input, model.output[1]) | |
| explainers = [] | |
| if explain_method=="Sobol": | |
| explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32)) | |
| if explain_method=="HSIC": | |
| explainers.append(HsicAttributionMethod(class_model, | |
| grid_size=7, nb_design=1500, | |
| sampler = LatinHypercube(binary=True))) | |
| if explain_method=="Rise": | |
| explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15, | |
| preservation_probability=0.5)) | |
| if explain_method=="Saliency": | |
| explainers.append(Saliency(class_model)) | |
| # explainers = [ | |
| # #Sobol, RISE, HSIC, Saliency | |
| # #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE), | |
| # #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE), | |
| # #GradCAM(class_model), | |
| # SobolAttributionMethod(class_model, grid_size=8, nb_design=32), | |
| # HsicAttributionMethod(class_model, | |
| # grid_size=7, nb_design=1500, | |
| # sampler = LatinHypercube(binary=True)), | |
| # Saliency(class_model), | |
| # Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15, | |
| # preservation_probability=0.5), | |
| # # | |
| # ] | |
| # Resize to fit (size, size) with black padding only—no tiling/duplication. | |
| X, (content_top, content_left, content_h, content_w) = letterbox_preprocess(input_image, size) | |
| # Mask: 1 in content region, 0 in zero-padded areas (so we don't show heatmap on padding) | |
| content_mask = np.zeros((size, size), dtype=np.float32) | |
| content_mask[content_top : content_top + content_h, content_left : content_left + content_w] = 1.0 | |
| # Determine which classes to explain: | |
| # - If target_labels are provided (from classifier output), use those. | |
| # - Otherwise, fall back to top-5 classes from this forward pass. | |
| classes = [] | |
| if target_labels: | |
| indices = [] | |
| for name in target_labels: | |
| idx = _FAMILY_TO_INDEX.get(name) | |
| if idx is not None: | |
| indices.append(idx) | |
| if indices: | |
| top_5_indices = np.array(indices, dtype=int) | |
| classes = [lookup_140[i] for i in top_5_indices] | |
| else: | |
| predictions = class_model.predict(np.array([X])) | |
| top_5_indices = np.argsort(predictions[0])[-5:][::-1] | |
| classes = [lookup_140[i] for i in top_5_indices] | |
| else: | |
| predictions = class_model.predict(np.array([X])) | |
| top_5_indices = np.argsort(predictions[0])[-5:][::-1] | |
| classes = [lookup_140[i] for i in top_5_indices] | |
| #print(top_5_indices) | |
| X = np.expand_dims(X, 0) | |
| explanations = [] | |
| for e,explainer in enumerate(explainers): | |
| print(f'{e}/{len(explainers)}') | |
| for i,Y in enumerate(top_5_indices): | |
| Y = tf.one_hot([Y], n_classes) | |
| print(f'{i}/{len(top_5_indices)}') | |
| phi = np.abs(explainer(X, Y))[0] | |
| if len(phi.shape) == 3: | |
| phi = np.mean(phi, -1) | |
| #apply Gaussian smoothing | |
| phi_smoothed = cv2.GaussianBlur(phi, (5, 5), sigmaX=1.0, sigmaY=1.0) | |
| # Overlay heatmap on preprocessed image (same coords as phi) for testing | |
| plt.clf() | |
| prep = np.asarray(X[0]).copy() | |
| if prep.max() > 1.0: | |
| prep = prep.astype(np.float32) / 255.0 | |
| prep = np.clip(prep, 0, 1) | |
| plt.imshow(prep) | |
| phi_display = np.abs(phi_smoothed) | |
| if phi_display.max() > 1 or phi_display.min() < 0: | |
| phi_display = phi_display - phi_display.min() | |
| if phi_display.max() > 0: | |
| phi_display = phi_display / phi_display.max() | |
| phi_display = np.clip(phi_display, np.percentile(phi_display, 1), np.percentile(phi_display, 99)) | |
| if phi_display.max() > 0: | |
| phi_display = phi_display / phi_display.max() | |
| # Heatmap only over content: mask out zero-padded areas (alpha=0 there) | |
| try: | |
| heatmap_cmap = plt.get_cmap('managua') | |
| except (ValueError, TypeError): | |
| heatmap_cmap = plt.get_cmap('magma') | |
| rgba = heatmap_cmap(phi_display) | |
| rgba[..., 3] = rgba[..., 3] * content_mask * heatmap_alpha | |
| plt.imshow(rgba) | |
| plt.axis('off') | |
| plt.savefig(f'phi_{e}{i}.png') | |
| explanations.append(f'phi_{e}{i}.png') | |
| # avg=[] | |
| # for i,Y in enumerate(top_5_indices): | |
| # Y = tf.one_hot([Y], n_classes) | |
| # print(f'{i}/{len(top_5_indices)}') | |
| # phi = np.abs(explainer(X, Y))[0] | |
| # if len(phi.shape) == 3: | |
| # phi = np.mean(phi, -1) | |
| # show(X[0][:,size_repetitions:2*size_repetitions,:]) | |
| # show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4) | |
| # plt.savefig(f'phi_6.png') | |
| # avg.append(f'phi_6.png') | |
| print('Done') | |
| if len(explanations)==1: | |
| explanations = explanations[0] | |
| # return explanations,avg | |
| return classes,explanations | |