chore: format files

This commit is contained in:
Igor Ilic 2025-09-22 11:33:19 +02:00
parent 023f5ea632
commit 87c79b52e3
3 changed files with 60 additions and 48 deletions

View file

@ -2,13 +2,20 @@ from cognee.modules.retrieval.lexical_retriever import LexicalRetriever
import re import re
from collections import Counter from collections import Counter
from typing import Optional from typing import Optional
class JaccardChunksRetriever(LexicalRetriever): class JaccardChunksRetriever(LexicalRetriever):
""" """
Retriever that specializes LexicalRetriever to use Jaccard similarity. Retriever that specializes LexicalRetriever to use Jaccard similarity.
""" """
def __init__(self, top_k: int = 10, with_scores: bool = False, def __init__(
stop_words: Optional[list[str]] = None, multiset_jaccard: bool = False): self,
top_k: int = 10,
with_scores: bool = False,
stop_words: Optional[list[str]] = None,
multiset_jaccard: bool = False,
):
""" """
Parameters Parameters
---------- ----------
@ -25,10 +32,7 @@ class JaccardChunksRetriever(LexicalRetriever):
self.multiset_jaccard = multiset_jaccard self.multiset_jaccard = multiset_jaccard
super().__init__( super().__init__(
tokenizer=self._tokenizer, tokenizer=self._tokenizer, scorer=self._scorer, top_k=top_k, with_scores=with_scores
scorer=self._scorer,
top_k=top_k,
with_scores=with_scores
) )
def _tokenizer(self, text: str) -> list[str]: def _tokenizer(self, text: str) -> list[str]:

View file

@ -12,8 +12,9 @@ logger = get_logger("LexicalRetriever")
class LexicalRetriever(BaseRetriever): class LexicalRetriever(BaseRetriever):
def __init__(
def __init__(self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False): self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False
):
if not callable(tokenizer) or not callable(scorer): if not callable(tokenizer) or not callable(scorer):
raise TypeError("tokenizer and scorer must be callables") raise TypeError("tokenizer and scorer must be callables")
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
@ -25,51 +26,51 @@ class LexicalRetriever(BaseRetriever):
self.with_scores = bool(with_scores) self.with_scores = bool(with_scores)
# Cache keyed by dataset context # Cache keyed by dataset context
self.chunks: dict[str, Any] = {} # {chunk_id: tokens} self.chunks: dict[str, Any] = {} # {chunk_id: tokens}
self.payloads: dict[str, Any] = {} # {chunk_id: original_document} self.payloads: dict[str, Any] = {} # {chunk_id: original_document}
self._initialized = False self._initialized = False
self._init_lock = asyncio.Lock() self._init_lock = asyncio.Lock()
async def initialize(self): async def initialize(self):
"""Initialize retriever by reading all DocumentChunks from graph_engine.""" """Initialize retriever by reading all DocumentChunks from graph_engine."""
async with self._init_lock: async with self._init_lock:
if self._initialized: if self._initialized:
return return
logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine") logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine")
try: try:
graph_engine = await get_graph_engine() graph_engine = await get_graph_engine()
nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}]) nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}])
except Exception as e: except Exception as e:
logger.error("Graph engine initialization failed") logger.error("Graph engine initialization failed")
raise NoDataError("Graph engine initialization failed") from e raise NoDataError("Graph engine initialization failed") from e
chunk_count = 0 chunk_count = 0
for node in nodes: for node in nodes:
try: try:
chunk_id, document = node chunk_id, document = node
except Exception: except Exception:
logger.warning("Skipping node with unexpected shape: %r", node) logger.warning("Skipping node with unexpected shape: %r", node)
continue continue
if document.get("type") == "DocumentChunk" and document.get("text"): if document.get("type") == "DocumentChunk" and document.get("text"):
try: try:
tokens = self.tokenizer(document["text"]) tokens = self.tokenizer(document["text"])
if not tokens: if not tokens:
continue continue
self.chunks[str(document.get("id",chunk_id))] = tokens self.chunks[str(document.get("id", chunk_id))] = tokens
self.payloads[str(document.get("id",chunk_id))] = document self.payloads[str(document.get("id", chunk_id))] = document
chunk_count += 1 chunk_count += 1
except Exception as e: except Exception as e:
logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e)) logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e))
if chunk_count == 0: if chunk_count == 0:
logger.error("Initialization completed but no valid chunks were loaded.") logger.error("Initialization completed but no valid chunks were loaded.")
raise NoDataError("No valid chunks loaded during initialization.") raise NoDataError("No valid chunks loaded during initialization.")
self._initialized = True self._initialized = True
logger.info("Initialized with %d document chunks", len(self.chunks)) logger.info("Initialized with %d document chunks", len(self.chunks))
async def get_context(self, query: str) -> Any: async def get_context(self, query: str) -> Any:
"""Retrieves relevant chunks for the given query.""" """Retrieves relevant chunks for the given query."""
@ -103,7 +104,12 @@ class LexicalRetriever(BaseRetriever):
results.append((chunk_id, score)) results.append((chunk_id, score))
top_results = nlargest(self.top_k, results, key=lambda x: x[1]) top_results = nlargest(self.top_k, results, key=lambda x: x[1])
logger.info("Retrieved %d/%d chunks for query (len=%d)", len(top_results), len(results), len(query_tokens)) logger.info(
"Retrieved %d/%d chunks for query (len=%d)",
len(top_results),
len(results),
len(query_tokens),
)
if self.with_scores: if self.with_scores:
return [(self.payloads[chunk_id], score) for chunk_id, score in top_results] return [(self.payloads[chunk_id], score) for chunk_id, score in top_results]

View file

@ -153,10 +153,12 @@ async def get_search_type_tools(
TemporalRetriever(top_k=top_k).get_completion, TemporalRetriever(top_k=top_k).get_completion,
TemporalRetriever(top_k=top_k).get_context, TemporalRetriever(top_k=top_k).get_context,
], ],
SearchType.CHUNKS_LEXICAL: (lambda _r=JaccardChunksRetriever(top_k=top_k): [ SearchType.CHUNKS_LEXICAL: (
_r.get_completion, lambda _r=JaccardChunksRetriever(top_k=top_k): [
_r.get_context, _r.get_completion,
])(), _r.get_context,
]
)(),
SearchType.CODING_RULES: [ SearchType.CODING_RULES: [
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules, CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
], ],