310 lines
9.9 KiB
Python
310 lines
9.9 KiB
Python
"""
|
|
FAISS index wrapper with proper dimension handling and persistence.
|
|
|
|
This module provides a robust wrapper around FAISS for storing and searching
|
|
image embeddings, with proper error handling for dimension mismatches and
|
|
reliable save/load functionality.
|
|
"""
|
|
|
|
import pickle
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Dict, Any, Optional
|
|
import numpy as np
|
|
|
|
try:
|
|
import faiss
|
|
except ImportError:
|
|
faiss = None
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FAISSIndex:
|
|
"""
|
|
Wrapper around FAISS index with metadata storage and dimension validation.
|
|
|
|
This class handles:
|
|
- Dimension validation on add/search operations
|
|
- Metadata storage alongside embeddings
|
|
- Reliable persistence (save/load)
|
|
- Automatic index rebuilding on dimension changes
|
|
"""
|
|
|
|
def __init__(self, dimension: int):
|
|
"""
|
|
Initialize a new FAISS index.
|
|
|
|
Args:
|
|
dimension: Embedding dimension (e.g., 512 for CLIP, 768 for Pix2Struct)
|
|
|
|
Raises:
|
|
ImportError: If FAISS is not installed
|
|
ValueError: If dimension is invalid
|
|
"""
|
|
if faiss is None:
|
|
raise ImportError(
|
|
"FAISS is not installed. "
|
|
"Install it with: pip install faiss-cpu or faiss-gpu"
|
|
)
|
|
|
|
if dimension <= 0:
|
|
raise ValueError(f"Dimension must be positive, got {dimension}")
|
|
|
|
self.dimension = dimension
|
|
self.index = faiss.IndexFlatL2(dimension)
|
|
self.metadata: List[Dict[str, Any]] = []
|
|
|
|
logger.info(f"FAISSIndex created with dimension={dimension}")
|
|
|
|
def add(self, embeddings: np.ndarray, metadata: List[Dict[str, Any]]):
|
|
"""
|
|
Add embeddings to the index with associated metadata.
|
|
|
|
Args:
|
|
embeddings: Array of shape (N, dimension) containing N embeddings
|
|
metadata: List of N metadata dictionaries
|
|
|
|
Raises:
|
|
ValueError: If dimensions don't match or array shapes are invalid
|
|
"""
|
|
# Validate input shape
|
|
if embeddings.ndim == 1:
|
|
# Single embedding, reshape to (1, dimension)
|
|
embeddings = embeddings.reshape(1, -1)
|
|
elif embeddings.ndim != 2:
|
|
raise ValueError(
|
|
f"Embeddings must be 1D or 2D array, got shape {embeddings.shape}"
|
|
)
|
|
|
|
# Validate dimension
|
|
if embeddings.shape[1] != self.dimension:
|
|
raise ValueError(
|
|
f"Embedding dimension {embeddings.shape[1]} doesn't match "
|
|
f"index dimension {self.dimension}"
|
|
)
|
|
|
|
# Validate metadata count
|
|
if len(metadata) != embeddings.shape[0]:
|
|
raise ValueError(
|
|
f"Number of metadata entries ({len(metadata)}) doesn't match "
|
|
f"number of embeddings ({embeddings.shape[0]})"
|
|
)
|
|
|
|
# Add to FAISS index
|
|
self.index.add(embeddings.astype('float32'))
|
|
|
|
# Store metadata
|
|
self.metadata.extend(metadata)
|
|
|
|
logger.debug(
|
|
f"Added {embeddings.shape[0]} embeddings to index "
|
|
f"(total: {self.index.ntotal})"
|
|
)
|
|
|
|
def search(
|
|
self,
|
|
query: np.ndarray,
|
|
k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search for the k most similar embeddings.
|
|
|
|
Args:
|
|
query: Query embedding of shape (dimension,) or (1, dimension)
|
|
k: Number of results to return
|
|
|
|
Returns:
|
|
List of dicts with keys:
|
|
- 'index': Index in the FAISS index
|
|
- 'distance': L2 distance
|
|
- 'similarity': Similarity score (1 / (1 + distance))
|
|
- 'metadata': Associated metadata dict
|
|
|
|
Raises:
|
|
ValueError: If query dimension doesn't match index dimension
|
|
"""
|
|
if self.index.ntotal == 0:
|
|
logger.warning("Search called on empty index")
|
|
return []
|
|
|
|
# Reshape query if needed
|
|
if query.ndim == 1:
|
|
query = query.reshape(1, -1)
|
|
elif query.ndim != 2:
|
|
raise ValueError(
|
|
f"Query must be 1D or 2D array, got shape {query.shape}"
|
|
)
|
|
|
|
# Validate dimension
|
|
if query.shape[1] != self.dimension:
|
|
raise ValueError(
|
|
f"Query dimension {query.shape[1]} doesn't match "
|
|
f"index dimension {self.dimension}"
|
|
)
|
|
|
|
# Limit k to available embeddings
|
|
k = min(k, self.index.ntotal)
|
|
|
|
# Search
|
|
distances, indices = self.index.search(query.astype('float32'), k)
|
|
|
|
# Format results
|
|
results = []
|
|
for dist, idx in zip(distances[0], indices[0]):
|
|
# FAISS returns -1 if not enough results
|
|
if idx >= 0 and idx < len(self.metadata):
|
|
results.append({
|
|
'index': int(idx),
|
|
'distance': float(dist),
|
|
'similarity': float(1.0 / (1.0 + dist)),
|
|
'metadata': self.metadata[idx]
|
|
})
|
|
|
|
return results
|
|
|
|
def save(self, path: str):
|
|
"""
|
|
Save index and metadata to disk.
|
|
|
|
Args:
|
|
path: Base path for saving (will create .index and .metadata files)
|
|
|
|
Raises:
|
|
RuntimeError: If save operation fails
|
|
"""
|
|
try:
|
|
path_obj = Path(path)
|
|
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save FAISS index
|
|
index_file = f"{path}.index"
|
|
faiss.write_index(self.index, index_file)
|
|
|
|
# Save metadata
|
|
metadata_file = f"{path}.metadata"
|
|
with open(metadata_file, 'wb') as f:
|
|
pickle.dump({
|
|
'dimension': self.dimension,
|
|
'metadata': self.metadata
|
|
}, f)
|
|
|
|
logger.info(
|
|
f"Saved index with {self.index.ntotal} embeddings to {path}"
|
|
)
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to save index: {e}")
|
|
|
|
def load(self, path: str):
|
|
"""
|
|
Load index and metadata from disk.
|
|
|
|
Args:
|
|
path: Base path for loading (will read .index and .metadata files)
|
|
|
|
Raises:
|
|
FileNotFoundError: If files don't exist
|
|
RuntimeError: If load operation fails or dimension mismatch
|
|
"""
|
|
try:
|
|
index_file = f"{path}.index"
|
|
metadata_file = f"{path}.metadata"
|
|
|
|
# Check files exist
|
|
if not Path(index_file).exists():
|
|
raise FileNotFoundError(f"Index file not found: {index_file}")
|
|
if not Path(metadata_file).exists():
|
|
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
|
|
|
|
# Load FAISS index
|
|
loaded_index = faiss.read_index(index_file)
|
|
|
|
# Load metadata
|
|
with open(metadata_file, 'rb') as f:
|
|
data = pickle.load(f)
|
|
|
|
loaded_dimension = data['dimension']
|
|
loaded_metadata = data['metadata']
|
|
|
|
# Validate dimension
|
|
if loaded_dimension != self.dimension:
|
|
raise RuntimeError(
|
|
f"Loaded index dimension ({loaded_dimension}) doesn't match "
|
|
f"current dimension ({self.dimension}). "
|
|
f"Use rebuild_if_needed() to handle dimension changes."
|
|
)
|
|
|
|
# Update state
|
|
self.index = loaded_index
|
|
self.metadata = loaded_metadata
|
|
|
|
logger.info(
|
|
f"Loaded index with {self.index.ntotal} embeddings from {path}"
|
|
)
|
|
|
|
except Exception as e:
|
|
if isinstance(e, (FileNotFoundError, RuntimeError)):
|
|
raise
|
|
raise RuntimeError(f"Failed to load index: {e}")
|
|
|
|
def rebuild_if_needed(self, new_dimension: int) -> bool:
|
|
"""
|
|
Rebuild index if dimension has changed.
|
|
|
|
This creates a new empty index with the new dimension.
|
|
Old embeddings are lost and need to be regenerated.
|
|
|
|
Args:
|
|
new_dimension: New embedding dimension
|
|
|
|
Returns:
|
|
bool: True if index was rebuilt, False if dimension unchanged
|
|
"""
|
|
if new_dimension == self.dimension:
|
|
return False
|
|
|
|
logger.warning(
|
|
f"Rebuilding FAISS index: dimension changed from "
|
|
f"{self.dimension} to {new_dimension}. "
|
|
f"Old embeddings ({self.index.ntotal}) will be lost."
|
|
)
|
|
|
|
# Create new index
|
|
self.dimension = new_dimension
|
|
self.index = faiss.IndexFlatL2(new_dimension)
|
|
self.metadata = []
|
|
|
|
return True
|
|
|
|
def clear(self):
|
|
"""Clear all embeddings from the index."""
|
|
self.index = faiss.IndexFlatL2(self.dimension)
|
|
self.metadata = []
|
|
logger.info("Index cleared")
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""
|
|
Get index statistics.
|
|
|
|
Returns:
|
|
Dict with keys: num_embeddings, dimension, is_trained
|
|
"""
|
|
return {
|
|
'num_embeddings': self.index.ntotal,
|
|
'dimension': self.dimension,
|
|
'is_trained': self.index.is_trained
|
|
}
|
|
|
|
def __len__(self) -> int:
|
|
"""Return number of embeddings in the index."""
|
|
return self.index.ntotal
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation of the index."""
|
|
return (
|
|
f"FAISSIndex(dimension={self.dimension}, "
|
|
f"num_embeddings={self.index.ntotal})"
|
|
)
|