fix: Move weaviate imports where needed

This commit is contained in:
Boris Arzentar 2024-04-20 19:54:11 +02:00
parent bb679c2dd7
commit 30055cc60c
10 changed files with 36 additions and 34 deletions

View file

@ -8,8 +8,8 @@ from cognee.modules.search.vector.search_similarity import search_similarity
from cognee.modules.search.graph.search_categories import search_categories from cognee.modules.search.graph.search_categories import search_categories
from cognee.modules.search.graph.search_neighbour import search_neighbour from cognee.modules.search.graph.search_neighbour import search_neighbour
from cognee.modules.search.graph.search_summary import search_summary from cognee.modules.search.graph.search_summary import search_summary
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.infrastructure import infrastructure_config
class SearchType(Enum): class SearchType(Enum):
ADJACENT = 'ADJACENT' ADJACENT = 'ADJACENT'
@ -42,8 +42,7 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
async def specific_search(query_params: List[SearchParameters]) -> List: async def specific_search(query_params: List[SearchParameters]) -> List:
graph_client = await get_graph_client(GraphDBType.NETWORKX) graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
await graph_client.load_graph_from_file()
graph = graph_client.graph graph = graph_client.graph
search_functions: Dict[SearchType, Callable] = { search_functions: Dict[SearchType, Callable] = {
@ -61,8 +60,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
search_func = search_functions.get(search_param.search_type) search_func = search_functions.get(search_param.search_type)
if search_func: if search_func:
# Schedule the coroutine for execution and store the task # Schedule the coroutine for execution and store the task
full_params = {**search_param.params, 'graph': graph} task = search_func(**search_param.params, graph = graph)
task = search_func(**full_params)
search_tasks.append(task) search_tasks.append(task)
# Use asyncio.gather to run all scheduled tasks concurrently # Use asyncio.gather to run all scheduled tasks concurrently

View file

@ -73,6 +73,8 @@ class Config:
connect_documents: bool = False connect_documents: bool = False
# Database parameters # Database parameters
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
if ( if (
os.getenv("ENV") == "prod" os.getenv("ENV") == "prod"
or os.getenv("ENV") == "dev" or os.getenv("ENV") == "dev"

View file

@ -84,7 +84,7 @@ class InfrastructureConfig():
config.weaviate_api_key, config.weaviate_api_key,
embedding_engine = self.embedding_engine embedding_engine = self.embedding_engine
) )
except EnvironmentError: except (EnvironmentError, ModuleNotFoundError):
if config.qdrant_url is None and config.qdrant_api_key is None: if config.qdrant_url is None and config.qdrant_api_key is None:
raise EnvironmentError("Qdrant is not configured!") raise EnvironmentError("Qdrant is not configured!")

View file

@ -5,7 +5,6 @@ from cognee.shared.data_models import GraphDBType
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from .graph_db_interface import GraphDBInterface from .graph_db_interface import GraphDBInterface
from .networkx.adapter import NetworkXAdapter from .networkx.adapter import NetworkXAdapter
from .neo4j_driver.adapter import Neo4jAdapter
config = Config() config = Config()
config.load() config.load()
@ -15,15 +14,18 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
"""Factory function to get the appropriate graph client based on the graph type.""" """Factory function to get the appropriate graph client based on the graph type."""
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}" graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
if graph_type == GraphDBType.NETWORKX: if graph_type == GraphDBType.NEO4J:
graph_client = NetworkXAdapter(filename = graph_file_path) try:
await graph_client.load_graph_from_file() from .neo4j_driver.adapter import Neo4jAdapter
return graph_client
elif graph_type == GraphDBType.NEO4J: return Neo4jAdapter(
return Neo4jAdapter( graph_database_url = config.graph_database_url,
graph_database_url = config.graph_database_url, graph_database_username = config.graph_database_username,
graph_database_username = config.graph_database_username, graph_database_password = config.graph_database_password
graph_database_password = config.graph_database_password )
) except:
else: pass
raise ValueError("Unsupported graph database type.")
graph_client = NetworkXAdapter(filename = graph_file_path)
await graph_client.load_graph_from_file()
return graph_client

View file

@ -2,10 +2,6 @@ import asyncio
from uuid import UUID from uuid import UUID
from typing import List, Optional from typing import List, Optional
from multiprocessing import Pool from multiprocessing import Pool
import weaviate
import weaviate.classes as wvc
import weaviate.classes.config as wvcc
from weaviate.classes.data import DataObject
from ..vector_db_interface import VectorDBInterface from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult from ..models.ScoredResult import ScoredResult
@ -17,6 +13,9 @@ class WeaviateAdapter(VectorDBInterface):
embedding_engine: EmbeddingEngine = None embedding_engine: EmbeddingEngine = None
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine): def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
import weaviate
import weaviate.classes as wvc
self.embedding_engine = embedding_engine self.embedding_engine = embedding_engine
self.client = weaviate.connect_to_wcs( self.client = weaviate.connect_to_wcs(
@ -32,6 +31,8 @@ class WeaviateAdapter(VectorDBInterface):
return await self.embedding_engine.embed_text(data) return await self.embedding_engine.embed_text(data)
async def create_collection(self, collection_name: str): async def create_collection(self, collection_name: str):
import weaviate.classes.config as wvcc
event_loop = asyncio.get_event_loop() event_loop = asyncio.get_event_loop()
def sync_create_collection(): def sync_create_collection():
@ -57,6 +58,8 @@ class WeaviateAdapter(VectorDBInterface):
return self.client.collections.get(collection_name) return self.client.collections.get(collection_name)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]): async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
from weaviate.classes.data import DataObject
data_vectors = await self.embed_data( data_vectors = await self.embed_data(
list(map(lambda data_point: data_point.get_embeddable_data(), data_points))) list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
@ -95,6 +98,8 @@ class WeaviateAdapter(VectorDBInterface):
limit: int = None, limit: int = None,
with_vector: bool = False with_vector: bool = False
): ):
import weaviate.classes as wvc
if query_text is None and query_vector is None: if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!") raise ValueError("One of query_text or query_vector must be provided!")

View file

@ -3,9 +3,8 @@
from typing import Union, Dict from typing import Union, Dict
import networkx as nx import networkx as nx
from neo4j import AsyncSession
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
async def search_adjacent(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]: async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
""" """
Find the neighbours of a given node in the graph and return their descriptions. Find the neighbours of a given node in the graph and return their descriptions.
Supports both NetworkX graphs and Neo4j graph databases based on the configuration. Supports both NetworkX graphs and Neo4j graph databases based on the configuration.

View file

@ -2,11 +2,10 @@ from typing import Union, Dict
""" Search categories in the graph and return their summary attributes. """ """ Search categories in the graph and return their summary attributes. """
from neo4j import AsyncSession
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
import networkx as nx import networkx as nx
async def search_categories(graph: Union[nx.Graph, AsyncSession], query_label: str, infrastructure_config: Dict): async def search_categories(graph: Union[nx.Graph, any], query_label: str, infrastructure_config: Dict):
""" """
Filter nodes in the graph that contain the specified label and return their summary attributes. Filter nodes in the graph that contain the specified label and return their summary attributes.
This function supports both NetworkX graphs and Neo4j graph databases. This function supports both NetworkX graphs and Neo4j graph databases.

View file

@ -1,13 +1,11 @@
""" Fetches the context of a given node in the graph""" """ Fetches the context of a given node in the graph"""
from typing import Union, Dict from typing import Union, Dict
from neo4j import AsyncSession
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
import networkx as nx import networkx as nx
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
async def search_neighbour(graph: Union[nx.Graph, AsyncSession], id: str, infrastructure_config: Dict, async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict,
other_param: dict = None): other_param: dict = None):
""" """
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions. Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.

View file

@ -3,10 +3,9 @@
from typing import Union, Dict from typing import Union, Dict
import networkx as nx import networkx as nx
from neo4j import AsyncSession
from cognee.shared.data_models import GraphDBType from cognee.shared.data_models import GraphDBType
async def search_summary(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]: async def search_summary(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]:
""" """
Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes. Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes.
Supports both NetworkX graphs and Neo4j graph databases based on the configuration. Supports both NetworkX graphs and Neo4j graph databases based on the configuration.

View file

@ -3,7 +3,7 @@ from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
async def search_similarity(query: str): async def search_similarity(query: str, graph):
graph_db_type = infrastructure_config.get_config()["graph_engine"] graph_db_type = infrastructure_config.get_config()["graph_engine"]
graph_client = await get_graph_client(graph_db_type) graph_client = await get_graph_client(graph_db_type)
@ -35,7 +35,7 @@ async def search_similarity(query: str):
for graph_node_data in graph_nodes: for graph_node_data in graph_nodes:
graph_node = await graph_client.extract_node(graph_node_data["node_id"]) graph_node = await graph_client.extract_node(graph_node_data["node_id"])
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node: if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
continue continue