Merge remote-tracking branch 'origin/main' into code-graph

This commit is contained in:
Boris Arzentar 2024-11-27 22:54:49 +01:00
commit d885a047ac
14 changed files with 260 additions and 41 deletions

View file

@ -142,12 +142,11 @@ class LanceDBAdapter(VectorDBInterface):
score = 0,
) for result in results.to_dict("index").values()]
async def get_distances_of_collection(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")

View file

@ -176,7 +176,7 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
) for result in results
]
async def get_distances_of_collection(
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
@ -192,8 +192,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
# Get PGVectorDataPoint Table from database
PGVectorDataPoint = await self.get_table(collection_name)
closest_items = []
# Use async session to connect to the database
async with self.get_async_session() as session:
# Find closest vectors to query_vector

View file

@ -142,6 +142,41 @@ class QDrantAdapter(VectorDBInterface):
await client.close()
return results
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
) -> List[ScoredResult]:
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
client = self.get_qdrant_client()
results = await client.search(
collection_name = collection_name,
query_vector = models.NamedVector(
name = "text",
vector = query_vector if query_vector is not None else (await self.embed_data([query_text]))[0],
),
with_vectors = with_vector
)
await client.close()
return [
ScoredResult(
id = UUID(result.id),
payload = {
**result.payload,
"id": UUID(result.id),
},
score = 1 - result.score,
) for result in results
]
async def search(
self,
collection_name: str,

View file

@ -1,18 +1,6 @@
from typing import List
def normalize_distances(result_values: List[dict]) -> List[float]:
min_value = 100
max_value = 0
for result in result_values:
value = float(result["_distance"])
if value > max_value:
max_value = value
if value < min_value:
min_value = value
normalized_values = []
min_value = min(result["_distance"] for result in result_values)
max_value = max(result["_distance"] for result in result_values)
@ -23,4 +11,4 @@ def normalize_distances(result_values: List[dict]) -> List[float]:
normalized_values = [(result["_distance"] - min_value) / (max_value - min_value) for result in
result_values]
return normalized_values
return normalized_values

View file

@ -153,6 +153,36 @@ class WeaviateAdapter(VectorDBInterface):
return await future
async def get_distance_from_collection_elements(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
with_vector: bool = False
) -> List[ScoredResult]:
import weaviate.classes as wvc
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_vector is None:
query_vector = (await self.embed_data([query_text]))[0]
search_result = self.get_collection(collection_name).query.hybrid(
query=None,
vector=query_vector,
include_vector=with_vector,
return_metadata=wvc.query.MetadataQuery(score=True),
)
return [
ScoredResult(
id=UUID(str(result.uuid)),
payload=result.properties,
score=1 - float(result.metadata.score)
) for result in search_result.objects
]
async def search(
self,
collection_name: str,

View file

@ -42,7 +42,7 @@ class CogneeGraph(CogneeAbstractGraph):
def get_node(self, node_id: str) -> Node:
return self.nodes.get(node_id, None)
def get_edges_of_node(self, node_id: str) -> List[Edge]:
def get_edges_from_node(self, node_id: str) -> List[Edge]:
node = self.get_node(node_id)
if node:
return node.skeleton_edges
@ -50,16 +50,18 @@ class CogneeGraph(CogneeAbstractGraph):
raise ValueError(f"Node with id {node_id} does not exist.")
def get_edges(self)-> List[Edge]:
return edges
return self.edges
async def project_graph_from_db(self,
adapter: Union[GraphDBInterface],
node_properties_to_project: List[str],
edge_properties_to_project: List[str],
directed = True,
node_dimension = 1,
edge_dimension = 1,
memory_fragment_filter = []) -> None:
async def project_graph_from_db(
self,
adapter: Union[GraphDBInterface],
node_properties_to_project: List[str],
edge_properties_to_project: List[str],
directed = True,
node_dimension = 1,
edge_dimension = 1,
memory_fragment_filter = [],
) -> None:
if node_dimension < 1 or edge_dimension < 1:
raise ValueError("Dimensions must be positive integers")
@ -158,15 +160,15 @@ class CogneeGraph(CogneeAbstractGraph):
print(f"Error mapping vector distances to edges: {ex}")
async def calculate_top_triplet_importances(self, k = int) -> List:
async def calculate_top_triplet_importances(self, k: int) -> List:
min_heap = []
for i, edge in enumerate(self.edges):
source_node = self.get_node(edge.node1.id)
target_node = self.get_node(edge.node2.id)
source_distance = source_node.attributes.get("vector_distance", 0) if source_node else 0
target_distance = target_node.attributes.get("vector_distance", 0) if target_node else 0
edge_distance = edge.attributes.get("vector_distance", 0)
source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1
target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1
edge_distance = edge.attributes.get("vector_distance", 1)
total_distance = source_distance + target_distance + edge_distance

View file

View file

@ -0,0 +1,150 @@
import asyncio
import logging
from typing import List
from cognee.modules.users.models import User
from cognee.modules.users.methods import get_default_user
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
from cognee.infrastructure.databases.vector import get_vector_engine
from cognee.infrastructure.databases.graph import get_graph_engine
from cognee.shared.utils import send_telemetry
def format_triplets(edges):
print("\n\n\n")
def filter_attributes(obj, attributes):
"""Helper function to filter out non-None properties, including nested dicts."""
result = {}
for attr in attributes:
value = getattr(obj, attr, None)
if value is not None:
# If the value is a dict, extract relevant keys from it
if isinstance(value, dict):
nested_values = {k: v for k, v in value.items() if k in attributes and v is not None}
result[attr] = nested_values
else:
result[attr] = value
return result
triplets = []
for edge in edges:
node1 = edge.node1
node2 = edge.node2
edge_attributes = edge.attributes
node1_attributes = node1.attributes
node2_attributes = node2.attributes
# Filter only non-None properties
node1_info = {key: value for key, value in node1_attributes.items() if value is not None}
node2_info = {key: value for key, value in node2_attributes.items() if value is not None}
edge_info = {key: value for key, value in edge_attributes.items() if value is not None}
# Create the formatted triplet
triplet = (
f"Node1: {node1_info}\n"
f"Edge: {edge_info}\n"
f"Node2: {node2_info}\n\n\n"
)
triplets.append(triplet)
return "".join(triplets)
async def brute_force_triplet_search(query: str, user: User = None, top_k = 5) -> list:
if user is None:
user = await get_default_user()
if user is None:
raise PermissionError("No user found in the system. Please create a user.")
retrieved_results = await brute_force_search(query, user, top_k)
return retrieved_results
def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is just for now to fix vector db duplicates
results_dict = {}
for collection, results in zip(collections, results):
seen_ids = set()
unique_results = []
for result in results:
if result.id not in seen_ids:
unique_results.append(result)
seen_ids.add(result.id)
else:
print(f"Duplicate found in collection '{collection}': {result.id}")
results_dict[collection] = unique_results
return results_dict
async def brute_force_search(
query: str,
user: User,
top_k: int,
collections: List[str] = None
) -> list:
"""
Performs a brute force search to retrieve the top triplets from the graph.
Args:
query (str): The search query.
user (User): The user performing the search.
top_k (int): The number of top results to retrieve.
collections (Optional[List[str]]): List of collections to query. Defaults to predefined collections.
Returns:
list: The top triplet results.
"""
if not query or not isinstance(query, str):
raise ValueError("The query must be a non-empty string.")
if top_k <= 0:
raise ValueError("top_k must be a positive integer.")
if collections is None:
collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"]
try:
vector_engine = get_vector_engine()
graph_engine = await get_graph_engine()
except Exception as e:
logging.error("Failed to initialize engines: %s", e)
raise RuntimeError("Initialization error") from e
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
try:
results = await asyncio.gather(
*[vector_engine.get_distance_from_collection_elements(collection, query_text=query) for collection in collections]
)
############################################# :TODO: Change when vector db does not contain duplicates
node_distances = delete_duplicated_vector_db_elements(collections, results)
# node_distances = {collection: result for collection, result in zip(collections, results)}
##############################################
memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(graph_engine,
node_properties_to_project=['id',
'description',
'name',
'type',
'text'],
edge_properties_to_project=['relationship_name'])
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
#:TODO: Change when vectordb contains edge embeddings
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
return results
except Exception as e:
logging.error("Error during brute force search for user: %s, query: %s. Error: %s", user.id, query, e)
send_telemetry("cognee.brute_force_triplet_search EXECUTION FAILED", user.id)
raise RuntimeError("An error occurred during brute force search") from e

View file

@ -4,6 +4,7 @@ import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)
@ -61,6 +62,9 @@ async def main():
assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -3,6 +3,7 @@ import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)
@ -89,6 +90,9 @@ async def main():
history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -5,6 +5,7 @@ import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)
@ -61,6 +62,9 @@ async def main():
history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -3,6 +3,7 @@ import logging
import pathlib
import cognee
from cognee.api.v1.search import SearchType
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
logging.basicConfig(level=logging.DEBUG)
@ -59,6 +60,9 @@ async def main():
history = await cognee.get_search_history()
assert len(history) == 6, "Search history is not correct."
results = await brute_force_triplet_search('What is a quantum computer?')
assert len(results) > 0
await cognee.prune.prune_data()
assert not os.path.isdir(data_directory_path), "Local data files are not deleted"

View file

@ -77,11 +77,11 @@ def test_get_edges_success(setup_graph):
graph.add_node(node2)
edge = Edge(node1, node2)
graph.add_edge(edge)
assert edge in graph.get_edges_of_node("node1")
assert edge in graph.get_edges_from_node("node1")
def test_get_edges_nonexistent_node(setup_graph):
"""Test retrieving edges for a nonexistent node raises an exception."""
graph = setup_graph
with pytest.raises(ValueError, match="Node with id nonexistent does not exist."):
graph.get_edges_of_node("nonexistent")
graph.get_edges_from_node("nonexistent")

View file

@ -1,6 +1,7 @@
import cognee
import asyncio
from cognee.pipelines.retriever.two_steps_retriever import two_step_retriever
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.modules.retrieval.brute_force_triplet_search import format_triplets
job_1 = """
CV 1: Relevant
@ -181,8 +182,8 @@ async def main(enable_steps):
# Step 4: Query insights
if enable_steps.get("retriever"):
await two_step_retriever('Who has Phd?')
results = await brute_force_triplet_search('Who has the most experience with graphic design?')
print(format_triplets(results))
if __name__ == '__main__':
# Flags to enable/disable steps