fix: Move weaviate imports where needed
This commit is contained in:
parent
bb679c2dd7
commit
30055cc60c
10 changed files with 36 additions and 34 deletions
|
|
@ -8,8 +8,8 @@ from cognee.modules.search.vector.search_similarity import search_similarity
|
||||||
from cognee.modules.search.graph.search_categories import search_categories
|
from cognee.modules.search.graph.search_categories import search_categories
|
||||||
from cognee.modules.search.graph.search_neighbour import search_neighbour
|
from cognee.modules.search.graph.search_neighbour import search_neighbour
|
||||||
from cognee.modules.search.graph.search_summary import search_summary
|
from cognee.modules.search.graph.search_summary import search_summary
|
||||||
from cognee.shared.data_models import GraphDBType
|
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
|
from cognee.infrastructure import infrastructure_config
|
||||||
|
|
||||||
class SearchType(Enum):
|
class SearchType(Enum):
|
||||||
ADJACENT = 'ADJACENT'
|
ADJACENT = 'ADJACENT'
|
||||||
|
|
@ -42,8 +42,7 @@ async def search(search_type: str, params: Dict[str, Any]) -> List:
|
||||||
|
|
||||||
|
|
||||||
async def specific_search(query_params: List[SearchParameters]) -> List:
|
async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||||
graph_client = await get_graph_client(GraphDBType.NETWORKX)
|
graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
|
||||||
await graph_client.load_graph_from_file()
|
|
||||||
graph = graph_client.graph
|
graph = graph_client.graph
|
||||||
|
|
||||||
search_functions: Dict[SearchType, Callable] = {
|
search_functions: Dict[SearchType, Callable] = {
|
||||||
|
|
@ -61,8 +60,7 @@ async def specific_search(query_params: List[SearchParameters]) -> List:
|
||||||
search_func = search_functions.get(search_param.search_type)
|
search_func = search_functions.get(search_param.search_type)
|
||||||
if search_func:
|
if search_func:
|
||||||
# Schedule the coroutine for execution and store the task
|
# Schedule the coroutine for execution and store the task
|
||||||
full_params = {**search_param.params, 'graph': graph}
|
task = search_func(**search_param.params, graph = graph)
|
||||||
task = search_func(**full_params)
|
|
||||||
search_tasks.append(task)
|
search_tasks.append(task)
|
||||||
|
|
||||||
# Use asyncio.gather to run all scheduled tasks concurrently
|
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,8 @@ class Config:
|
||||||
connect_documents: bool = False
|
connect_documents: bool = False
|
||||||
|
|
||||||
# Database parameters
|
# Database parameters
|
||||||
|
graph_database_provider: str = os.getenv("GRAPH_DB_PROVIDER", "NETWORKX")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
os.getenv("ENV") == "prod"
|
os.getenv("ENV") == "prod"
|
||||||
or os.getenv("ENV") == "dev"
|
or os.getenv("ENV") == "dev"
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ class InfrastructureConfig():
|
||||||
config.weaviate_api_key,
|
config.weaviate_api_key,
|
||||||
embedding_engine = self.embedding_engine
|
embedding_engine = self.embedding_engine
|
||||||
)
|
)
|
||||||
except EnvironmentError:
|
except (EnvironmentError, ModuleNotFoundError):
|
||||||
if config.qdrant_url is None and config.qdrant_api_key is None:
|
if config.qdrant_url is None and config.qdrant_api_key is None:
|
||||||
raise EnvironmentError("Qdrant is not configured!")
|
raise EnvironmentError("Qdrant is not configured!")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from cognee.shared.data_models import GraphDBType
|
||||||
from cognee.infrastructure import infrastructure_config
|
from cognee.infrastructure import infrastructure_config
|
||||||
from .graph_db_interface import GraphDBInterface
|
from .graph_db_interface import GraphDBInterface
|
||||||
from .networkx.adapter import NetworkXAdapter
|
from .networkx.adapter import NetworkXAdapter
|
||||||
from .neo4j_driver.adapter import Neo4jAdapter
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
|
|
@ -15,15 +14,18 @@ async def get_graph_client(graph_type: GraphDBType, graph_file_name: str = None)
|
||||||
"""Factory function to get the appropriate graph client based on the graph type."""
|
"""Factory function to get the appropriate graph client based on the graph type."""
|
||||||
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
|
graph_file_path = f"{infrastructure_config.get_config('database_directory_path')}/{graph_file_name if graph_file_name else config.graph_filename}"
|
||||||
|
|
||||||
if graph_type == GraphDBType.NETWORKX:
|
if graph_type == GraphDBType.NEO4J:
|
||||||
graph_client = NetworkXAdapter(filename = graph_file_path)
|
try:
|
||||||
await graph_client.load_graph_from_file()
|
from .neo4j_driver.adapter import Neo4jAdapter
|
||||||
return graph_client
|
|
||||||
elif graph_type == GraphDBType.NEO4J:
|
return Neo4jAdapter(
|
||||||
return Neo4jAdapter(
|
graph_database_url = config.graph_database_url,
|
||||||
graph_database_url = config.graph_database_url,
|
graph_database_username = config.graph_database_username,
|
||||||
graph_database_username = config.graph_database_username,
|
graph_database_password = config.graph_database_password
|
||||||
graph_database_password = config.graph_database_password
|
)
|
||||||
)
|
except:
|
||||||
else:
|
pass
|
||||||
raise ValueError("Unsupported graph database type.")
|
|
||||||
|
graph_client = NetworkXAdapter(filename = graph_file_path)
|
||||||
|
await graph_client.load_graph_from_file()
|
||||||
|
return graph_client
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,6 @@ import asyncio
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
import weaviate
|
|
||||||
import weaviate.classes as wvc
|
|
||||||
import weaviate.classes.config as wvcc
|
|
||||||
from weaviate.classes.data import DataObject
|
|
||||||
from ..vector_db_interface import VectorDBInterface
|
from ..vector_db_interface import VectorDBInterface
|
||||||
from ..models.DataPoint import DataPoint
|
from ..models.DataPoint import DataPoint
|
||||||
from ..models.ScoredResult import ScoredResult
|
from ..models.ScoredResult import ScoredResult
|
||||||
|
|
@ -17,6 +13,9 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
embedding_engine: EmbeddingEngine = None
|
embedding_engine: EmbeddingEngine = None
|
||||||
|
|
||||||
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
|
def __init__(self, url: str, api_key: str, embedding_engine: EmbeddingEngine):
|
||||||
|
import weaviate
|
||||||
|
import weaviate.classes as wvc
|
||||||
|
|
||||||
self.embedding_engine = embedding_engine
|
self.embedding_engine = embedding_engine
|
||||||
|
|
||||||
self.client = weaviate.connect_to_wcs(
|
self.client = weaviate.connect_to_wcs(
|
||||||
|
|
@ -32,6 +31,8 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
return await self.embedding_engine.embed_text(data)
|
return await self.embedding_engine.embed_text(data)
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str):
|
async def create_collection(self, collection_name: str):
|
||||||
|
import weaviate.classes.config as wvcc
|
||||||
|
|
||||||
event_loop = asyncio.get_event_loop()
|
event_loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def sync_create_collection():
|
def sync_create_collection():
|
||||||
|
|
@ -57,6 +58,8 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
return self.client.collections.get(collection_name)
|
return self.client.collections.get(collection_name)
|
||||||
|
|
||||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||||
|
from weaviate.classes.data import DataObject
|
||||||
|
|
||||||
data_vectors = await self.embed_data(
|
data_vectors = await self.embed_data(
|
||||||
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
list(map(lambda data_point: data_point.get_embeddable_data(), data_points)))
|
||||||
|
|
||||||
|
|
@ -95,6 +98,8 @@ class WeaviateAdapter(VectorDBInterface):
|
||||||
limit: int = None,
|
limit: int = None,
|
||||||
with_vector: bool = False
|
with_vector: bool = False
|
||||||
):
|
):
|
||||||
|
import weaviate.classes as wvc
|
||||||
|
|
||||||
if query_text is None and query_vector is None:
|
if query_text is None and query_vector is None:
|
||||||
raise ValueError("One of query_text or query_vector must be provided!")
|
raise ValueError("One of query_text or query_vector must be provided!")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,8 @@
|
||||||
|
|
||||||
from typing import Union, Dict
|
from typing import Union, Dict
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from neo4j import AsyncSession
|
|
||||||
from cognee.shared.data_models import GraphDBType
|
from cognee.shared.data_models import GraphDBType
|
||||||
async def search_adjacent(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
|
async def search_adjacent(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: dict = None) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Find the neighbours of a given node in the graph and return their descriptions.
|
Find the neighbours of a given node in the graph and return their descriptions.
|
||||||
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,10 @@ from typing import Union, Dict
|
||||||
|
|
||||||
""" Search categories in the graph and return their summary attributes. """
|
""" Search categories in the graph and return their summary attributes. """
|
||||||
|
|
||||||
from neo4j import AsyncSession
|
|
||||||
from cognee.shared.data_models import GraphDBType
|
from cognee.shared.data_models import GraphDBType
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
async def search_categories(graph: Union[nx.Graph, AsyncSession], query_label: str, infrastructure_config: Dict):
|
async def search_categories(graph: Union[nx.Graph, any], query_label: str, infrastructure_config: Dict):
|
||||||
"""
|
"""
|
||||||
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
Filter nodes in the graph that contain the specified label and return their summary attributes.
|
||||||
This function supports both NetworkX graphs and Neo4j graph databases.
|
This function supports both NetworkX graphs and Neo4j graph databases.
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
""" Fetches the context of a given node in the graph"""
|
""" Fetches the context of a given node in the graph"""
|
||||||
from typing import Union, Dict
|
from typing import Union, Dict
|
||||||
|
|
||||||
from neo4j import AsyncSession
|
|
||||||
|
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from cognee.shared.data_models import GraphDBType
|
from cognee.shared.data_models import GraphDBType
|
||||||
|
|
||||||
async def search_neighbour(graph: Union[nx.Graph, AsyncSession], id: str, infrastructure_config: Dict,
|
async def search_neighbour(graph: Union[nx.Graph, any], id: str, infrastructure_config: Dict,
|
||||||
other_param: dict = None):
|
other_param: dict = None):
|
||||||
"""
|
"""
|
||||||
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
Search for nodes that share the same 'layer_uuid' as the specified node and return their descriptions.
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,9 @@
|
||||||
|
|
||||||
from typing import Union, Dict
|
from typing import Union, Dict
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from neo4j import AsyncSession
|
|
||||||
from cognee.shared.data_models import GraphDBType
|
from cognee.shared.data_models import GraphDBType
|
||||||
|
|
||||||
async def search_summary(graph: Union[nx.Graph, AsyncSession], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]:
|
async def search_summary(graph: Union[nx.Graph, any], query: str, infrastructure_config: Dict, other_param: str = None) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes.
|
Filter nodes based on a condition (such as containing 'SUMMARY' in their identifiers) and return their summary attributes.
|
||||||
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
Supports both NetworkX graphs and Neo4j graph databases based on the configuration.
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from cognee.infrastructure import infrastructure_config
|
||||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||||
|
|
||||||
|
|
||||||
async def search_similarity(query: str):
|
async def search_similarity(query: str, graph):
|
||||||
graph_db_type = infrastructure_config.get_config()["graph_engine"]
|
graph_db_type = infrastructure_config.get_config()["graph_engine"]
|
||||||
|
|
||||||
graph_client = await get_graph_client(graph_db_type)
|
graph_client = await get_graph_client(graph_db_type)
|
||||||
|
|
@ -35,7 +35,7 @@ async def search_similarity(query: str):
|
||||||
|
|
||||||
for graph_node_data in graph_nodes:
|
for graph_node_data in graph_nodes:
|
||||||
graph_node = await graph_client.extract_node(graph_node_data["node_id"])
|
graph_node = await graph_client.extract_node(graph_node_data["node_id"])
|
||||||
|
|
||||||
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
|
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue