chore: format files

This commit is contained in:
Igor Ilic 2025-09-22 11:33:19 +02:00
parent 023f5ea632
commit 87c79b52e3
3 changed files with 60 additions and 48 deletions

View file

@ -2,13 +2,20 @@ from cognee.modules.retrieval.lexical_retriever import LexicalRetriever
import re
from collections import Counter
from typing import Optional
class JaccardChunksRetriever(LexicalRetriever):
"""
Retriever that specializes LexicalRetriever to use Jaccard similarity.
"""
def __init__(self, top_k: int = 10, with_scores: bool = False,
stop_words: Optional[list[str]] = None, multiset_jaccard: bool = False):
def __init__(
self,
top_k: int = 10,
with_scores: bool = False,
stop_words: Optional[list[str]] = None,
multiset_jaccard: bool = False,
):
"""
Parameters
----------
@ -25,10 +32,7 @@ class JaccardChunksRetriever(LexicalRetriever):
self.multiset_jaccard = multiset_jaccard
super().__init__(
tokenizer=self._tokenizer,
scorer=self._scorer,
top_k=top_k,
with_scores=with_scores
tokenizer=self._tokenizer, scorer=self._scorer, top_k=top_k, with_scores=with_scores
)
def _tokenizer(self, text: str) -> list[str]:

View file

@ -12,8 +12,9 @@ logger = get_logger("LexicalRetriever")
class LexicalRetriever(BaseRetriever):
def __init__(self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False):
def __init__(
self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False
):
if not callable(tokenizer) or not callable(scorer):
raise TypeError("tokenizer and scorer must be callables")
if not isinstance(top_k, int) or top_k <= 0:
@ -25,51 +26,51 @@ class LexicalRetriever(BaseRetriever):
self.with_scores = bool(with_scores)
# Cache keyed by dataset context
self.chunks: dict[str, Any] = {} # {chunk_id: tokens}
self.payloads: dict[str, Any] = {} # {chunk_id: original_document}
self.chunks: dict[str, Any] = {} # {chunk_id: tokens}
self.payloads: dict[str, Any] = {} # {chunk_id: original_document}
self._initialized = False
self._init_lock = asyncio.Lock()
async def initialize(self):
"""Initialize retriever by reading all DocumentChunks from graph_engine."""
async with self._init_lock:
if self._initialized:
return
"""Initialize retriever by reading all DocumentChunks from graph_engine."""
async with self._init_lock:
if self._initialized:
return
logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine")
logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine")
try:
graph_engine = await get_graph_engine()
nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}])
except Exception as e:
logger.error("Graph engine initialization failed")
raise NoDataError("Graph engine initialization failed") from e
try:
graph_engine = await get_graph_engine()
nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}])
except Exception as e:
logger.error("Graph engine initialization failed")
raise NoDataError("Graph engine initialization failed") from e
chunk_count = 0
for node in nodes:
try:
chunk_id, document = node
except Exception:
logger.warning("Skipping node with unexpected shape: %r", node)
continue
chunk_count = 0
for node in nodes:
try:
chunk_id, document = node
except Exception:
logger.warning("Skipping node with unexpected shape: %r", node)
continue
if document.get("type") == "DocumentChunk" and document.get("text"):
try:
tokens = self.tokenizer(document["text"])
if not tokens:
continue
self.chunks[str(document.get("id",chunk_id))] = tokens
self.payloads[str(document.get("id",chunk_id))] = document
chunk_count += 1
except Exception as e:
logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e))
if document.get("type") == "DocumentChunk" and document.get("text"):
try:
tokens = self.tokenizer(document["text"])
if not tokens:
continue
self.chunks[str(document.get("id", chunk_id))] = tokens
self.payloads[str(document.get("id", chunk_id))] = document
chunk_count += 1
except Exception as e:
logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e))
if chunk_count == 0:
logger.error("Initialization completed but no valid chunks were loaded.")
raise NoDataError("No valid chunks loaded during initialization.")
if chunk_count == 0:
logger.error("Initialization completed but no valid chunks were loaded.")
raise NoDataError("No valid chunks loaded during initialization.")
self._initialized = True
logger.info("Initialized with %d document chunks", len(self.chunks))
self._initialized = True
logger.info("Initialized with %d document chunks", len(self.chunks))
async def get_context(self, query: str) -> Any:
"""Retrieves relevant chunks for the given query."""
@ -103,7 +104,12 @@ class LexicalRetriever(BaseRetriever):
results.append((chunk_id, score))
top_results = nlargest(self.top_k, results, key=lambda x: x[1])
logger.info("Retrieved %d/%d chunks for query (len=%d)", len(top_results), len(results), len(query_tokens))
logger.info(
"Retrieved %d/%d chunks for query (len=%d)",
len(top_results),
len(results),
len(query_tokens),
)
if self.with_scores:
return [(self.payloads[chunk_id], score) for chunk_id, score in top_results]

View file

@ -153,10 +153,12 @@ async def get_search_type_tools(
TemporalRetriever(top_k=top_k).get_completion,
TemporalRetriever(top_k=top_k).get_context,
],
SearchType.CHUNKS_LEXICAL: (lambda _r=JaccardChunksRetriever(top_k=top_k): [
_r.get_completion,
_r.get_context,
])(),
SearchType.CHUNKS_LEXICAL: (
lambda _r=JaccardChunksRetriever(top_k=top_k): [
_r.get_completion,
_r.get_context,
]
)(),
SearchType.CODING_RULES: [
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
],