Initial commit
This commit is contained in:
215
visualize_embeddings.py
Executable file
215
visualize_embeddings.py
Executable file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Visualisation des embeddings avec t-SNE ou UMAP.
|
||||
|
||||
Ce script permet de visualiser les embeddings dans un espace 2D
|
||||
pour comprendre comment le modèle groupe les workflows similaires.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from geniusia2.core.embedders import EmbeddingManager, FAISSIndex
|
||||
|
||||
try:
|
||||
from sklearn.manifold import TSNE
|
||||
TSNE_AVAILABLE = True
|
||||
except ImportError:
|
||||
TSNE_AVAILABLE = False
|
||||
print("⚠️ scikit-learn not installed. Install with: pip install scikit-learn")
|
||||
|
||||
try:
|
||||
import umap
|
||||
UMAP_AVAILABLE = True
|
||||
except ImportError:
|
||||
UMAP_AVAILABLE = False
|
||||
|
||||
|
||||
def create_sample_images():
|
||||
"""Crée des images d'exemple pour la visualisation."""
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
samples = []
|
||||
|
||||
# Groupe 1: Boutons bleus (Submit, OK, Send)
|
||||
for text in ["Submit", "OK", "Send"]:
|
||||
img = Image.new('RGB', (200, 100), color=(240, 240, 240))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rectangle([20, 20, 180, 80], fill=(100, 150, 255), outline=(50, 50, 50), width=2)
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
bbox = draw.textbbox((0, 0), text, font=font)
|
||||
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
draw.text((100 - w//2, 50 - h//2), text, fill=(255, 255, 255), font=font)
|
||||
samples.append(('blue_button', text, img))
|
||||
|
||||
# Groupe 2: Boutons rouges (Cancel, Close, Abort)
|
||||
for text in ["Cancel", "Close", "Abort"]:
|
||||
img = Image.new('RGB', (200, 100), color=(240, 240, 240))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rectangle([20, 20, 180, 80], fill=(255, 100, 100), outline=(50, 50, 50), width=2)
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
bbox = draw.textbbox((0, 0), text, font=font)
|
||||
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
draw.text((100 - w//2, 50 - h//2), text, fill=(255, 255, 255), font=font)
|
||||
samples.append(('red_button', text, img))
|
||||
|
||||
# Groupe 3: Boutons verts (Apply, Confirm, Accept)
|
||||
for text in ["Apply", "Confirm", "Accept"]:
|
||||
img = Image.new('RGB', (200, 100), color=(240, 240, 240))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rectangle([20, 20, 180, 80], fill=(100, 200, 100), outline=(50, 50, 50), width=2)
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 16)
|
||||
except:
|
||||
font = ImageFont.load_default()
|
||||
bbox = draw.textbbox((0, 0), text, font=font)
|
||||
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
draw.text((100 - w//2, 50 - h//2), text, fill=(255, 255, 255), font=font)
|
||||
samples.append(('green_button', text, img))
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def visualize_embeddings(method='tsne'):
|
||||
"""Visualise les embeddings en 2D."""
|
||||
print("\n" + "="*70)
|
||||
print("VISUALISATION DES EMBEDDINGS")
|
||||
print("="*70)
|
||||
|
||||
# Check dependencies
|
||||
if method == 'tsne' and not TSNE_AVAILABLE:
|
||||
print("❌ t-SNE requires scikit-learn")
|
||||
print(" Install with: pip install scikit-learn")
|
||||
return 1
|
||||
|
||||
if method == 'umap' and not UMAP_AVAILABLE:
|
||||
print("❌ UMAP requires umap-learn")
|
||||
print(" Install with: pip install umap-learn")
|
||||
print(" Falling back to t-SNE...")
|
||||
method = 'tsne'
|
||||
if not TSNE_AVAILABLE:
|
||||
return 1
|
||||
|
||||
# 1. Create samples
|
||||
print("\n1. Création des échantillons...")
|
||||
samples = create_sample_images()
|
||||
print(f" ✓ Créé {len(samples)} images")
|
||||
|
||||
# 2. Generate embeddings
|
||||
print("\n2. Génération des embeddings...")
|
||||
manager = EmbeddingManager(model_name="clip")
|
||||
|
||||
embeddings = []
|
||||
labels = []
|
||||
texts = []
|
||||
|
||||
for category, text, img in samples:
|
||||
emb = manager.embed(img)
|
||||
embeddings.append(emb)
|
||||
labels.append(category)
|
||||
texts.append(text)
|
||||
|
||||
embeddings = np.array(embeddings)
|
||||
print(f" ✓ Généré {len(embeddings)} embeddings de dimension {embeddings.shape[1]}")
|
||||
|
||||
# 3. Reduce dimensionality
|
||||
print(f"\n3. Réduction de dimensionnalité ({method.upper()})...")
|
||||
|
||||
if method == 'tsne':
|
||||
reducer = TSNE(n_components=2, random_state=42, perplexity=min(5, len(embeddings)-1))
|
||||
else: # umap
|
||||
reducer = umap.UMAP(n_components=2, random_state=42)
|
||||
|
||||
embeddings_2d = reducer.fit_transform(embeddings)
|
||||
print(f" ✓ Réduit à 2D: {embeddings_2d.shape}")
|
||||
|
||||
# 4. Plot
|
||||
print("\n4. Création du graphique...")
|
||||
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Color map
|
||||
color_map = {
|
||||
'blue_button': 'blue',
|
||||
'red_button': 'red',
|
||||
'green_button': 'green'
|
||||
}
|
||||
|
||||
# Plot points
|
||||
for category in set(labels):
|
||||
mask = np.array(labels) == category
|
||||
plt.scatter(
|
||||
embeddings_2d[mask, 0],
|
||||
embeddings_2d[mask, 1],
|
||||
c=color_map[category],
|
||||
label=category.replace('_', ' ').title(),
|
||||
s=200,
|
||||
alpha=0.6,
|
||||
edgecolors='black',
|
||||
linewidth=2
|
||||
)
|
||||
|
||||
# Add text labels
|
||||
for i, (x, y) in enumerate(embeddings_2d):
|
||||
plt.annotate(
|
||||
texts[i],
|
||||
(x, y),
|
||||
xytext=(5, 5),
|
||||
textcoords='offset points',
|
||||
fontsize=9,
|
||||
bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.3)
|
||||
)
|
||||
|
||||
plt.title(f'Visualisation des Embeddings ({method.upper()})\n'
|
||||
f'CLIP ViT-B/32 - {len(embeddings)} échantillons',
|
||||
fontsize=14, fontweight='bold')
|
||||
plt.xlabel('Dimension 1', fontsize=12)
|
||||
plt.ylabel('Dimension 2', fontsize=12)
|
||||
plt.legend(fontsize=10, loc='best')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout()
|
||||
|
||||
# Save
|
||||
output_file = f'embeddings_visualization_{method}.png'
|
||||
plt.savefig(output_file, dpi=150, bbox_inches='tight')
|
||||
print(f" ✓ Sauvegardé: {output_file}")
|
||||
|
||||
# Show
|
||||
print("\n5. Affichage...")
|
||||
print(" (Ferme la fenêtre pour continuer)")
|
||||
plt.show()
|
||||
|
||||
print("\n✅ Visualisation terminée!")
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='Visualiser les embeddings')
|
||||
parser.add_argument(
|
||||
'--method',
|
||||
choices=['tsne', 'umap'],
|
||||
default='tsne',
|
||||
help='Méthode de réduction de dimensionnalité (default: tsne)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return visualize_embeddings(method=args.method)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user