fix: Move weaviate imports where needed
This commit is contained in:
parent
bb679c2dd7
commit
30055cc60c
10 changed files with 36 additions and 34 deletions
|
|
@ -8,8 +8,8 @@ from cognee.modules.search.vector.search_similarity import search_similarity
|
|||
from cognee.modules.search.graph.search_categories import search_categories
|
||||
from cognee.modules.search.graph.search_neighbour import search_neighbour
|
||||
from cognee.modules.search.graph.search_summary import search_summary
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
|
||||
class SearchType(Enum):
|
||||
ADJACENT = 'ADJACENT'
|
||||
|
|
@ -42,8 +42,7 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
|
|||
|
||||
|
||||
async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||
graph_client = await get_graph_client(GraphDBType.NETWORKX)
|
||||
await graph_client.load_graph_from_file()
|
||||
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
|
||||
graph = graph_client.graph
|
||||
|
||||
search_functions: Dict[SearchType, Callable] = {
|
||||
|
|
@ -61,8 +60,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
|||
search_func = search_functions.get(search_param.search_type)
|
||||
if search_func:
|
||||
# Schedule the coroutine for execution and store the task
|
||||
full_params = {**search_param.params, 'graph': graph}
|
||||
task = search_func(**full_params)
|
||||
task = search_func(**search_param.params, graph = graph)
|
||||
search_tasks.append(task)
|
||||
|
||||
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||
|
|
|
|||
|
|
@ -73,6 +73,8 @@ class Config:
|
|||
connect_documents: bool = False
|
||||
|
||||
# Database parameters
|
||||
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
|
||||
|
||||
if (
|
||||
os.getenv("ENV") == "prod"
|
||||
or os.getenv("ENV") == "dev"
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class InfrastructureConfig():
|
|||
config.weaviate_api_key,
|
||||
embedding_engine = self.embedding_engine
|
||||
)
|
||||
except EnvironmentError:
|
||||
except (EnvironmentError, ModuleNotFoundError):
|
||||
if config.qdrant_url is None and config.qdrant_api_key is None:
|
||||
raise EnvironmentError("Qdrant is not configured!")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from cognee.shared.data_models import GraphDBType
|
|||
from cognee.infrastructure import infrastructure_config
|
||||
from .graph_db_interface import GraphDBInterface
|
||||
from .networkx.adapter import NetworkXAdapter
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
|
@ -15,15 +14,18 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
|
|||
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
|
||||
|
||||
if graph_type == GraphDBType.NETWORKX:
|
||||
graph_client = NetworkXAdapter(filename = graph_file_path)
|
||||
await graph_client.load_graph_from_file()
|
||||
return graph_client
|
||||
elif graph_type == GraphDBType.NEO4J:
|
||||
return Neo4jAdapter(
|
||||
graph_database_url = config.graph_database_url,
|
||||
graph_database_username = config.graph_database_username,
|
||||
graph_database_password = config.graph_database_password
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported graph database type.")
|
||||
if graph_type == GraphDBType.NEO4J:
|
||||
try:
|
||||
from .neo4j_driver.adapter import Neo4jAdapter
|
||||
|
||||
return Neo4jAdapter(
|
||||
graph_database_url = config.graph_database_url,
|
||||
graph_database_username = config.graph_database_username,
|
||||
graph_database_password = config.graph_database_password
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
graph_client = NetworkXAdapter(filename = graph_file_path)
|
||||
await graph_client.load_graph_from_file()
|
||||
return graph_client
|
||||
|
|
|
|||
|
|
@ -2,10 +2,6 @@ import asyncio
|
|||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from multiprocessing import Pool
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
import weaviate.classes.config as wvcc
|
||||
from weaviate.classes.data import DataObject
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
|
@ -17,6 +13,9 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
embedding_engine: EmbeddingEngine = None
|
||||
|
||||
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
|
||||
self.embedding_engine = embedding_engine
|
||||
|
||||
self.client = weaviate.connect_to_wcs(
|
||||
|
|
@ -32,6 +31,8 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
return await self.embedding_engine.embed_text(data)
|
||||
|
||||
async def create_collection(self, collection_name: str):
|
||||
import weaviate.classes.config as wvcc
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
|
||||
def sync_create_collection():
|
||||
|
|
@ -57,6 +58,8 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
return self.client.collections.get(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
from weaviate.classes.data import DataObject
|
||||
|
||||
data_vectors = await self.embed_data(
|
||||
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
||||
|
||||
|
|
@ -95,6 +98,8 @@ class WeaviateAdapter(VectorDBInterface):
|
|||
limit: int = None,
|
||||
with_vector: bool = False
|
||||
):
|
||||
import weaviate.classes as wvc
|
||||
|
||||
if query_text is None and query_vector is None:
|
||||
raise ValueError("One of query_text or query_vector must be provided!")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,8 @@
|
|||
|
||||
from typing import Union, Dict
|
||||
import networkx as nx
|
||||
from neo4j import AsyncSession
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
async def search_adjacent(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
|
||||
async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
|
||||
"""
|
||||
Find the neighbours of a given node in the graph and return their descriptions.
|
||||
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
||||
|
|
|
|||
|
|
@ -2,11 +2,10 @@ from typing import Union, Dict
|
|||
|
||||
""" Search categories in the graph and return their summary attributes. """
|
||||
|
||||
from neo4j import AsyncSession
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
import networkx as nx
|
||||
|
||||
async def search_categories(graph: Union[nx.Graph, AsyncSession], query_label: str, infrastructure_config: Dict):
|
||||
async def search_categories(graph: Union[nx.Graph, any], query_label: str, infrastructure_config: Dict):
|
||||
"""
|
||||
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
||||
This function supports both NetworkX graphs and Neo4j graph databases.
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
""" Fetches the context of a given node in the graph"""
|
||||
from typing import Union, Dict
|
||||
|
||||
from neo4j import AsyncSession
|
||||
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
import networkx as nx
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
|
||||
async def search_neighbour(graph: Union[nx.Graph, AsyncSession], id: str, infrastructure_config: Dict,
|
||||
async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict,
|
||||
other_param: dict = None):
|
||||
"""
|
||||
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
||||
|
|
|
|||
|
|
@ -3,10 +3,9 @@
|
|||
|
||||
from typing import Union, Dict
|
||||
import networkx as nx
|
||||
from neo4j import AsyncSession
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
|
||||
async def search_summary(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]:
|
||||
async def search_summary(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]:
|
||||
"""
|
||||
Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes.
|
||||
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from cognee.infrastructure import infrastructure_config
|
|||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
|
||||
|
||||
async def search_similarity(query: str):
|
||||
async def search_similarity(query: str, graph):
|
||||
graph_db_type = infrastructure_config.get_config()["graph_engine"]
|
||||
|
||||
graph_client = await get_graph_client(graph_db_type)
|
||||
|
|
@ -35,7 +35,7 @@ async def search_similarity(query: str):
|
|||
|
||||
for graph_node_data in graph_nodes:
|
||||
graph_node = await graph_client.extract_node(graph_node_data["node_id"])
|
||||
|
||||
|
||||
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
|
||||
continue
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue