feat: embedding triplet completion retriever
This commit is contained in:
parent
0bbb7e3781
commit
e0294b38ff
1 changed files with 122 additions and 0 deletions
|
|
@ -0,0 +1,122 @@
|
||||||
|
from typing import List
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from cognee.modules.retrieval.graph_completion_retriever import GraphCompletionRetriever
|
||||||
|
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||||
|
from cognee.infrastructure.databases.vector.models.ScoredResult import ScoredResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TripletNode:
|
||||||
|
"""Graph node representation from triplet data."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
attributes: dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TripletEdge:
|
||||||
|
"""Graph edge representation from triplet data."""
|
||||||
|
|
||||||
|
node1: TripletNode
|
||||||
|
node2: TripletNode
|
||||||
|
attributes: dict
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddedTripletCompletionRetriever(GraphCompletionRetriever):
|
||||||
|
"""Retriever that uses embedded triplets from vector collection instead of brute force search.
|
||||||
|
|
||||||
|
Data Conversion Rationale:
|
||||||
|
GraphCompletionRetriever expects graph objects with .node1/.node2 having .id and .attributes,
|
||||||
|
but our triplets are stored as JSON: {"start_node": {"id", "content"}, "relationship", "end_node": {"id", "content"}}.
|
||||||
|
We convert triplet format → TripletNode/TripletEdge objects to match the expected interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, collection_name: str = "Triplets", **kwargs):
|
||||||
|
"""Initialize with configurable collection name."""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
async def get_triplets(self, query: str) -> List[TripletEdge]:
|
||||||
|
"""Override parent method to use vector search on triplet collection."""
|
||||||
|
vector_engine = get_vector_engine()
|
||||||
|
search_results = await vector_engine.search(
|
||||||
|
collection_name=self.collection_name, query_text=query, limit=self.top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
triplet_edges = []
|
||||||
|
for search_result in search_results:
|
||||||
|
triplet_edge = self._convert_search_result_to_edge(search_result)
|
||||||
|
triplet_edges.append(triplet_edge)
|
||||||
|
|
||||||
|
return triplet_edges
|
||||||
|
|
||||||
|
def _parse_triplet_payload(self, search_result: ScoredResult) -> dict:
|
||||||
|
"""Extract triplet data from search result payload."""
|
||||||
|
# ScoredResult.payload contains entire DataPoint structure
|
||||||
|
# Actual triplet JSON is nested at payload['payload']
|
||||||
|
triplet_json = search_result.payload["payload"]
|
||||||
|
return json.loads(triplet_json)
|
||||||
|
|
||||||
|
def _create_triplet_node(self, node_data: dict) -> TripletNode:
|
||||||
|
"""Convert triplet node data to graph node format."""
|
||||||
|
return TripletNode(
|
||||||
|
id=node_data["id"],
|
||||||
|
attributes={
|
||||||
|
"text": node_data.get("content", ""),
|
||||||
|
"name": node_data.get("content", "Unnamed Node"),
|
||||||
|
"description": node_data.get("content", ""),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_triplet_edge(
|
||||||
|
self, node1: TripletNode, node2: TripletNode, relationship: str
|
||||||
|
) -> TripletEdge:
|
||||||
|
"""Create graph edge from two nodes and relationship."""
|
||||||
|
return TripletEdge(node1=node1, node2=node2, attributes={"relationship_type": relationship})
|
||||||
|
|
||||||
|
def _convert_search_result_to_edge(self, search_result) -> TripletEdge:
|
||||||
|
"""Main conversion method: ScoredResult → TripletEdge."""
|
||||||
|
# 1. Extract triplet data from search result
|
||||||
|
triplet_data = self._parse_triplet_payload(search_result)
|
||||||
|
|
||||||
|
# 2. Extract triplet components
|
||||||
|
start_node_data = triplet_data["start_node"]
|
||||||
|
relationship = triplet_data["relationship"]
|
||||||
|
end_node_data = triplet_data["end_node"]
|
||||||
|
|
||||||
|
# 3. Create triplet nodes
|
||||||
|
start_node = self._create_triplet_node(start_node_data)
|
||||||
|
end_node = self._create_triplet_node(end_node_data)
|
||||||
|
|
||||||
|
# 4. Create triplet edge
|
||||||
|
triplet_edge = self._create_triplet_edge(start_node, end_node, relationship)
|
||||||
|
|
||||||
|
return triplet_edge
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
retriever = EmbeddedTripletCompletionRetriever(collection_name="Triplets", top_k=5)
|
||||||
|
|
||||||
|
# Test triplet retrieval with CV/resume data
|
||||||
|
query = "machine learning experience Python data scientist"
|
||||||
|
print(f"Searching for: {query}")
|
||||||
|
|
||||||
|
triplets = await retriever.get_triplets(query)
|
||||||
|
print(f"Found {len(triplets)} triplets")
|
||||||
|
|
||||||
|
# Test full context generation
|
||||||
|
context = await retriever.get_context(query)
|
||||||
|
print(f"Generated context:\n{context}")
|
||||||
|
|
||||||
|
# Test another query
|
||||||
|
query2 = "Stanford University education background"
|
||||||
|
print(f"\nSecond search: {query2}")
|
||||||
|
triplets2 = await retriever.get_triplets(query2)
|
||||||
|
print(f"Found {len(triplets2)} triplets")
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Reference in a new issue