diff --git a/cognee/modules/search/methods/search.py b/cognee/modules/search/methods/search.py index becfb669c..5202fd8ee 100644 --- a/cognee/modules/search/methods/search.py +++ b/cognee/modules/search/methods/search.py @@ -173,6 +173,7 @@ async def search( graphs = prepared_search_results["graphs"] context = prepared_search_results["context"] datasets = prepared_search_results["datasets"] + citations = prepared_search_results.get("citations", []) if only_context: search_result_dict = { @@ -180,6 +181,7 @@ async def search( "dataset_id": datasets[0].id, "dataset_name": datasets[0].name, "dataset_tenant_id": datasets[0].tenant_id, + "citations": citations, } if verbose: # Include graphs only in verbose mode @@ -192,6 +194,7 @@ async def search( "dataset_id": datasets[0].id, "dataset_name": datasets[0].name, "dataset_tenant_id": datasets[0].tenant_id, + "citations": citations, } if verbose: # Include graphs only in verbose mode diff --git a/cognee/modules/search/types/SearchResult.py b/cognee/modules/search/types/SearchResult.py index 8ea5d3990..eb5826aab 100644 --- a/cognee/modules/search/types/SearchResult.py +++ b/cognee/modules/search/types/SearchResult.py @@ -19,3 +19,4 @@ class SearchResult(BaseModel): search_result: Any dataset_id: Optional[UUID] dataset_name: Optional[str] + citations: Optional[List[Dict[str, Any]]] = None diff --git a/cognee/modules/search/utils/prepare_search_result.py b/cognee/modules/search/utils/prepare_search_result.py index b854a318d..876fb8353 100644 --- a/cognee/modules/search/utils/prepare_search_result.py +++ b/cognee/modules/search/utils/prepare_search_result.py @@ -55,9 +55,28 @@ async def prepare_search_result(search_result): if isinstance(results, List) and len(results) > 0 and isinstance(results[0], Edge): result_graph = transform_context_to_graph(results) + citations = [] + if isinstance(context, List) and len(context) > 0 and isinstance(context[0], Edge): + seen_ids = set() + for edge in context: + for node in [edge.node1, edge.node2]: + if node.id not in seen_ids: + seen_ids.add(node.id) + # Extract attributes, prioritizing commonly used citation fields + citation = { + "id": str(node.id), + "text": node.attributes.get("text", ""), + "metadata": { + k: v for k, v in node.attributes.items() + if k not in ["text", "vector_distance"] + } + } + citations.append(citation) + return { "result": result_graph or results[0] if results and len(results) == 1 else results, "graphs": graphs, "context": context_texts, "datasets": datasets, + "citations": citations, }