This commit is contained in:
Vasilije 2024-03-21 16:59:25 +01:00
parent 49b80ec898
commit 1a213a9f93
9 changed files with 1832 additions and 1183 deletions

File diff suppressed because one or more lines are too long

View file

@ -178,7 +178,12 @@ async def process_text(input_text: str, file_metadata: dict):
results = await resolve_cross_graph_references(nodes_by_layer) results = await resolve_cross_graph_references(nodes_by_layer)
relationships = graph_ready_output(results) relationships = graph_ready_output(results)
# print(relationships)
await graph_client.load_graph_from_file()
graph = graph_client.graph
connect_nodes_in_graph(graph, relationships) connect_nodes_in_graph(graph, relationships)

View file

@ -2,7 +2,7 @@
import asyncio import asyncio
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, Any, Callable, List from typing import Dict, Any, Callable, List
from pydantic import BaseModel, validator from pydantic import BaseModel, field_validator
from cognee.modules.search.graph.search_adjacent import search_adjacent from cognee.modules.search.graph.search_adjacent import search_adjacent
from cognee.modules.search.vector.search_similarity import search_similarity from cognee.modules.search.vector.search_similarity import search_similarity
from cognee.modules.search.graph.search_categories import search_categories from cognee.modules.search.graph.search_categories import search_categories
@ -28,7 +28,7 @@ class SearchParameters(BaseModel):
search_type: SearchType search_type: SearchType
params: Dict[str, Any] params: Dict[str, Any]
@validator('search_type', pre=True) @field_validator('search_type', mode='before')
def convert_string_to_enum(cls, value): def convert_string_to_enum(cls, value):
if isinstance(value, str): if isinstance(value, str):
return SearchType.from_str(value) return SearchType.from_str(value)
@ -80,8 +80,8 @@ if __name__ == "__main__":
await graph_client.load_graph_from_file() await graph_client.load_graph_from_file()
graph = graph_client.graph graph = graph_client.graph
# Assuming 'graph' is your graph object, obtained from somewhere # Assuming 'graph' is your graph object, obtained from somewhere
search_type = 'ADJACENT' search_type = 'CATEGORIES'
params = {'query': 'example query', 'other_param': {"node_id": "LLM_LAYER_SUMMARY:DOCUMENT:881ecb36-2819-54c3-8147-ed80293084d6"}} params = {'query': 'Ministarstvo', 'other_param': {"node_id": "LLM_LAYER_SUMMARY:DOCUMENT:881ecb36-2819-54c3-8147-ed80293084d6"}}
results = await search(graph, search_type, params) results = await search(graph, search_type, params)
print(results) print(results)

View file

@ -3,6 +3,6 @@ from typing import Any, Dict
from pydantic import BaseModel from pydantic import BaseModel
class ScoredResult(BaseModel): class ScoredResult(BaseModel):
id: UUID id: str
score: int score: float
payload: Dict[str, Any] payload: Dict[str, Any]

View file

@ -49,17 +49,19 @@ class WeaviateAdapter(VectorDBInterface):
return self.get_collection(collection_name).data.insert_many(objects) return self.get_collection(collection_name).data.insert_many(objects)
async def search(self, collection_name: str, query_text: str, limit: int, with_vector: bool = False): async def search(self, collection_name: str, query_text: str, limit: int, with_vector: bool = False):
search_result = self.get_collection(collection_name).query.bm25( search_result = self.get_collection(collection_name).query.hybrid(
query = query_text, query = query_text,
limit = limit, limit = limit,
include_vector = with_vector, include_vector = with_vector,
return_metadata = wvc.query.MetadataQuery(score = True), return_metadata = wvc.query.MetadataQuery(score = True),
) )
# print(search_result.objects)
return list(map(lambda result: ScoredResult( return list(map(lambda result: ScoredResult(
id = result.uuid, id = str(result.uuid),
payload = result.properties, payload = result.properties,
score = str(result.metadata.score) score = float(result.metadata.score)
), search_result.objects)) ), search_result.objects))
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False): async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):

View file

@ -31,7 +31,7 @@ async def group_nodes_by_layer(node_descriptions):
return grouped_data return grouped_data
def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph: def connect_nodes_in_graph(graph: Graph, relationship_dict: dict, score_treshold:float=None) -> Graph:
""" """
For each relationship in relationship_dict, check if both nodes exist in the graph based on node attributes. For each relationship in relationship_dict, check if both nodes exist in the graph based on node attributes.
If they do, create a connection (edge) between them. If they do, create a connection (edge) between them.
@ -39,37 +39,43 @@ def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph:
:param graph: A NetworkX graph object :param graph: A NetworkX graph object
:param relationship_dict: A dictionary containing relationships between nodes :param relationship_dict: A dictionary containing relationships between nodes
""" """
if score_treshold is None:
score_treshold = 0.9
for id, relationships in relationship_dict.items(): for id, relationships in relationship_dict.items():
for relationship in relationships: for relationship in relationships:
searched_node_attr_id = relationship["searched_node_id"] searched_node_attr_id = relationship["searched_node_id"]
score_attr_id = relationship["original_id_for_search"] score_attr_id = relationship["original_id_for_search"]
score = relationship["score"] score = relationship["score"]
# Initialize node keys for both searched_node and score_node if score> score_treshold:
searched_node_key, score_node_key = None, None # Initialize node keys for both searched_node and score_node
searched_node_key, score_node_key = None, None
# Find nodes in the graph that match the searched_node_id and score_id from their attributes # Find nodes in the graph that match the searched_node_id and score_id from their attributes
for node, attrs in graph.nodes(data = True): for node, attrs in graph.nodes(data = True):
if "unique_id" in attrs: # Ensure there is an "id" attribute if "unique_id" in attrs: # Ensure there is an "id" attribute
if attrs["unique_id"] == searched_node_attr_id: if attrs["unique_id"] == searched_node_attr_id:
searched_node_key = node searched_node_key = node
elif attrs["unique_id"] == score_attr_id: elif attrs["unique_id"] == score_attr_id:
score_node_key = node score_node_key = node
# If both nodes are found, no need to continue checking other nodes # If both nodes are found, no need to continue checking other nodes
if searched_node_key and score_node_key: if searched_node_key and score_node_key:
break break
# Check if both nodes were found in the graph # Check if both nodes were found in the graph
if searched_node_key is not None and score_node_key is not None: if searched_node_key is not None and score_node_key is not None:
# If both nodes exist, create an edge between them # print(f"Connecting {searched_node_key} to {score_node_key}")
# You can customize the edge attributes as needed, here we use "score" as an attribute # If both nodes exist, create an edge between them
graph.add_edge( # You can customize the edge attributes as needed, here we use "score" as an attribute
searched_node_key, graph.add_edge(
score_node_key, searched_node_key,
weight = score, score_node_key,
score_metadata = relationship.get("score_metadata") weight = score,
) score_metadata = relationship.get("score_metadata")
)
else:
pass
return graph return graph
@ -97,3 +103,32 @@ def graph_ready_output(results):
}) })
return relationship_dict return relationship_dict
if __name__ == "__main__":
async def main():
graph_client = get_graph_client(GraphDBType.NETWORKX)
await graph_client.load_graph_from_file()
graph = graph_client.graph
# for nodes, attr in graph.nodes(data=True):
# if 'd0bd0f6a-09e5-4308-89f6-400d66895126' in nodes:
# print(nodes)
relationships = {'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah': [{'collection_id': 'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah', 'searched_node_id': 'd0bd0f6a-09e5-4308-89f6-400d66895126', 'score': 1.0, 'score_metadata': {'text': 'Pravilnik o izmenama i dopunama Pravilnika o sadržini, načinu i postupku izrade i način vršenja kontrole tehničke dokumentacije prema klasi i nameni objekata'}, 'original_id_for_search': '2801f7b5-55bf-499b-9843-97d48f8e067a'}, {'collection_id': 'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah', 'searched_node_id': 'd0bd0f6a-09e5-4308-89f6-400d66895126', 'score': 0.1648828387260437, 'score_metadata': {'text': 'Zakon o planiranju i izgradnji'}, 'original_id_for_search': '57966b55-33e2-4eae-a7fa-2f0237643bbe'}, {'collection_id': 'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah', 'searched_node_id': 'd0bd0f6a-09e5-4308-89f6-400d66895126', 'score': 0.12986786663532257, 'score_metadata': {'text': 'Službeni glasnik RS, broj 77/2015'}, 'original_id_for_search': '0f626d48-4441-43c1-9060-ea7e54f6d8e2'}, {'collection_id': 'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah', 'searched_node_id': 'c9b9a460-c64a-4e2e-a4d6-aa5b3769274b', 'score': 1.0, 'score_metadata': {'text': 'Službeni glasnik RS, broj 77/2015'}, 'original_id_for_search': '0f626d48-4441-43c1-9060-ea7e54f6d8e2'}, {'collection_id': 'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah', 'searched_node_id': 'c9b9a460-c64a-4e2e-a4d6-aa5b3769274b', 'score': 0.07603412866592407, 'score_metadata': {'text': 'Prof. dr Zorana Mihajlović'}, 'original_id_for_search': '5d064a62-3cd6-4895-9f60-1a0d8bc299e8'}, {'collection_id': 'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah', 'searched_node_id': 'c9b9a460-c64a-4e2e-a4d6-aa5b3769274b', 'score': 0.07226034998893738, 'score_metadata': {'text': 'Ministar građevinarstva, saobraćaja i infrastrukture'}, 'original_id_for_search': 'f5d052ca-c4a0-490e-a3ac-d8ad522dea83'}, {'collection_id': 'SuaGeKyKWKWyaSeiqWeWaSyuSKqieSamiyah', 'searched_node_id': 'bbd6d2d6-e673-4b59-a50c-516972a9d0de', 'score': 0.5, 'score_metadata': {'text': 'Pravilnik o izmenama i dopunama Pravilnika o sadržini, načinu i postupku izrade i način vršenja kontrole tehničke dokumentacije prema klasi i nameni objekata'}, 'original_id_for_search': '2801f7b5-55bf-499b-9843-97d48f8e067a'}]}
connect_nodes_in_graph(graph, relationships)
from cognitive_architecture.utils import render_graph
graph_url = await render_graph(graph, graph_type="networkx")
print(graph_url)
import asyncio
asyncio.run(main())

View file

@ -28,11 +28,18 @@ async def search_similarity(query: str, graph, other_param: str = None):
for proposition_id in out[0][0]: for proposition_id in out[0][0]:
for n, attr in graph.nodes(data = True): for n, attr in graph.nodes(data = True):
if proposition_id in n: if str(proposition_id) in str(n):
for n_, attr_ in graph.nodes(data=True): for n_, attr_ in graph.nodes(data=True):
relevant_layer = attr["layer_uuid"] relevant_layer = attr["layer_uuid"]
if attr_.get("layer_uuid") == relevant_layer: if attr_.get("layer_uuid") == relevant_layer:
relevant_context.append(attr_["description"]) relevant_context.append(attr_["description"])
def deduplicate_list(original_list):
seen = set()
deduplicated_list = [x for x in original_list if not (x in seen or seen.add(x))]
return deduplicated_list
relevant_context = deduplicate_list(relevant_context)
return relevant_context return relevant_context

1048
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "cognee" name = "cognee"
version = "0.1.0" version = "0.1.1"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = ["Vasilije Markovic"] authors = ["Vasilije Markovic"]
readme = "README.md" readme = "README.md"
@ -18,7 +18,6 @@ classifiers = [
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "~3.10" python = "~3.10"
langchain = "^0.0.338"
openai = "1.12.0" openai = "1.12.0"
python-dotenv = "1.0.1" python-dotenv = "1.0.1"
fastapi = "^0.109.2" fastapi = "^0.109.2"