chore: Fixes some of the issues based on PR review + restructures things
This commit is contained in:
parent
676cdfcc84
commit
a59517409c
6 changed files with 26 additions and 80 deletions
|
|
@ -192,8 +192,6 @@ class PGVectorAdapter(SQLAlchemyAdapter, VectorDBInterface):
|
|||
# Get PGVectorDataPoint Table from database
|
||||
PGVectorDataPoint = await self.get_table(collection_name)
|
||||
|
||||
closest_items = []
|
||||
|
||||
# Use async session to connect to the database
|
||||
async with self.get_async_session() as session:
|
||||
# Find closest vectors to query_vector
|
||||
|
|
|
|||
|
|
@ -164,9 +164,9 @@ class CogneeGraph(CogneeAbstractGraph):
|
|||
source_node = self.get_node(edge.node1.id)
|
||||
target_node = self.get_node(edge.node2.id)
|
||||
|
||||
source_distance = source_node.attributes.get("vector_distance", 0) if source_node else 0
|
||||
target_distance = target_node.attributes.get("vector_distance", 0) if target_node else 0
|
||||
edge_distance = edge.attributes.get("vector_distance", 0)
|
||||
source_distance = source_node.attributes.get("vector_distance", 1) if source_node else 1
|
||||
target_distance = target_node.attributes.get("vector_distance", 1) if target_node else 1
|
||||
edge_distance = edge.attributes.get("vector_distance", 1)
|
||||
|
||||
total_distance = source_distance + target_distance + edge_distance
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
import asyncio
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from typing import Dict, List
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
from cognee.modules.graph.cognee_graph.CogneeGraph import CogneeGraph
|
||||
from cognee.infrastructure.databases.vector import get_vector_engine
|
||||
from cognee.infrastructure.databases.graph import get_graph_engine
|
||||
|
||||
from cognee.shared.utils import send_telemetry
|
||||
|
||||
def format_triplets(edges):
|
||||
print("\n\n\n")
|
||||
|
|
@ -44,24 +40,22 @@ def format_triplets(edges):
|
|||
triplet = (
|
||||
f"Node1: {node1_info}\n"
|
||||
f"Edge: {edge_info}\n"
|
||||
f"Node2: {node2_info}\n\n\n" # Add three blank lines for separation
|
||||
f"Node2: {node2_info}\n\n\n"
|
||||
)
|
||||
triplets.append(triplet)
|
||||
|
||||
return "".join(triplets)
|
||||
|
||||
|
||||
async def two_step_retriever(query: Dict[str, str], user: User = None) -> list:
|
||||
async def brute_force_triplet_search(query: str, user: User = None, top_k = 5) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
retrieved_results = await run_two_step_retriever(query, user)
|
||||
retrieved_results = await brute_force_search(query, user, top_k)
|
||||
|
||||
filtered_search_results = []
|
||||
|
||||
return retrieved_results
|
||||
|
||||
|
|
@ -82,18 +76,22 @@ def delete_duplicated_vector_db_elements(collections, results): #:TODO: This is
|
|||
return results_dict
|
||||
|
||||
|
||||
async def run_two_step_retriever(query: str, user, community_filter = []) -> list:
|
||||
async def brute_force_search(query: str, user: User, top_k: int, collections: List[str] = None) -> list:
|
||||
if collections is None:
|
||||
collections = ["entity_name", "text_summary_text", "entity_type_name", "document_chunk_text"]
|
||||
|
||||
vector_engine = get_vector_engine()
|
||||
graph_engine = await get_graph_engine()
|
||||
|
||||
collections = ["Entity_name", "TextSummary_text", 'EntityType_name', 'DocumentChunk_text']
|
||||
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[vector_engine.get_distances_of_collection(collection, query_text=query) for collection in collections]
|
||||
)
|
||||
|
||||
############################################# This part is a quick fix til we don't fix the vector db inconsistency
|
||||
node_distances = delete_duplicated_vector_db_elements(collections, results)# :TODO: Change when vector db is fixed
|
||||
# results_dict = {collection: result for collection, result in zip(collections, results)}
|
||||
############################################# :TODO: Change when vector db does not contain duplicates
|
||||
node_distances = delete_duplicated_vector_db_elements(collections, results)
|
||||
# node_distances = {collection: result for collection, result in zip(collections, results)}
|
||||
##############################################
|
||||
|
||||
memory_fragment = CogneeGraph()
|
||||
|
|
@ -104,16 +102,16 @@ async def run_two_step_retriever(query: str, user, community_filter = []) -> lis
|
|||
'name',
|
||||
'type',
|
||||
'text'],
|
||||
edge_properties_to_project=['id',
|
||||
'relationship_name'])
|
||||
edge_properties_to_project=['relationship_name'])
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_nodes(node_distances=node_distances)
|
||||
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)# :TODO: This should be coming from vector db
|
||||
#:TODO: Change when vectordb contains edge embeddings
|
||||
await memory_fragment.map_vector_distances_to_graph_edges(vector_engine, query)
|
||||
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=5)
|
||||
results = await memory_fragment.calculate_top_triplet_importances(k=top_k)
|
||||
|
||||
print(format_triplets(results))
|
||||
print(f'Query was the following:{query}' )
|
||||
send_telemetry("cognee.brute_force_triplet_search EXECUTION STARTED", user.id)
|
||||
|
||||
#:TODO: Once we have Edge pydantic models we should retrieve the exact edge and node objects from graph db
|
||||
return results
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
|
||||
async def two_step_retriever(query: Dict[str, str], user: User = None) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
retrieved_results = await diffusion_retriever(query, user)
|
||||
|
||||
filtered_search_results = []
|
||||
|
||||
|
||||
return retrieved_results
|
||||
|
||||
async def diffusion_retriever(query: str, user, community_filter = []) -> list:
|
||||
raise(NotImplementedError)
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict
|
||||
from cognee.shared.utils import send_telemetry
|
||||
from cognee.modules.users.models import User
|
||||
from cognee.modules.users.methods import get_default_user
|
||||
from cognee.modules.users.permissions.methods import get_document_ids_for_user
|
||||
|
||||
async def two_step_retriever(query: Dict[str, str], user: User = None) -> list:
|
||||
if user is None:
|
||||
user = await get_default_user()
|
||||
|
||||
if user is None:
|
||||
raise PermissionError("No user found in the system. Please create a user.")
|
||||
|
||||
own_document_ids = await get_document_ids_for_user(user.id)
|
||||
retrieved_results = await g_retriever(query, user)
|
||||
|
||||
filtered_search_results = []
|
||||
|
||||
|
||||
return retrieved_results
|
||||
|
||||
async def g_retriever(query: str, user, community_filter = []) -> list:
|
||||
raise(NotImplementedError)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import cognee
|
||||
import asyncio
|
||||
from cognee.pipelines.retriever.two_steps_retriever import two_step_retriever
|
||||
from cognee.pipelines.retriever.brute_force_triplet_search import brute_force_triplet_search
|
||||
|
||||
job_1 = """
|
||||
CV 1: Relevant
|
||||
|
|
@ -181,13 +181,13 @@ async def main(enable_steps):
|
|||
|
||||
# Step 4: Query insights
|
||||
if enable_steps.get("retriever"):
|
||||
await two_step_retriever('Who has Phd?')
|
||||
await brute_force_triplet_search('Who has Phd?')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Flags to enable/disable steps
|
||||
|
||||
rebuild_kg = True
|
||||
rebuild_kg = False
|
||||
retrieve = True
|
||||
steps_to_enable = {
|
||||
"prune_data": rebuild_kg,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue