216 lines
6.7 KiB
Python
Executable File
216 lines
6.7 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Visualisation des embeddings avec t-SNE ou UMAP.
|
|
|
|
Ce script permet de visualiser les embeddings dans un espace 2D
|
|
pour comprendre comment le modèle groupe les workflows similaires.
|
|
"""
|
|
|
|
import sys
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|
|
|
from geniusia2.core.embedders import EmbeddingManager, FAISSIndex
|
|
|
|
try:
|
|
from sklearn.manifold import TSNE
|
|
TSNE_AVAILABLE = True
|
|
except ImportError:
|
|
TSNE_AVAILABLE = False
|
|
print("⚠️ scikit-learn not installed. Install with: pip install scikit-learn")
|
|
|
|
try:
|
|
import umap
|
|
UMAP_AVAILABLE = True
|
|
except ImportError:
|
|
UMAP_AVAILABLE = False
|
|
|
|
|
|
def create_sample_images():
|
|
"""Crée des images d'exemple pour la visualisation."""
|
|
from PIL import ImageDraw, ImageFont
|
|
|
|
samples = []
|
|
|
|
# Groupe 1: Boutons bleus (Submit, OK, Send)
|
|
for text in ["Submit", "OK", "Send"]:
|
|
img = Image.new('RGB', (200, 100), color=(240, 240, 240))
|
|
draw = ImageDraw.Draw(img)
|
|
draw.rectangle([20, 20, 180, 80], fill=(100, 150, 255), outline=(50, 50, 50), width=2)
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
|
except:
|
|
font = ImageFont.load_default()
|
|
bbox = draw.textbbox((0, 0), text, font=font)
|
|
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
|
draw.text((100 - w//2, 50 - h//2), text, fill=(255, 255, 255), font=font)
|
|
samples.append(('blue_button', text, img))
|
|
|
|
# Groupe 2: Boutons rouges (Cancel, Close, Abort)
|
|
for text in ["Cancel", "Close", "Abort"]:
|
|
img = Image.new('RGB', (200, 100), color=(240, 240, 240))
|
|
draw = ImageDraw.Draw(img)
|
|
draw.rectangle([20, 20, 180, 80], fill=(255, 100, 100), outline=(50, 50, 50), width=2)
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
|
except:
|
|
font = ImageFont.load_default()
|
|
bbox = draw.textbbox((0, 0), text, font=font)
|
|
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
|
draw.text((100 - w//2, 50 - h//2), text, fill=(255, 255, 255), font=font)
|
|
samples.append(('red_button', text, img))
|
|
|
|
# Groupe 3: Boutons verts (Apply, Confirm, Accept)
|
|
for text in ["Apply", "Confirm", "Accept"]:
|
|
img = Image.new('RGB', (200, 100), color=(240, 240, 240))
|
|
draw = ImageDraw.Draw(img)
|
|
draw.rectangle([20, 20, 180, 80], fill=(100, 200, 100), outline=(50, 50, 50), width=2)
|
|
try:
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
|
except:
|
|
font = ImageFont.load_default()
|
|
bbox = draw.textbbox((0, 0), text, font=font)
|
|
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
|
draw.text((100 - w//2, 50 - h//2), text, fill=(255, 255, 255), font=font)
|
|
samples.append(('green_button', text, img))
|
|
|
|
return samples
|
|
|
|
|
|
def visualize_embeddings(method='tsne'):
|
|
"""Visualise les embeddings en 2D."""
|
|
print("\n" + "="*70)
|
|
print("VISUALISATION DES EMBEDDINGS")
|
|
print("="*70)
|
|
|
|
# Check dependencies
|
|
if method == 'tsne' and not TSNE_AVAILABLE:
|
|
print("❌ t-SNE requires scikit-learn")
|
|
print(" Install with: pip install scikit-learn")
|
|
return 1
|
|
|
|
if method == 'umap' and not UMAP_AVAILABLE:
|
|
print("❌ UMAP requires umap-learn")
|
|
print(" Install with: pip install umap-learn")
|
|
print(" Falling back to t-SNE...")
|
|
method = 'tsne'
|
|
if not TSNE_AVAILABLE:
|
|
return 1
|
|
|
|
# 1. Create samples
|
|
print("\n1. Création des échantillons...")
|
|
samples = create_sample_images()
|
|
print(f" ✓ Créé {len(samples)} images")
|
|
|
|
# 2. Generate embeddings
|
|
print("\n2. Génération des embeddings...")
|
|
manager = EmbeddingManager(model_name="clip")
|
|
|
|
embeddings = []
|
|
labels = []
|
|
texts = []
|
|
|
|
for category, text, img in samples:
|
|
emb = manager.embed(img)
|
|
embeddings.append(emb)
|
|
labels.append(category)
|
|
texts.append(text)
|
|
|
|
embeddings = np.array(embeddings)
|
|
print(f" ✓ Généré {len(embeddings)} embeddings de dimension {embeddings.shape[1]}")
|
|
|
|
# 3. Reduce dimensionality
|
|
print(f"\n3. Réduction de dimensionnalité ({method.upper()})...")
|
|
|
|
if method == 'tsne':
|
|
reducer = TSNE(n_components=2, random_state=42, perplexity=min(5, len(embeddings)-1))
|
|
else: # umap
|
|
reducer = umap.UMAP(n_components=2, random_state=42)
|
|
|
|
embeddings_2d = reducer.fit_transform(embeddings)
|
|
print(f" ✓ Réduit à 2D: {embeddings_2d.shape}")
|
|
|
|
# 4. Plot
|
|
print("\n4. Création du graphique...")
|
|
|
|
plt.figure(figsize=(12, 8))
|
|
|
|
# Color map
|
|
color_map = {
|
|
'blue_button': 'blue',
|
|
'red_button': 'red',
|
|
'green_button': 'green'
|
|
}
|
|
|
|
# Plot points
|
|
for category in set(labels):
|
|
mask = np.array(labels) == category
|
|
plt.scatter(
|
|
embeddings_2d[mask, 0],
|
|
embeddings_2d[mask, 1],
|
|
c=color_map[category],
|
|
label=category.replace('_', ' ').title(),
|
|
s=200,
|
|
alpha=0.6,
|
|
edgecolors='black',
|
|
linewidth=2
|
|
)
|
|
|
|
# Add text labels
|
|
for i, (x, y) in enumerate(embeddings_2d):
|
|
plt.annotate(
|
|
texts[i],
|
|
(x, y),
|
|
xytext=(5, 5),
|
|
textcoords='offset points',
|
|
fontsize=9,
|
|
bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.3)
|
|
)
|
|
|
|
plt.title(f'Visualisation des Embeddings ({method.upper()})\n'
|
|
f'CLIP ViT-B/32 - {len(embeddings)} échantillons',
|
|
fontsize=14, fontweight='bold')
|
|
plt.xlabel('Dimension 1', fontsize=12)
|
|
plt.ylabel('Dimension 2', fontsize=12)
|
|
plt.legend(fontsize=10, loc='best')
|
|
plt.grid(True, alpha=0.3)
|
|
plt.tight_layout()
|
|
|
|
# Save
|
|
output_file = f'embeddings_visualization_{method}.png'
|
|
plt.savefig(output_file, dpi=150, bbox_inches='tight')
|
|
print(f" ✓ Sauvegardé: {output_file}")
|
|
|
|
# Show
|
|
print("\n5. Affichage...")
|
|
print(" (Ferme la fenêtre pour continuer)")
|
|
plt.show()
|
|
|
|
print("\n✅ Visualisation terminée!")
|
|
return 0
|
|
|
|
|
|
def main():
|
|
"""Main function."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Visualiser les embeddings')
|
|
parser.add_argument(
|
|
'--method',
|
|
choices=['tsne', 'umap'],
|
|
default='tsne',
|
|
help='Méthode de réduction de dimensionnalité (default: tsne)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
return visualize_embeddings(method=args.method)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|