Initial commit
This commit is contained in:
309
geniusia2/core/embedders/faiss_index.py
Normal file
309
geniusia2/core/embedders/faiss_index.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
FAISS index wrapper with proper dimension handling and persistence.
|
||||
|
||||
This module provides a robust wrapper around FAISS for storing and searching
|
||||
image embeddings, with proper error handling for dimension mismatches and
|
||||
reliable save/load functionality.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except ImportError:
|
||||
faiss = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FAISSIndex:
|
||||
"""
|
||||
Wrapper around FAISS index with metadata storage and dimension validation.
|
||||
|
||||
This class handles:
|
||||
- Dimension validation on add/search operations
|
||||
- Metadata storage alongside embeddings
|
||||
- Reliable persistence (save/load)
|
||||
- Automatic index rebuilding on dimension changes
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int):
|
||||
"""
|
||||
Initialize a new FAISS index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (e.g., 512 for CLIP, 768 for Pix2Struct)
|
||||
|
||||
Raises:
|
||||
ImportError: If FAISS is not installed
|
||||
ValueError: If dimension is invalid
|
||||
"""
|
||||
if faiss is None:
|
||||
raise ImportError(
|
||||
"FAISS is not installed. "
|
||||
"Install it with: pip install faiss-cpu or faiss-gpu"
|
||||
)
|
||||
|
||||
if dimension <= 0:
|
||||
raise ValueError(f"Dimension must be positive, got {dimension}")
|
||||
|
||||
self.dimension = dimension
|
||||
self.index = faiss.IndexFlatL2(dimension)
|
||||
self.metadata: List[Dict[str, Any]] = []
|
||||
|
||||
logger.info(f"FAISSIndex created with dimension={dimension}")
|
||||
|
||||
def add(self, embeddings: np.ndarray, metadata: List[Dict[str, Any]]):
|
||||
"""
|
||||
Add embeddings to the index with associated metadata.
|
||||
|
||||
Args:
|
||||
embeddings: Array of shape (N, dimension) containing N embeddings
|
||||
metadata: List of N metadata dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If dimensions don't match or array shapes are invalid
|
||||
"""
|
||||
# Validate input shape
|
||||
if embeddings.ndim == 1:
|
||||
# Single embedding, reshape to (1, dimension)
|
||||
embeddings = embeddings.reshape(1, -1)
|
||||
elif embeddings.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Embeddings must be 1D or 2D array, got shape {embeddings.shape}"
|
||||
)
|
||||
|
||||
# Validate dimension
|
||||
if embeddings.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"Embedding dimension {embeddings.shape[1]} doesn't match "
|
||||
f"index dimension {self.dimension}"
|
||||
)
|
||||
|
||||
# Validate metadata count
|
||||
if len(metadata) != embeddings.shape[0]:
|
||||
raise ValueError(
|
||||
f"Number of metadata entries ({len(metadata)}) doesn't match "
|
||||
f"number of embeddings ({embeddings.shape[0]})"
|
||||
)
|
||||
|
||||
# Add to FAISS index
|
||||
self.index.add(embeddings.astype('float32'))
|
||||
|
||||
# Store metadata
|
||||
self.metadata.extend(metadata)
|
||||
|
||||
logger.debug(
|
||||
f"Added {embeddings.shape[0]} embeddings to index "
|
||||
f"(total: {self.index.ntotal})"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: np.ndarray,
|
||||
k: int = 5
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for the k most similar embeddings.
|
||||
|
||||
Args:
|
||||
query: Query embedding of shape (dimension,) or (1, dimension)
|
||||
k: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of dicts with keys:
|
||||
- 'index': Index in the FAISS index
|
||||
- 'distance': L2 distance
|
||||
- 'similarity': Similarity score (1 / (1 + distance))
|
||||
- 'metadata': Associated metadata dict
|
||||
|
||||
Raises:
|
||||
ValueError: If query dimension doesn't match index dimension
|
||||
"""
|
||||
if self.index.ntotal == 0:
|
||||
logger.warning("Search called on empty index")
|
||||
return []
|
||||
|
||||
# Reshape query if needed
|
||||
if query.ndim == 1:
|
||||
query = query.reshape(1, -1)
|
||||
elif query.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Query must be 1D or 2D array, got shape {query.shape}"
|
||||
)
|
||||
|
||||
# Validate dimension
|
||||
if query.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"Query dimension {query.shape[1]} doesn't match "
|
||||
f"index dimension {self.dimension}"
|
||||
)
|
||||
|
||||
# Limit k to available embeddings
|
||||
k = min(k, self.index.ntotal)
|
||||
|
||||
# Search
|
||||
distances, indices = self.index.search(query.astype('float32'), k)
|
||||
|
||||
# Format results
|
||||
results = []
|
||||
for dist, idx in zip(distances[0], indices[0]):
|
||||
# FAISS returns -1 if not enough results
|
||||
if idx >= 0 and idx < len(self.metadata):
|
||||
results.append({
|
||||
'index': int(idx),
|
||||
'distance': float(dist),
|
||||
'similarity': float(1.0 / (1.0 + dist)),
|
||||
'metadata': self.metadata[idx]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
def save(self, path: str):
|
||||
"""
|
||||
Save index and metadata to disk.
|
||||
|
||||
Args:
|
||||
path: Base path for saving (will create .index and .metadata files)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If save operation fails
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save FAISS index
|
||||
index_file = f"{path}.index"
|
||||
faiss.write_index(self.index, index_file)
|
||||
|
||||
# Save metadata
|
||||
metadata_file = f"{path}.metadata"
|
||||
with open(metadata_file, 'wb') as f:
|
||||
pickle.dump({
|
||||
'dimension': self.dimension,
|
||||
'metadata': self.metadata
|
||||
}, f)
|
||||
|
||||
logger.info(
|
||||
f"Saved index with {self.index.ntotal} embeddings to {path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save index: {e}")
|
||||
|
||||
def load(self, path: str):
|
||||
"""
|
||||
Load index and metadata from disk.
|
||||
|
||||
Args:
|
||||
path: Base path for loading (will read .index and .metadata files)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If files don't exist
|
||||
RuntimeError: If load operation fails or dimension mismatch
|
||||
"""
|
||||
try:
|
||||
index_file = f"{path}.index"
|
||||
metadata_file = f"{path}.metadata"
|
||||
|
||||
# Check files exist
|
||||
if not Path(index_file).exists():
|
||||
raise FileNotFoundError(f"Index file not found: {index_file}")
|
||||
if not Path(metadata_file).exists():
|
||||
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
|
||||
|
||||
# Load FAISS index
|
||||
loaded_index = faiss.read_index(index_file)
|
||||
|
||||
# Load metadata
|
||||
with open(metadata_file, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
loaded_dimension = data['dimension']
|
||||
loaded_metadata = data['metadata']
|
||||
|
||||
# Validate dimension
|
||||
if loaded_dimension != self.dimension:
|
||||
raise RuntimeError(
|
||||
f"Loaded index dimension ({loaded_dimension}) doesn't match "
|
||||
f"current dimension ({self.dimension}). "
|
||||
f"Use rebuild_if_needed() to handle dimension changes."
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.index = loaded_index
|
||||
self.metadata = loaded_metadata
|
||||
|
||||
logger.info(
|
||||
f"Loaded index with {self.index.ntotal} embeddings from {path}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (FileNotFoundError, RuntimeError)):
|
||||
raise
|
||||
raise RuntimeError(f"Failed to load index: {e}")
|
||||
|
||||
def rebuild_if_needed(self, new_dimension: int) -> bool:
|
||||
"""
|
||||
Rebuild index if dimension has changed.
|
||||
|
||||
This creates a new empty index with the new dimension.
|
||||
Old embeddings are lost and need to be regenerated.
|
||||
|
||||
Args:
|
||||
new_dimension: New embedding dimension
|
||||
|
||||
Returns:
|
||||
bool: True if index was rebuilt, False if dimension unchanged
|
||||
"""
|
||||
if new_dimension == self.dimension:
|
||||
return False
|
||||
|
||||
logger.warning(
|
||||
f"Rebuilding FAISS index: dimension changed from "
|
||||
f"{self.dimension} to {new_dimension}. "
|
||||
f"Old embeddings ({self.index.ntotal}) will be lost."
|
||||
)
|
||||
|
||||
# Create new index
|
||||
self.dimension = new_dimension
|
||||
self.index = faiss.IndexFlatL2(new_dimension)
|
||||
self.metadata = []
|
||||
|
||||
return True
|
||||
|
||||
def clear(self):
|
||||
"""Clear all embeddings from the index."""
|
||||
self.index = faiss.IndexFlatL2(self.dimension)
|
||||
self.metadata = []
|
||||
logger.info("Index cleared")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get index statistics.
|
||||
|
||||
Returns:
|
||||
Dict with keys: num_embeddings, dimension, is_trained
|
||||
"""
|
||||
return {
|
||||
'num_embeddings': self.index.ntotal,
|
||||
'dimension': self.dimension,
|
||||
'is_trained': self.index.is_trained
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return number of embeddings in the index."""
|
||||
return self.index.ntotal
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the index."""
|
||||
return (
|
||||
f"FAISSIndex(dimension={self.dimension}, "
|
||||
f"num_embeddings={self.index.ntotal})"
|
||||
)
|
||||
Reference in New Issue
Block a user