diff --git a/cognee/modules/retrieval/jaccard_retrival.py b/cognee/modules/retrieval/jaccard_retrival.py index 91d2b67f7..b864d62d9 100644 --- a/cognee/modules/retrieval/jaccard_retrival.py +++ b/cognee/modules/retrieval/jaccard_retrival.py @@ -2,13 +2,20 @@ from cognee.modules.retrieval.lexical_retriever import LexicalRetriever import re from collections import Counter from typing import Optional + + class JaccardChunksRetriever(LexicalRetriever): """ Retriever that specializes LexicalRetriever to use Jaccard similarity. """ - def __init__(self, top_k: int = 10, with_scores: bool = False, - stop_words: Optional[list[str]] = None, multiset_jaccard: bool = False): + def __init__( + self, + top_k: int = 10, + with_scores: bool = False, + stop_words: Optional[list[str]] = None, + multiset_jaccard: bool = False, + ): """ Parameters ---------- @@ -25,10 +32,7 @@ class JaccardChunksRetriever(LexicalRetriever): self.multiset_jaccard = multiset_jaccard super().__init__( - tokenizer=self._tokenizer, - scorer=self._scorer, - top_k=top_k, - with_scores=with_scores + tokenizer=self._tokenizer, scorer=self._scorer, top_k=top_k, with_scores=with_scores ) def _tokenizer(self, text: str) -> list[str]: diff --git a/cognee/modules/retrieval/lexical_retriever.py b/cognee/modules/retrieval/lexical_retriever.py index 2292b64c8..fa80fe320 100644 --- a/cognee/modules/retrieval/lexical_retriever.py +++ b/cognee/modules/retrieval/lexical_retriever.py @@ -12,8 +12,9 @@ logger = get_logger("LexicalRetriever") class LexicalRetriever(BaseRetriever): - - def __init__(self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False): + def __init__( + self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False + ): if not callable(tokenizer) or not callable(scorer): raise TypeError("tokenizer and scorer must be callables") if not isinstance(top_k, int) or top_k <= 0: @@ -25,51 +26,51 @@ class LexicalRetriever(BaseRetriever): self.with_scores = bool(with_scores) # Cache keyed by dataset context - self.chunks: dict[str, Any] = {} # {chunk_id: tokens} - self.payloads: dict[str, Any] = {} # {chunk_id: original_document} + self.chunks: dict[str, Any] = {} # {chunk_id: tokens} + self.payloads: dict[str, Any] = {} # {chunk_id: original_document} self._initialized = False self._init_lock = asyncio.Lock() async def initialize(self): - """Initialize retriever by reading all DocumentChunks from graph_engine.""" - async with self._init_lock: - if self._initialized: - return + """Initialize retriever by reading all DocumentChunks from graph_engine.""" + async with self._init_lock: + if self._initialized: + return - logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine") + logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine") - try: - graph_engine = await get_graph_engine() - nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}]) - except Exception as e: - logger.error("Graph engine initialization failed") - raise NoDataError("Graph engine initialization failed") from e + try: + graph_engine = await get_graph_engine() + nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}]) + except Exception as e: + logger.error("Graph engine initialization failed") + raise NoDataError("Graph engine initialization failed") from e - chunk_count = 0 - for node in nodes: - try: - chunk_id, document = node - except Exception: - logger.warning("Skipping node with unexpected shape: %r", node) - continue + chunk_count = 0 + for node in nodes: + try: + chunk_id, document = node + except Exception: + logger.warning("Skipping node with unexpected shape: %r", node) + continue - if document.get("type") == "DocumentChunk" and document.get("text"): - try: - tokens = self.tokenizer(document["text"]) - if not tokens: - continue - self.chunks[str(document.get("id",chunk_id))] = tokens - self.payloads[str(document.get("id",chunk_id))] = document - chunk_count += 1 - except Exception as e: - logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e)) + if document.get("type") == "DocumentChunk" and document.get("text"): + try: + tokens = self.tokenizer(document["text"]) + if not tokens: + continue + self.chunks[str(document.get("id", chunk_id))] = tokens + self.payloads[str(document.get("id", chunk_id))] = document + chunk_count += 1 + except Exception as e: + logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e)) - if chunk_count == 0: - logger.error("Initialization completed but no valid chunks were loaded.") - raise NoDataError("No valid chunks loaded during initialization.") + if chunk_count == 0: + logger.error("Initialization completed but no valid chunks were loaded.") + raise NoDataError("No valid chunks loaded during initialization.") - self._initialized = True - logger.info("Initialized with %d document chunks", len(self.chunks)) + self._initialized = True + logger.info("Initialized with %d document chunks", len(self.chunks)) async def get_context(self, query: str) -> Any: """Retrieves relevant chunks for the given query.""" @@ -103,7 +104,12 @@ class LexicalRetriever(BaseRetriever): results.append((chunk_id, score)) top_results = nlargest(self.top_k, results, key=lambda x: x[1]) - logger.info("Retrieved %d/%d chunks for query (len=%d)", len(top_results), len(results), len(query_tokens)) + logger.info( + "Retrieved %d/%d chunks for query (len=%d)", + len(top_results), + len(results), + len(query_tokens), + ) if self.with_scores: return [(self.payloads[chunk_id], score) for chunk_id, score in top_results] diff --git a/cognee/modules/search/methods/get_search_type_tools.py b/cognee/modules/search/methods/get_search_type_tools.py index c5ea53a62..9cf67785e 100644 --- a/cognee/modules/search/methods/get_search_type_tools.py +++ b/cognee/modules/search/methods/get_search_type_tools.py @@ -153,10 +153,12 @@ async def get_search_type_tools( TemporalRetriever(top_k=top_k).get_completion, TemporalRetriever(top_k=top_k).get_context, ], - SearchType.CHUNKS_LEXICAL: (lambda _r=JaccardChunksRetriever(top_k=top_k): [ - _r.get_completion, - _r.get_context, - ])(), + SearchType.CHUNKS_LEXICAL: ( + lambda _r=JaccardChunksRetriever(top_k=top_k): [ + _r.get_completion, + _r.get_context, + ] + )(), SearchType.CODING_RULES: [ CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules, ],