fix: Move weaviate imports where needed

This commit is contained in:
Boris Arzentar 2024-04-20 19:54:11 +02:00
parent bb679c2dd7
commit 30055cc60c
10 changed files with 36 additions and 34 deletions

View file

@ -8,8 +8,8 @@ from cognee.modules.search.vector.search_similarity import search_similarity
from cognee.modules.search.graph.search_categories import search_categories
from cognee.modules.search.graph.search_neighbour import search_neighbour
from cognee.modules.search.graph.search_summary import search_summary
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure import infrastructure_config
class SearchType(Enum):
ADJACENT = 'ADJACENT'
@ -42,8 +42,7 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
async def specific_search(query_params: List[SearchParameters]) -> List:
graph_client = await get_graph_client(GraphDBType.NETWORKX)
await graph_client.load_graph_from_file()
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
graph = graph_client.graph
search_functions: Dict[SearchType, Callable] = {
@ -61,8 +60,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
search_func = search_functions.get(search_param.search_type)
if search_func:
# Schedule the coroutine for execution and store the task
full_params = {**search_param.params, 'graph': graph}
task = search_func(**full_params)
task = search_func(**search_param.params, graph = graph)
search_tasks.append(task)
# Use asyncio.gather to run all scheduled tasks concurrently

View file

@ -73,6 +73,8 @@ class Config:
connect_documents: bool = False
# Database parameters
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
if (
os.getenv("ENV") == "prod"
or os.getenv("ENV") == "dev"

View file

@ -84,7 +84,7 @@ class InfrastructureConfig():
config.weaviate_api_key,
embedding_engine = self.embedding_engine
)
except EnvironmentError:
except (EnvironmentError, ModuleNotFoundError):
if config.qdrant_url is None and config.qdrant_api_key is None:
raise EnvironmentError("Qdrant is not configured!")

View file

@ -5,7 +5,6 @@ from cognee.shared.data_models import GraphDBType
from cognee.infrastructure import infrastructure_config
from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter
from .neo4j_driver.adapter import Neo4jAdapter
config = Config()
config.load()
@ -15,15 +14,18 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
"""Factory function to get the appropriate graph client based on the graph type."""
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
if graph_type == GraphDBType.NETWORKX:
graph_client = NetworkXAdapter(filename = graph_file_path)
await graph_client.load_graph_from_file()
return graph_client
elif graph_type == GraphDBType.NEO4J:
return Neo4jAdapter(
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password
)
else:
raise ValueError("Unsupported graph database type.")
if graph_type == GraphDBType.NEO4J:
try:
from .neo4j_driver.adapter import Neo4jAdapter
return Neo4jAdapter(
graph_database_url = config.graph_database_url,
graph_database_username = config.graph_database_username,
graph_database_password = config.graph_database_password
)
except:
pass
graph_client = NetworkXAdapter(filename = graph_file_path)
await graph_client.load_graph_from_file()
return graph_client

View file

@ -2,10 +2,6 @@ import asyncio
from uuid import UUID
from typing import List, Optional
from multiprocessing import Pool
import weaviate
import weaviate.classes as wvc
import weaviate.classes.config as wvcc
from weaviate.classes.data import DataObject
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult
@ -17,6 +13,9 @@ class WeaviateAdapter(VectorDBInterface):
embedding_engine: EmbeddingEngine = None
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
import weaviate
import weaviate.classes as wvc
self.embedding_engine = embedding_engine
self.client = weaviate.connect_to_wcs(
@ -32,6 +31,8 @@ class WeaviateAdapter(VectorDBInterface):
return await self.embedding_engine.embed_text(data)
async def create_collection(self, collection_name: str):
import weaviate.classes.config as wvcc
event_loop = asyncio.get_event_loop()
def sync_create_collection():
@ -57,6 +58,8 @@ class WeaviateAdapter(VectorDBInterface):
return self.client.collections.get(collection_name)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
from weaviate.classes.data import DataObject
data_vectors = await self.embed_data(
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
@ -95,6 +98,8 @@ class WeaviateAdapter(VectorDBInterface):
limit: int = None,
with_vector: bool = False
):
import weaviate.classes as wvc
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")

View file

@ -3,9 +3,8 @@
from typing import Union, Dict
import networkx as nx
from neo4j import AsyncSession
from cognee.shared.data_models import GraphDBType
async def search_adjacent(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
"""
Find the neighbours of a given node in the graph and return their descriptions.
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.

View file

@ -2,11 +2,10 @@ from typing import Union, Dict
""" Search categories in the graph and return their summary attributes. """
from neo4j import AsyncSession
from cognee.shared.data_models import GraphDBType
import networkx as nx
async def search_categories(graph: Union[nx.Graph, AsyncSession], query_label: str, infrastructure_config: Dict):
async def search_categories(graph: Union[nx.Graph, any], query_label: str, infrastructure_config: Dict):
"""
Filter nodes in the graph that contain the specified label and return their summary attributes.
This function supports both NetworkX graphs and Neo4j graph databases.

View file

@ -1,13 +1,11 @@
""" Fetches the context of a given node in the graph"""
from typing import Union, Dict
from neo4j import AsyncSession
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
import networkx as nx
from cognee.shared.data_models import GraphDBType
async def search_neighbour(graph: Union[nx.Graph, AsyncSession], id: str, infrastructure_config: Dict,
async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict,
other_param: dict = None):
"""
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.

View file

@ -3,10 +3,9 @@
from typing import Union, Dict
import networkx as nx
from neo4j import AsyncSession
from cognee.shared.data_models import GraphDBType
async def search_summary(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]:
async def search_summary(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]:
"""
Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes.
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.

View file

@ -3,7 +3,7 @@ from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
async def search_similarity(query: str):
async def search_similarity(query: str, graph):
graph_db_type = infrastructure_config.get_config()["graph_engine"]
graph_client = await get_graph_client(graph_db_type)
@ -35,7 +35,7 @@ async def search_similarity(query: str):
for graph_node_data in graph_nodes:
graph_node = await graph_client.extract_node(graph_node_data["node_id"])
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
continue