fossil_app / explanations.py
piperod91's picture
Explanations: align families with classifier top-5; accept target_labels from UI
81d08ae
import xplique
import tensorflow as tf
from xplique.attributions import (
Saliency,
GradientInput,
IntegratedGradients,
SmoothGrad,
VarGrad,
SquareGrad,
GradCAM,
Occlusion,
Rise,
GuidedBackprop,
GradCAMPP,
Lime,
KernelShap,
SobolAttributionMethod,
HsicAttributionMethod,
)
from xplique.attributions.global_sensitivity_analysis import LatinHypercube
import numpy as np
import matplotlib.pyplot as plt
from inference_resnet import inference_resnet_finer, preprocess
from labels import lookup_140
import cv2
BATCH_SIZE = 1
_FAMILY_TO_INDEX = {v: k for k, v in lookup_140.items()}
def letterbox_preprocess(img, size):
"""
Resize image to fit inside (size, size) with black padding. No tiling/duplication.
Returns preprocessed (size, size) array in [0,1] and (top, left, content_h, content_w)
so heatmap content can be extracted and mapped back to original (h, w).
"""
img = np.asarray(img, np.float32) / 255.0
if img.ndim == 2:
img = np.stack([img] * 3, axis=-1)
h, w = img.shape[:2]
scale = min(size / h, size / w)
content_h = int(round(h * scale))
content_w = int(round(w * scale))
resized = cv2.resize(img, (content_w, content_h), interpolation=cv2.INTER_AREA)
top = (size - content_h) // 2
left = (size - content_w) // 2
out = np.zeros((size, size, 3), dtype=np.float32)
out[top : top + content_h, left : left + content_w] = resized
return np.clip(out, 0, 1).astype(np.float32), (top, left, content_h, content_w)
def preprocess_image(image, output_size=(300, 300)):
#shape (height, width, channels)
h, w = image.shape[:2]
#padding
if h > w:
padding = (h - w) // 2
image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
else:
padding = (w - h) // 2
image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
# resize
image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)
return image_resized
def transform(image, original_size,output_size):
"""
resize xai output back to original scale and pad to square-shape
"""
h,w = original_size
image = cv2.resize(image,(h,w), interpolation = cv2.INTER_AREA)
if h > w:
padding = (h - w) // 2
image= cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
else:
padding = (w - h) // 2
image = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
image = cv2.resize(image,output_size, interpolation = cv2.INTER_AREA)
return image
def show(img, original_size, output_size,p=False, **kwargs):
#img = preprocess_image(img, output_size=(output_size,output_size))
# check if channel first
if img.shape[0] == 1:
img = img[0]
# check if cmap
if img.shape[-1] == 1:
img = img[:,:,0]
elif img.shape[-1] == 3:
img = img[:,:,::-1]
# normalize
if img.max() > 1 or img.min() < 0:
img -= img.min(); img/=img.max()
# check if clip percentile
if p is not False:
img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
img = transform(img,original_size=original_size,output_size=output_size)
plt.imshow(img, **kwargs)
plt.axis('off')
#return img
def explain(
model,
input_image,
h,
w,
explain_method,
nb_samples,
size=600,
n_classes=171,
heatmap_alpha=0.22,
target_labels=None,
):
"""
Generate explanations for a given model and dataset.
:param model: The model to explain.
:param X: The dataset.
:param Y: The labels.
:param explainer: The explainer to use.
:param batch_size: The batch size to use.
:return: The explanations.
"""
print('using explain_method:',explain_method)
# we only need the classification part of the model
class_model = tf.keras.Model(model.input, model.output[1])
explainers = []
if explain_method=="Sobol":
explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
if explain_method=="HSIC":
explainers.append(HsicAttributionMethod(class_model,
grid_size=7, nb_design=1500,
sampler = LatinHypercube(binary=True)))
if explain_method=="Rise":
explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
preservation_probability=0.5))
if explain_method=="Saliency":
explainers.append(Saliency(class_model))
# explainers = [
# #Sobol, RISE, HSIC, Saliency
# #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
# #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
# #GradCAM(class_model),
# SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
# HsicAttributionMethod(class_model,
# grid_size=7, nb_design=1500,
# sampler = LatinHypercube(binary=True)),
# Saliency(class_model),
# Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
# preservation_probability=0.5),
# #
# ]
# Resize to fit (size, size) with black padding only—no tiling/duplication.
X, (content_top, content_left, content_h, content_w) = letterbox_preprocess(input_image, size)
# Mask: 1 in content region, 0 in zero-padded areas (so we don't show heatmap on padding)
content_mask = np.zeros((size, size), dtype=np.float32)
content_mask[content_top : content_top + content_h, content_left : content_left + content_w] = 1.0
# Determine which classes to explain:
# - If target_labels are provided (from classifier output), use those.
# - Otherwise, fall back to top-5 classes from this forward pass.
classes = []
if target_labels:
indices = []
for name in target_labels:
idx = _FAMILY_TO_INDEX.get(name)
if idx is not None:
indices.append(idx)
if indices:
top_5_indices = np.array(indices, dtype=int)
classes = [lookup_140[i] for i in top_5_indices]
else:
predictions = class_model.predict(np.array([X]))
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
classes = [lookup_140[i] for i in top_5_indices]
else:
predictions = class_model.predict(np.array([X]))
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
classes = [lookup_140[i] for i in top_5_indices]
#print(top_5_indices)
X = np.expand_dims(X, 0)
explanations = []
for e,explainer in enumerate(explainers):
print(f'{e}/{len(explainers)}')
for i,Y in enumerate(top_5_indices):
Y = tf.one_hot([Y], n_classes)
print(f'{i}/{len(top_5_indices)}')
phi = np.abs(explainer(X, Y))[0]
if len(phi.shape) == 3:
phi = np.mean(phi, -1)
#apply Gaussian smoothing
phi_smoothed = cv2.GaussianBlur(phi, (5, 5), sigmaX=1.0, sigmaY=1.0)
# Overlay heatmap on preprocessed image (same coords as phi) for testing
plt.clf()
prep = np.asarray(X[0]).copy()
if prep.max() > 1.0:
prep = prep.astype(np.float32) / 255.0
prep = np.clip(prep, 0, 1)
plt.imshow(prep)
phi_display = np.abs(phi_smoothed)
if phi_display.max() > 1 or phi_display.min() < 0:
phi_display = phi_display - phi_display.min()
if phi_display.max() > 0:
phi_display = phi_display / phi_display.max()
phi_display = np.clip(phi_display, np.percentile(phi_display, 1), np.percentile(phi_display, 99))
if phi_display.max() > 0:
phi_display = phi_display / phi_display.max()
# Heatmap only over content: mask out zero-padded areas (alpha=0 there)
try:
heatmap_cmap = plt.get_cmap('managua')
except (ValueError, TypeError):
heatmap_cmap = plt.get_cmap('magma')
rgba = heatmap_cmap(phi_display)
rgba[..., 3] = rgba[..., 3] * content_mask * heatmap_alpha
plt.imshow(rgba)
plt.axis('off')
plt.savefig(f'phi_{e}{i}.png')
explanations.append(f'phi_{e}{i}.png')
# avg=[]
# for i,Y in enumerate(top_5_indices):
# Y = tf.one_hot([Y], n_classes)
# print(f'{i}/{len(top_5_indices)}')
# phi = np.abs(explainer(X, Y))[0]
# if len(phi.shape) == 3:
# phi = np.mean(phi, -1)
# show(X[0][:,size_repetitions:2*size_repetitions,:])
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
# plt.savefig(f'phi_6.png')
# avg.append(f'phi_6.png')
print('Done')
if len(explanations)==1:
explanations = explanations[0]
# return explanations,avg
return classes,explanations