chore: format files
This commit is contained in:
parent
023f5ea632
commit
87c79b52e3
3 changed files with 60 additions and 48 deletions
|
|
@ -2,13 +2,20 @@ from cognee.modules.retrieval.lexical_retriever import LexicalRetriever
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class JaccardChunksRetriever(LexicalRetriever):
|
class JaccardChunksRetriever(LexicalRetriever):
|
||||||
"""
|
"""
|
||||||
Retriever that specializes LexicalRetriever to use Jaccard similarity.
|
Retriever that specializes LexicalRetriever to use Jaccard similarity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, top_k: int = 10, with_scores: bool = False,
|
def __init__(
|
||||||
stop_words: Optional[list[str]] = None, multiset_jaccard: bool = False):
|
self,
|
||||||
|
top_k: int = 10,
|
||||||
|
with_scores: bool = False,
|
||||||
|
stop_words: Optional[list[str]] = None,
|
||||||
|
multiset_jaccard: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
|
@ -25,10 +32,7 @@ class JaccardChunksRetriever(LexicalRetriever):
|
||||||
self.multiset_jaccard = multiset_jaccard
|
self.multiset_jaccard = multiset_jaccard
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
tokenizer=self._tokenizer,
|
tokenizer=self._tokenizer, scorer=self._scorer, top_k=top_k, with_scores=with_scores
|
||||||
scorer=self._scorer,
|
|
||||||
top_k=top_k,
|
|
||||||
with_scores=with_scores
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _tokenizer(self, text: str) -> list[str]:
|
def _tokenizer(self, text: str) -> list[str]:
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,9 @@ logger = get_logger("LexicalRetriever")
|
||||||
|
|
||||||
|
|
||||||
class LexicalRetriever(BaseRetriever):
|
class LexicalRetriever(BaseRetriever):
|
||||||
|
def __init__(
|
||||||
def __init__(self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False):
|
self, tokenizer: Callable, scorer: Callable, top_k: int = 10, with_scores: bool = False
|
||||||
|
):
|
||||||
if not callable(tokenizer) or not callable(scorer):
|
if not callable(tokenizer) or not callable(scorer):
|
||||||
raise TypeError("tokenizer and scorer must be callables")
|
raise TypeError("tokenizer and scorer must be callables")
|
||||||
if not isinstance(top_k, int) or top_k <= 0:
|
if not isinstance(top_k, int) or top_k <= 0:
|
||||||
|
|
@ -25,51 +26,51 @@ class LexicalRetriever(BaseRetriever):
|
||||||
self.with_scores = bool(with_scores)
|
self.with_scores = bool(with_scores)
|
||||||
|
|
||||||
# Cache keyed by dataset context
|
# Cache keyed by dataset context
|
||||||
self.chunks: dict[str, Any] = {} # {chunk_id: tokens}
|
self.chunks: dict[str, Any] = {} # {chunk_id: tokens}
|
||||||
self.payloads: dict[str, Any] = {} # {chunk_id: original_document}
|
self.payloads: dict[str, Any] = {} # {chunk_id: original_document}
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
self._init_lock = asyncio.Lock()
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize retriever by reading all DocumentChunks from graph_engine."""
|
"""Initialize retriever by reading all DocumentChunks from graph_engine."""
|
||||||
async with self._init_lock:
|
async with self._init_lock:
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine")
|
logger.info("Initializing LexicalRetriever by loading DocumentChunks from graph engine")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph_engine = await get_graph_engine()
|
graph_engine = await get_graph_engine()
|
||||||
nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}])
|
nodes, _ = await graph_engine.get_filtered_graph_data([{"type": ["DocumentChunk"]}])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Graph engine initialization failed")
|
logger.error("Graph engine initialization failed")
|
||||||
raise NoDataError("Graph engine initialization failed") from e
|
raise NoDataError("Graph engine initialization failed") from e
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
try:
|
try:
|
||||||
chunk_id, document = node
|
chunk_id, document = node
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Skipping node with unexpected shape: %r", node)
|
logger.warning("Skipping node with unexpected shape: %r", node)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if document.get("type") == "DocumentChunk" and document.get("text"):
|
if document.get("type") == "DocumentChunk" and document.get("text"):
|
||||||
try:
|
try:
|
||||||
tokens = self.tokenizer(document["text"])
|
tokens = self.tokenizer(document["text"])
|
||||||
if not tokens:
|
if not tokens:
|
||||||
continue
|
continue
|
||||||
self.chunks[str(document.get("id",chunk_id))] = tokens
|
self.chunks[str(document.get("id", chunk_id))] = tokens
|
||||||
self.payloads[str(document.get("id",chunk_id))] = document
|
self.payloads[str(document.get("id", chunk_id))] = document
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e))
|
logger.error("Tokenizer failed for chunk %s: %s", chunk_id, str(e))
|
||||||
|
|
||||||
if chunk_count == 0:
|
if chunk_count == 0:
|
||||||
logger.error("Initialization completed but no valid chunks were loaded.")
|
logger.error("Initialization completed but no valid chunks were loaded.")
|
||||||
raise NoDataError("No valid chunks loaded during initialization.")
|
raise NoDataError("No valid chunks loaded during initialization.")
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
logger.info("Initialized with %d document chunks", len(self.chunks))
|
logger.info("Initialized with %d document chunks", len(self.chunks))
|
||||||
|
|
||||||
async def get_context(self, query: str) -> Any:
|
async def get_context(self, query: str) -> Any:
|
||||||
"""Retrieves relevant chunks for the given query."""
|
"""Retrieves relevant chunks for the given query."""
|
||||||
|
|
@ -103,7 +104,12 @@ class LexicalRetriever(BaseRetriever):
|
||||||
results.append((chunk_id, score))
|
results.append((chunk_id, score))
|
||||||
|
|
||||||
top_results = nlargest(self.top_k, results, key=lambda x: x[1])
|
top_results = nlargest(self.top_k, results, key=lambda x: x[1])
|
||||||
logger.info("Retrieved %d/%d chunks for query (len=%d)", len(top_results), len(results), len(query_tokens))
|
logger.info(
|
||||||
|
"Retrieved %d/%d chunks for query (len=%d)",
|
||||||
|
len(top_results),
|
||||||
|
len(results),
|
||||||
|
len(query_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
if self.with_scores:
|
if self.with_scores:
|
||||||
return [(self.payloads[chunk_id], score) for chunk_id, score in top_results]
|
return [(self.payloads[chunk_id], score) for chunk_id, score in top_results]
|
||||||
|
|
|
||||||
|
|
@ -153,10 +153,12 @@ async def get_search_type_tools(
|
||||||
TemporalRetriever(top_k=top_k).get_completion,
|
TemporalRetriever(top_k=top_k).get_completion,
|
||||||
TemporalRetriever(top_k=top_k).get_context,
|
TemporalRetriever(top_k=top_k).get_context,
|
||||||
],
|
],
|
||||||
SearchType.CHUNKS_LEXICAL: (lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
SearchType.CHUNKS_LEXICAL: (
|
||||||
_r.get_completion,
|
lambda _r=JaccardChunksRetriever(top_k=top_k): [
|
||||||
_r.get_context,
|
_r.get_completion,
|
||||||
])(),
|
_r.get_context,
|
||||||
|
]
|
||||||
|
)(),
|
||||||
SearchType.CODING_RULES: [
|
SearchType.CODING_RULES: [
|
||||||
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
|
CodingRulesRetriever(rules_nodeset_name=node_name).get_existing_rules,
|
||||||
],
|
],
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue