chore: format files
This commit is contained in:
parent
023f5ea632
commit
87c79b52e3
3 changed files with 60 additions and 48 deletions
|
|
@ -2,13 +2,20 @@ from cognee.modules.retrieval.lexical_retriever import LexicalRetriever
|
|||
import re
|
||||
from collections import Counter
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class JaccardChunksRetriever(LexicalRetriever):
|
||||
"""
|
||||
Retriever that specializes LexicalRetriever to use Jaccard similarity.
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: int = 10, with_scores: bool = False,
|
||||
stop_words: Optional[list[str]] = None, multiset_jaccard: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = 10,
|
||||
with_scores: bool = False,
|
||||
stop_words: Optional[list[str]] = None,
|
||||
multiset_jaccard: bool = False,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -25,10 +32,7 @@ class JaccardChunksRetriever(LexicalRetriever):
|
|||
self.multiset_jaccard = multiset_jaccard
|
||||
|
||||
super().__init__(
|
||||
tokenizer=self._tokenizer,
|
||||
scorer=self._scorer,
|
||||
top_k=top_k,
|
||||
with_scores=with_scores
|
||||
tokenizer=self._tokenizer, scorer=self._scorer, top_k=top_k, with_scores=with_scores
|
||||
)
|
||||
|
||||
def _tokenizer(self, text: str) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ logger = get_logger("LexicalRetriever")
|
|||
|
||||
|
||||
class LexicalRetriever(BaseRetriever):
|
||||
|
||||
def __init__(self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False):
|
||||
def __init__(
|
||||
self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False
|
||||
):
|
||||
if not callable(tokenizer) or not callable(scorer):
|
||||
raise TypeError("tokenizer and scorer must be callables")
|
||||
if not isinstance(top_k, int) or top_k <= 0:
|
||||
|
|
@ -25,51 +26,51 @@ class LexicalRetriever(BaseRetriever):
|
|||
self.with_scores = bool(with_scores)
|
||||
|
||||
# Cache keyed by dataset context
|
||||
self.chunks: dict[str, Any] = {} # {chunk_id: tokens}
|
||||
self.payloads: dict[str, Any] = {} # {chunk_id: original_document}
|
||||
self.chunks: dict[str, Any] = {} # {chunk_id: tokens}
|
||||
self.payloads: dict[str, Any] = {} # {chunk_id: original_document}
|
||||
self._initialized = False
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize retriever by reading all DocumentChunks from graph_engine."""
|
||||
async with self._init_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
"""Initialize retriever by reading all DocumentChunks from graph_engine."""
|
||||
async with self._init_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine")
|
||||
logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine")
|
||||
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}])
|
||||
except Exception as e:
|
||||
logger.error("Graph engine initialization failed")
|
||||
raise NoDataError("Graph engine initialization failed") from e
|
||||
try:
|
||||
graph_engine = await get_graph_engine()
|
||||
nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}])
|
||||
except Exception as e:
|
||||
logger.error("Graph engine initialization failed")
|
||||
raise NoDataError("Graph engine initialization failed") from e
|
||||
|
||||
chunk_count = 0
|
||||
for node in nodes:
|
||||
try:
|
||||
chunk_id, document = node
|
||||
except Exception:
|
||||
logger.warning("Skipping node with unexpected shape: %r", node)
|
||||
continue
|
||||
chunk_count = 0
|
||||
for node in nodes:
|
||||
try:
|
||||
chunk_id, document = node
|
||||
except Exception:
|
||||
logger.warning("Skipping node with unexpected shape: %r", node)
|
||||
continue
|
||||
|
||||
if document.get("type") == "DocumentChunk" and document.get("text"):
|
||||
try:
|
||||
tokens = self.tokenizer(document["text"])
|
||||
if not tokens:
|
||||
continue
|
||||
self.chunks[str(document.get("id",chunk_id))] = tokens
|
||||
self.payloads[str(document.get("id",chunk_id))] = document
|
||||
chunk_count += 1
|
||||
except Exception as e:
|
||||
logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e))
|
||||
if document.get("type") == "DocumentChunk" and document.get("text"):
|
||||
try:
|
||||
tokens = self.tokenizer(document["text"])
|
||||
if not tokens:
|
||||
continue
|
||||
self.chunks[str(document.get("id", chunk_id))] = tokens
|
||||
self.payloads[str(document.get("id", chunk_id))] = document
|
||||
chunk_count += 1
|
||||
except Exception as e:
|
||||
logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e))
|
||||
|
||||
if chunk_count == 0:
|
||||
logger.error("Initialization completed but no valid chunks were loaded.")
|
||||
raise NoDataError("No valid chunks loaded during initialization.")
|
||||
if chunk_count == 0:
|
||||
logger.error("Initialization completed but no valid chunks were loaded.")
|
||||
raise NoDataError("No valid chunks loaded during initialization.")
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Initialized with %d document chunks", len(self.chunks))
|
||||
self._initialized = True
|
||||
logger.info("Initialized with %d document chunks", len(self.chunks))
|
||||
|
||||
async def get_context(self, query: str) -> Any:
|
||||
"""Retrieves relevant chunks for the given query."""
|
||||
|
|
@ -103,7 +104,12 @@ class LexicalRetriever(BaseRetriever):
|
|||
results.append((chunk_id, score))
|
||||
|
||||
top_results = nlargest(self.top_k, results, key=lambda x: x[1])
|
||||
logger.info("Retrieved %d/%d chunks for query (len=%d)", len(top_results), len(results), len(query_tokens))
|
||||
logger.info(
|
||||
"Retrieved %d/%d chunks for query (len=%d)",
|
||||
len(top_results),
|
||||
len(results),
|
||||
len(query_tokens),
|
||||
)
|
||||
|
||||
if self.with_scores:
|
||||
return [(self.payloads[chunk_id], score) for chunk_id, score in top_results]
|
||||
|
|
|
|||
|
|
@ -153,10 +153,12 @@ async def get_search_type_tools(
|
|||
TemporalRetriever(top_k=top_k).get_completion,
|
||||
TemporalRetriever(top_k=top_k).get_context,
|
||||
],
|
||||
SearchType.CHUNKS_LEXICAL: (lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
||||
_r.get_completion,
|
||||
_r.get_context,
|
||||
])(),
|
||||
SearchType.CHUNKS_LEXICAL: (
|
||||
lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
||||
_r.get_completion,
|
||||
_r.get_context,
|
||||
]
|
||||
)(),
|
||||
SearchType.CODING_RULES: [
|
||||
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
|
||||
],
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue