Source code for matchms.similarity.BaseEmbeddingSimilarity

import pickle
from abc import abstractmethod
from pathlib import Path
from typing import Any, Iterable, List, Optional, Tuple, Union
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from matchms.similarity.BaseSimilarity import BaseSimilarity
from matchms.typing import SpectrumType


try:
    import pynndescent
except ImportError:
    pynndescent = None


[docs] class BaseEmbeddingSimilarity(BaseSimilarity): """Base class for similarity measures that work with embeddings. This class provides functionality for computing similarities between spectra based on their embeddings (vector representations). It supports cosine and euclidean similarity metrics, and includes approximate nearest neighbor (ANN) search capabilities. Parameters ---------- similarity : str The similarity measure to use for comparing embeddings. Default is "cosine". Options are "cosine" or "euclidean". Attributes ---------- index : object The ANN index object; if built. index_backend : str The backend used for ANN indexing (currently only "pynndescent" supported); if index is built. index_kwargs : dict Additional arguments passed to the ANN index constructor; if index is built. index_k : int Number of nearest neighbors used in the ANN index; if index is built. """
[docs] def __init__(self, similarity: str = "cosine"): self.similarity = similarity self.index = None self.index_backend = None self.index_kwargs = None self.index_k = None if self.similarity == "cosine": self.pairwise_similarity_fn = cosine_similarity elif self.similarity == "euclidean": self.pairwise_similarity_fn = lambda x, y: self._distances_to_similarities(euclidean_distances(x, y)) else: raise ValueError(f"Only cosine and euclidean similarities are supported for now. Got {self.similarity}.")
[docs] @abstractmethod def compute_embeddings(self, spectra: Iterable[SpectrumType]) -> np.ndarray: """Compute embeddings for a list of spectra. Parameters ---------- spectra: List of spectra to compute embeddings for. Returns ------- np.ndarray Embeddings for the spectra. Shape: (n_spectra, n_embedding_features). """ raise NotImplementedError("Subclasses must implement this method.")
[docs] def get_embeddings( self, spectra: Optional[Iterable[SpectrumType]] = None, npy_path: Optional[Union[str, Path]] = None) -> np.ndarray: """Get embeddings either by computing them or loading from disk. Parameters ---------- spectra: List of spectra to compute embeddings for. npy_path: Path to load/save embeddings from/to. If provided, embeddings are loaded from disk if it exists, otherwise they are computed and saved on disk to the provided path. Returns ------- np.ndarray Embeddings array. Raises ------ ValueError If neither spectra nor npy_path is provided. """ if spectra is None and npy_path is None: raise ValueError("Either spectra or npy_path must be provided.") if npy_path is not None: if Path(npy_path).exists(): # If file path is provided and exists, load embeddings embs = self.load_embeddings(npy_path) else: # If file path is provided and does not exist, compute embeddings and store them embs = self.compute_embeddings(spectra) self.store_embeddings(npy_path, embs) else: # If no file path is provided, compute embeddings embs = self.compute_embeddings(spectra) return embs
[docs] def pair(self, reference: SpectrumType, query: SpectrumType) -> float: """Compute similarity between a pair of spectra. Parameters ---------- reference : SpectrumType Reference spectrum. query : SpectrumType Query spectrum. Returns ------- float Similarity score between the spectra. """ return self.matrix([reference], [query])[0, 0]
[docs] def matrix( self, references: List[SpectrumType], queries: List[SpectrumType], array_type: str = "numpy", is_symmetric: bool = True) -> np.ndarray: """Compute similarity matrix between reference and query spectra. Parameters ---------- references: List of reference spectra. queries: List of query spectra. array_type: Type of array to return. Must be "numpy". is_symmetric: Whether the matrix is symmetric. Must be True. Returns ------- np.ndarray Similarity matrix. Raises ------ ValueError If array_type is not "numpy" or is_symmetric is False. """ if array_type != "numpy" or not is_symmetric: raise ValueError("Any embedding base similarity matrix is supposed to be dense and symmetric.") # Compute embeddings embs_ref = self.compute_embeddings(references) embs_query = self.compute_embeddings(queries) # Compute pairwise similarities matrix return self.pairwise_similarity_fn(embs_ref, embs_query)
[docs] def build_ann_index( self, reference_spectra: Optional[Iterable[SpectrumType]] = None, embeddings_path: Optional[Union[str, Path]] = None, k: int = 100, index_backend: str = "pynndescent", **index_kwargs) -> Any: """Build an ANN index for the reference spectra. Parameters ---------- reference_spectra : Optional[Iterable[SpectrumType]] List of reference spectra to build the ANN index for. embeddings_path : Optional[Union[str, Path]] If embeddings are already computed, provide the path to the numpy file. k : int, optional Number of nearest neighbors to use for the ANN index. index_backend : str, optional Backend to use for ANN index. Currently only "pynndescent" is supported. **index_kwargs Additional keyword arguments passed to the index constructor. Returns ------- Any The constructed ANN index. Raises ------ ImportError If pynndescent is not installed. ValueError If an unsupported index_backend is specified. """ # Compute reference embeddings embs_ref = self.get_embeddings(reference_spectra, embeddings_path) if index_backend == "pynndescent": if not pynndescent: raise ImportError("pynndescent is not installed. Please install it with `pip install pynndescent`.") self.index_backend = index_backend self.index_k = k self.index_kwargs = index_kwargs # Build ANN index index = pynndescent.NNDescent(embs_ref, metric=self.similarity, n_neighbors=k, **index_kwargs) else: raise ValueError(f"Only pynndescent is supported for now. Got {index_backend}.") # Keep index in memory self.index = index return self.index
[docs] def get_anns( self, query_spectra: Union[Iterable[SpectrumType], np.ndarray], k: int = 100) -> Tuple[np.ndarray, np.ndarray]: """Get approximate nearest neighbors for query spectra. Parameters ---------- query_spectra : Union[Iterable[SpectrumType], np.ndarray] Query spectra or their embeddings. k : int, optional Number of nearest neighbors to return. Returns ------- Tuple[np.ndarray, np.ndarray] Neighbor indices and similarity scores. Raises ------ ValueError If no index is built or k is larger than index k. """ if self.index is None: raise ValueError( "No index built yet. Please call `build_ann_index` on your reference spectra or `load_ann_index` if it " "was previously built and stored using `save_ann_index`." ) if k > self.index_k: raise ValueError(f"k ({k}) is larger than the k used to build the index ({self.index_k}).") if isinstance(query_spectra, np.ndarray): embs_query = query_spectra if embs_query.ndim != 2: raise ValueError(f"Expected 2D embeddings array, got {embs_query.ndim}D array.") else: # Compute query embeddings embs_query = self.compute_embeddings(query_spectra) # Get ANN indices if self.index_backend == "pynndescent": neighbors, distances = self.index.query(embs_query, k=k) similarities = self._distances_to_similarities(distances) else: raise ValueError(f"Only pynndescent is supported for now. Got {self.index_backend}.") return neighbors, similarities
[docs] def get_index_anns(self) -> Tuple[np.ndarray, np.ndarray]: """Get nearest neighbors for all points in the index. Returns ------- Tuple[np.ndarray, np.ndarray] Neighbor indices and similarity scores. Raises ------ ValueError If unsupported index backend is used. """ if self.index_backend == "pynndescent": neighbors, distances = self.index.neighbor_graph similarities = self._distances_to_similarities(distances) return neighbors, similarities raise ValueError(f"Only pynndescent is supported for now. Got {self.index_backend}.")
def _distances_to_similarities(self, distances: np.ndarray) -> np.ndarray: """Convert distances to similarities based on similarity metric. Parameters ---------- distances : np.ndarray Distance matrix. Returns ------- np.ndarray Similarity matrix. Raises ------ ValueError If unsupported similarity metric is used. """ if self.similarity == "cosine": return 1 - distances if self.similarity == "euclidean": return -distances raise ValueError(f"Only cosine and euclidean similarities are supported for now. Got {self.similarity}.")
[docs] @staticmethod def load_embeddings(npy_path: Union[str, Path]) -> np.ndarray: """Load embeddings from a numpy file. Parameters ---------- npy_path : Union[str, Path] Path to the numpy file. Returns ------- np.ndarray Embeddings array. Raises ------ ValueError If loaded array is not 2D. """ embs = np.load(npy_path) if embs.ndim != 2: raise ValueError(f"Expected 2D embeddings array, got {embs.ndim}D array.") return embs
[docs] @staticmethod def store_embeddings(npy_path: Union[str, Path], embs: np.ndarray) -> None: """Store embeddings in a numpy file. Parameters ---------- npy_path : Union[str, Path] Path to save the embeddings to. embs : np.ndarray Embeddings array to store. """ np.save(npy_path, embs)
[docs] def save_ann_index(self, path: Union[str, Path]) -> None: """Save the ANN index to disk. Parameters ---------- path : Union[str, Path] Path to save the index to. Raises ------ ValueError If no index exists to save. """ if self.index is None: raise ValueError("No index to save. Please build an index first using build_ann_index().") save_dict = { 'index': self.index, 'backend': self.index_backend, 'similarity': self.similarity, 'index_kwargs': self.index_kwargs, 'index_k': self.index_k } with open(path, 'wb') as f: pickle.dump(save_dict, f)
[docs] def load_ann_index(self, path: Union[str, Path]) -> Any: """Load an ANN index from disk. Parameters ---------- path : Union[str, Path] Path to load the index from. Returns ------- Any The loaded ANN index. Raises ------ ValueError If loaded index similarity metric doesn't match current metric. """ with open(path, 'rb') as f: load_dict = pickle.load(f) if load_dict['similarity'] != self.similarity: raise ValueError( f"Loaded index similarity metric ({load_dict['similarity']}) does not match " f"current similarity metric ({self.similarity})" ) self.index = load_dict['index'] self.index_backend = load_dict['backend'] self.index_kwargs = load_dict['index_kwargs'] self.index_k = load_dict['index_k'] return self.index