Merge branch 'feat/COG-113-integrate-weviate' into feat/COG-118-remove-unused-code
This commit is contained in:
commit
a727cce00f
27 changed files with 895 additions and 1097 deletions
File diff suppressed because one or more lines are too long
|
|
@ -1,23 +1,19 @@
|
|||
import asyncio
|
||||
# import logging
|
||||
from typing import List, Union
|
||||
from qdrant_client import models
|
||||
import instructor
|
||||
from openai import OpenAI
|
||||
from unstructured.cleaners.core import clean
|
||||
from unstructured.partition.pdf import partition_pdf
|
||||
from cognee.infrastructure.databases.vector.qdrant.QDrantAdapter import CollectionConfig
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
|
||||
from cognee.modules.cognify.llm.label_content import label_content
|
||||
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
|
||||
from cognee.modules.cognify.graph.add_node_connections import add_node_connection, graph_ready_output, \
|
||||
from cognee.modules.cognify.llm.summarize_content import summarize_content
|
||||
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
||||
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, graph_ready_output, \
|
||||
connect_nodes_in_graph, extract_node_descriptions
|
||||
from cognee.modules.cognify.graph.add_propositions import append_to_graph
|
||||
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
|
||||
from cognee.modules.cognify.llm.add_node_connection_embeddings import process_items
|
||||
from cognee.modules.cognify.llm.label_content import label_content
|
||||
from cognee.modules.cognify.llm.summarize_content import summarize_content
|
||||
from cognee.modules.cognify.vector.batch_search import adapted_qdrant_batch_search
|
||||
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
|
||||
from cognee.modules.cognify.vector.add_propositions import add_propositions
|
||||
|
||||
from cognee.config import Config
|
||||
|
|
@ -28,10 +24,11 @@ from cognee.shared.data_models import DefaultContentPrediction, KnowledgeGraph,
|
|||
SummarizedContent, LabeledContent
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
|
||||
from cognee.shared.data_models import GraphDBType
|
||||
from cognee.infrastructure.databases.vector.get_vector_database import get_vector_database
|
||||
from cognee.infrastructure.databases.relational import DuckDBAdapter
|
||||
from cognee.modules.cognify.graph.add_document_node import add_document_node
|
||||
from cognee.modules.cognify.graph.initialize_graph import initialize_graph
|
||||
from cognee.infrastructure.databases.vector import CollectionConfig, VectorConfig
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
|
@ -76,7 +73,7 @@ async def cognify(datasets: Union[str, List[str]] = None, graphdatamodel: object
|
|||
|
||||
async def process_text(input_text: str, file_metadata: dict):
|
||||
print(f"Processing document ({file_metadata['id']})")
|
||||
|
||||
|
||||
classified_categories = []
|
||||
|
||||
try:
|
||||
|
|
@ -133,31 +130,22 @@ async def process_text(input_text: str, file_metadata: dict):
|
|||
|
||||
# Run the async function for each set of cognitive layers
|
||||
layer_graphs = await generate_graph_per_layer(input_text, cognitive_layers)
|
||||
# print(layer_graphs)
|
||||
|
||||
print(f"Document ({file_metadata['id']}) layer graphs created")
|
||||
|
||||
# G = await create_semantic_graph(graph_model_instance)
|
||||
|
||||
await add_classification_nodes(f"DOCUMENT:{file_metadata['id']}", classified_categories[0])
|
||||
|
||||
# print(file_metadata['summary'])
|
||||
await add_summary_nodes(f"DOCUMENT:{file_metadata['id']}", {"summary": file_metadata["summary"]})
|
||||
|
||||
await add_summary_nodes(f"DOCUMENT:{file_metadata['id']}", {"summary": file_metadata['summary']})
|
||||
await add_label_nodes(f"DOCUMENT:{file_metadata['id']}", {"content_labels": file_metadata["content_labels"]})
|
||||
|
||||
# print(file_metadata['content_labels'])
|
||||
|
||||
await add_label_nodes(f"DOCUMENT:{file_metadata['id']}", {"content_labels": file_metadata['content_labels']})
|
||||
|
||||
unique_layer_uuids = await append_to_graph(layer_graphs, classified_categories[0])
|
||||
await append_to_graph(layer_graphs, classified_categories[0])
|
||||
|
||||
print(f"Document ({file_metadata['id']}) layers connected")
|
||||
|
||||
print("Document categories, summaries and metadata are: ", str(classified_categories))
|
||||
|
||||
|
||||
print(f"Document categories, summaries and metadata are ",str(classified_categories) )
|
||||
|
||||
print(f"Document metadata is ",str(file_metadata) )
|
||||
print("Document metadata is: ", str(file_metadata))
|
||||
|
||||
graph_client = get_graph_client(GraphDBType.NETWORKX)
|
||||
|
||||
|
|
@ -165,45 +153,34 @@ async def process_text(input_text: str, file_metadata: dict):
|
|||
|
||||
graph = graph_client.graph
|
||||
|
||||
# # Extract the node descriptions
|
||||
node_descriptions = await extract_node_descriptions(graph.nodes(data = True))
|
||||
# print(node_descriptions)
|
||||
|
||||
unique_layer_uuids = set(node["layer_decomposition_uuid"] for node in node_descriptions)
|
||||
nodes_by_layer = await group_nodes_by_layer(node_descriptions)
|
||||
|
||||
unique_layers = nodes_by_layer.keys()
|
||||
|
||||
collection_config = CollectionConfig(
|
||||
vector_config = {
|
||||
"content": models.VectorParams(
|
||||
distance = models.Distance.COSINE,
|
||||
size = 3072
|
||||
)
|
||||
},
|
||||
vector_config = VectorConfig(
|
||||
distance = "Cosine",
|
||||
size = 3072
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
for layer in unique_layer_uuids:
|
||||
db = get_vector_database()
|
||||
await db.create_collection(layer, collection_config)
|
||||
db_engine = infrastructure_config.get_config()["vector_engine"]
|
||||
|
||||
for layer in unique_layers:
|
||||
await db_engine.create_collection(layer, collection_config)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
await add_propositions(node_descriptions)
|
||||
await add_propositions(nodes_by_layer)
|
||||
|
||||
grouped_data = await add_node_connection(node_descriptions)
|
||||
# print("we are here, grouped_data", grouped_data)
|
||||
results = await resolve_cross_graph_references(nodes_by_layer)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
relationships = graph_ready_output(results)
|
||||
|
||||
relationship_dict = await process_items(grouped_data, unique_layer_uuids, llm_client)
|
||||
# print("we are here", relationship_dict[0])
|
||||
|
||||
results = await adapted_qdrant_batch_search(relationship_dict, db)
|
||||
# print(results)
|
||||
|
||||
relationship_d = graph_ready_output(results)
|
||||
# print(relationship_d)
|
||||
|
||||
connect_nodes_in_graph(graph, relationship_d)
|
||||
connect_nodes_in_graph(graph, relationships)
|
||||
|
||||
print(f"Document ({file_metadata['id']}) processed")
|
||||
|
||||
|
|
@ -220,4 +197,4 @@ if __name__ == "__main__":
|
|||
print(graph_url)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -1,21 +1,32 @@
|
|||
from cognee.config import Config
|
||||
from .databases.relational import SqliteEngine, DatabaseEngine
|
||||
from .databases.vector import WeaviateAdapter, VectorDBInterface
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
||||
class InfrastructureConfig():
|
||||
database_engine: DatabaseEngine = None
|
||||
vector_engine: VectorDBInterface = None
|
||||
|
||||
def get_config(self) -> dict:
|
||||
if self.database_engine is None:
|
||||
self.database_engine = SqliteEngine(config.db_path, config.db_name)
|
||||
|
||||
if self.vector_engine is None:
|
||||
self.vector_engine = WeaviateAdapter(
|
||||
config.weaviate_url,
|
||||
config.weaviate_api_key,
|
||||
config.openai_key
|
||||
)
|
||||
|
||||
return {
|
||||
"database_engine": self.database_engine
|
||||
"database_engine": self.database_engine,
|
||||
"vector_engine": self.vector_engine
|
||||
}
|
||||
|
||||
def set_config(self, new_config: dict):
|
||||
self.database_engine = new_config["database_engine"]
|
||||
self.vector_engine = new_config["vector_engine"]
|
||||
|
||||
infrastructure_config = InfrastructureConfig()
|
||||
|
|
|
|||
|
|
@ -1,2 +1,7 @@
|
|||
from .get_vector_database import get_vector_database
|
||||
from .qdrant import QDrantAdapter, CollectionConfig
|
||||
from .qdrant import QDrantAdapter
|
||||
from .models.DataPoint import DataPoint
|
||||
from .models.VectorConfig import VectorConfig
|
||||
from .models.CollectionConfig import CollectionConfig
|
||||
from .weaviate_db import WeaviateAdapter
|
||||
from .vector_db_interface import VectorDBInterface
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from cognee.config import Config
|
||||
from .qdrant import QDrantAdapter
|
||||
# from .qdrant import QDrantAdapter
|
||||
from .weaviate_db import WeaviateAdapter
|
||||
|
||||
config = Config()
|
||||
config.load()
|
||||
|
||||
def get_vector_database():
|
||||
return QDrantAdapter(config.qdrant_path, config.qdrant_url, config.qdrant_api_key)
|
||||
# return QDrantAdapter(config.qdrant_path, config.qdrant_url, config.qdrant_api_key)
|
||||
return WeaviateAdapter(config.weaviate_url, config.weaviate_api_key, config.openai_key)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
from pydantic import BaseModel
|
||||
from .VectorConfig import VectorConfig
|
||||
|
||||
class CollectionConfig(BaseModel):
|
||||
vector_config: VectorConfig
|
||||
10
cognee/infrastructure/databases/vector/models/DataPoint.py
Normal file
10
cognee/infrastructure/databases/vector/models/DataPoint.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
class DataPoint(BaseModel):
|
||||
id: str
|
||||
payload: Dict[str, str]
|
||||
embed_field: str
|
||||
|
||||
def get_embeddable_data(self):
|
||||
return self.payload[self.embed_field]
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from uuid import UUID
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ScoredResult(BaseModel):
|
||||
id: UUID
|
||||
score: int
|
||||
payload: Dict[str, Any]
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from typing import Literal
|
||||
from pydantic import BaseModel
|
||||
|
||||
class VectorConfig(BaseModel):
|
||||
distance: Literal['Cosine', 'Dot']
|
||||
size: int
|
||||
|
|
@ -1,19 +1,59 @@
|
|||
from typing import List, Optional, Dict
|
||||
from pydantic import BaseModel, Field
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
# from pydantic import BaseModel, Field
|
||||
from qdrant_client import AsyncQdrantClient, models
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..models.VectorConfig import VectorConfig
|
||||
from ..models.CollectionConfig import CollectionConfig
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
|
||||
class CollectionConfig(BaseModel, extra = "forbid"):
|
||||
vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
||||
hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
||||
optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
|
||||
quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
|
||||
# class CollectionConfig(BaseModel, extra = "forbid"):
|
||||
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
|
||||
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
|
||||
# optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
|
||||
# quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
|
||||
|
||||
async def embed_data(data: str):
|
||||
llm_client = get_llm_client()
|
||||
|
||||
return await llm_client.async_get_embedding_with_backoff(data)
|
||||
|
||||
async def convert_to_qdrant_point(data_point: DataPoint):
|
||||
return models.PointStruct(
|
||||
id = data_point.id,
|
||||
payload = data_point.payload,
|
||||
vector = {
|
||||
"text": await embed_data(data_point.get_embeddable_data())
|
||||
}
|
||||
)
|
||||
|
||||
def create_vector_config(vector_config: VectorConfig):
|
||||
return models.VectorParams(
|
||||
size = vector_config.size,
|
||||
distance = vector_config.distance
|
||||
)
|
||||
|
||||
def create_hnsw_config(hnsw_config: Dict):
|
||||
if hnsw_config is not None:
|
||||
return models.HnswConfig()
|
||||
return None
|
||||
|
||||
def create_optimizers_config(optimizers_config: Dict):
|
||||
if optimizers_config is not None:
|
||||
return models.OptimizersConfig()
|
||||
return None
|
||||
|
||||
def create_quantization_config(quantization_config: Dict):
|
||||
if quantization_config is not None:
|
||||
return models.QuantizationConfig()
|
||||
return None
|
||||
|
||||
class QDrantAdapter(VectorDBInterface):
|
||||
qdrant_url: str = None
|
||||
qdrant_path: str = None
|
||||
qdrant_api_key: str = None
|
||||
|
||||
|
||||
def __init__(self, qdrant_path, qdrant_url, qdrant_api_key):
|
||||
if qdrant_path is not None:
|
||||
self.qdrant_path = qdrant_path
|
||||
|
|
@ -46,43 +86,49 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
return await client.create_collection(
|
||||
collection_name = collection_name,
|
||||
vectors_config = collection_config.vector_config,
|
||||
hnsw_config = collection_config.hnsw_config,
|
||||
optimizers_config = collection_config.optimizers_config,
|
||||
quantization_config = collection_config.quantization_config
|
||||
vectors_config = {
|
||||
"text": create_vector_config(collection_config.vector_config)
|
||||
}
|
||||
)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points):
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
awaitables = []
|
||||
|
||||
for point in data_points:
|
||||
awaitables.append(convert_to_qdrant_point(point))
|
||||
|
||||
points = await asyncio.gather(*awaitables)
|
||||
|
||||
return await client.upload_points(
|
||||
collection_name = collection_name,
|
||||
points = data_points
|
||||
points = points
|
||||
)
|
||||
|
||||
async def search(self, collection_name: str, query_vector: List[float], limit: int, with_vector: bool = False):
|
||||
async def search(self, collection_name: str, query_text: str, limit: int, with_vector: bool = False):
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
return await client.search(
|
||||
collection_name = collection_name,
|
||||
query_vector = (
|
||||
"content", query_vector),
|
||||
query_vector = models.NamedVector(
|
||||
name = "text",
|
||||
vector = await embed_data(query_text)
|
||||
),
|
||||
limit = limit,
|
||||
with_vectors = with_vector
|
||||
)
|
||||
|
||||
|
||||
async def batch_search(self, collection_name: str, embeddings: List[List[float]],
|
||||
with_vectors: List[bool] = None):
|
||||
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
|
||||
"""
|
||||
Perform batch search in a Qdrant collection with dynamic search requests.
|
||||
|
||||
Args:
|
||||
- collection_name (str): Name of the collection to search in.
|
||||
- embeddings (List[List[float]]): List of embeddings to search for.
|
||||
- limits (List[int]): List of result limits for each search request.
|
||||
- with_vectors (List[bool], optional): List indicating whether to return vectors for each search request.
|
||||
Defaults to None, in which case vectors are not returned.
|
||||
- query_texts (List[str]): List of query texts to search for.
|
||||
- limit (int): List of result limits for search requests.
|
||||
- with_vectors (bool, optional): Bool indicating whether to return vectors for search requests.
|
||||
|
||||
Returns:
|
||||
- results: The search results from Qdrant.
|
||||
|
|
@ -90,30 +136,32 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
# Default with_vectors to False for each request if not provided
|
||||
if with_vectors is None:
|
||||
with_vectors = [False] * len(embeddings)
|
||||
|
||||
# Ensure with_vectors list matches the length of embeddings and limits
|
||||
if len(with_vectors) != len(embeddings):
|
||||
raise ValueError("The length of with_vectors must match the length of embeddings and limits")
|
||||
vectors = await asyncio.gather(*[embed_data(query_text) for query_text in query_texts])
|
||||
|
||||
# Generate dynamic search requests based on the provided embeddings
|
||||
requests = [
|
||||
models.SearchRequest(vector=models.NamedVector(
|
||||
name="content",
|
||||
vector=embedding,
|
||||
),
|
||||
# vector= embedding,
|
||||
limit=3,
|
||||
with_vector=False
|
||||
) for embedding in [embeddings]
|
||||
models.SearchRequest(
|
||||
vector = models.NamedVector(
|
||||
name = "text",
|
||||
vector = vector
|
||||
),
|
||||
limit = limit,
|
||||
with_vector = with_vectors
|
||||
) for vector in vectors
|
||||
]
|
||||
|
||||
# Perform batch search with the dynamically generated requests
|
||||
results = await client.search_batch(
|
||||
collection_name=collection_name,
|
||||
requests=requests
|
||||
collection_name = collection_name,
|
||||
requests = requests
|
||||
)
|
||||
|
||||
return results
|
||||
return [filter(lambda result: result.score > 0.9, result_group) for result_group in results]
|
||||
|
||||
async def prune(self):
|
||||
client = self.get_qdrant_client()
|
||||
|
||||
response = await client.get_collections()
|
||||
|
||||
for collection in response.collections:
|
||||
await client.delete_collection(collection.name)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
from typing import List, Protocol
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol
|
||||
from .models.CollectionConfig import CollectionConfig
|
||||
from .models.DataPoint import DataPoint
|
||||
|
||||
class VectorDBInterface(Protocol):
|
||||
""" Collections """
|
||||
|
|
@ -8,7 +9,7 @@ class VectorDBInterface(Protocol):
|
|||
async def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
collection_config: object
|
||||
collection_config: CollectionConfig
|
||||
): raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
|
|
@ -43,7 +44,7 @@ class VectorDBInterface(Protocol):
|
|||
async def create_data_points(
|
||||
self,
|
||||
collection_name: str,
|
||||
data_points
|
||||
data_points: List[DataPoint]
|
||||
): raise NotImplementedError
|
||||
|
||||
# @abstractmethod
|
||||
|
|
@ -67,12 +68,13 @@ class VectorDBInterface(Protocol):
|
|||
# collection_name: str,
|
||||
# data_point_id: str
|
||||
# ): raise NotImplementedError
|
||||
|
||||
""" Search """
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query_vector: List[float],
|
||||
query_text: str,
|
||||
limit: int,
|
||||
with_vector: bool = False
|
||||
|
||||
|
|
@ -82,7 +84,7 @@ class VectorDBInterface(Protocol):
|
|||
async def batch_search(
|
||||
self,
|
||||
collection_name: str,
|
||||
embeddings: List[List[float]],
|
||||
with_vectors: List[bool] = None
|
||||
query_texts: List[str],
|
||||
limit: int,
|
||||
with_vectors: bool = False
|
||||
): raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -1,417 +0,0 @@
|
|||
from weaviate.gql.get import HybridFusion
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
|
||||
from databases.vector.vector_db_interface import VectorDBInterface
|
||||
# from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from cognee.database.vectordb.loaders.loaders import _document_loader
|
||||
|
||||
class WeaviateVectorDB(VectorDBInterface):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_weaviate(embeddings=self.embeddings, namespace=self.namespace)
|
||||
|
||||
def init_weaviate(
|
||||
self,
|
||||
embeddings=OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY", "")),
|
||||
namespace=None,
|
||||
retriever_type="",
|
||||
):
|
||||
# Weaviate initialization logic
|
||||
auth_config = weaviate.auth.AuthApiKey(
|
||||
api_key=os.environ.get("WEAVIATE_API_KEY")
|
||||
)
|
||||
client = weaviate.Client(
|
||||
url=os.environ.get("WEAVIATE_URL"),
|
||||
auth_client_secret=auth_config,
|
||||
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
|
||||
if retriever_type == "single_document_context":
|
||||
retriever = WeaviateHybridSearchRetriever(
|
||||
client=client,
|
||||
index_name=namespace,
|
||||
text_key="text",
|
||||
attributes=[],
|
||||
embedding=embeddings,
|
||||
create_schema_if_missing=True,
|
||||
)
|
||||
return retriever
|
||||
elif retriever_type == "multi_document_context":
|
||||
retriever = WeaviateHybridSearchRetriever(
|
||||
client=client,
|
||||
index_name=namespace,
|
||||
text_key="text",
|
||||
attributes=[],
|
||||
embedding=embeddings,
|
||||
create_schema_if_missing=True,
|
||||
)
|
||||
return retriever
|
||||
else:
|
||||
return client
|
||||
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
||||
# store = InMemoryStore()
|
||||
# retriever = ParentDocumentRetriever(
|
||||
# vectorstore=vectorstore,
|
||||
# docstore=store,
|
||||
# child_splitter=child_splitter,
|
||||
# )
|
||||
|
||||
from marshmallow import Schema, fields
|
||||
|
||||
def create_document_structure(observation, params, metadata_schema_class=None):
|
||||
"""
|
||||
Create and validate a document structure with optional custom fields.
|
||||
|
||||
:param observation: Content of the document.
|
||||
:param params: Metadata information.
|
||||
:param metadata_schema_class: Custom metadata schema class (optional).
|
||||
:return: A list containing the validated document data.
|
||||
"""
|
||||
document_data = {"metadata": params, "page_content": observation}
|
||||
|
||||
def get_document_schema():
|
||||
class DynamicDocumentSchema(Schema):
|
||||
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
return DynamicDocumentSchema
|
||||
|
||||
# Validate and deserialize, defaulting to "1.0" if not provided
|
||||
CurrentDocumentSchema = get_document_schema()
|
||||
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||
return [loaded_document]
|
||||
|
||||
def _stuct(self, observation, params, metadata_schema_class=None):
|
||||
"""Utility function to create the document structure with optional custom fields."""
|
||||
# Construct document data
|
||||
document_data = {"metadata": params, "page_content": observation}
|
||||
|
||||
def get_document_schema():
|
||||
class DynamicDocumentSchema(Schema):
|
||||
metadata = fields.Nested(metadata_schema_class, required=True)
|
||||
page_content = fields.Str(required=True)
|
||||
|
||||
return DynamicDocumentSchema
|
||||
|
||||
# Validate and deserialize # Default to "1.0" if not provided
|
||||
CurrentDocumentSchema = get_document_schema()
|
||||
loaded_document = CurrentDocumentSchema().load(document_data)
|
||||
return [loaded_document]
|
||||
|
||||
async def add_memories(
|
||||
self,
|
||||
observation,
|
||||
loader_settings=None,
|
||||
params=None,
|
||||
namespace=None,
|
||||
metadata_schema_class=None,
|
||||
embeddings="hybrid",
|
||||
):
|
||||
# Update Weaviate memories here
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
params["user_id"] = self.user_id
|
||||
logging.info("User id is %s", self.user_id)
|
||||
retriever = self.init_weaviate(
|
||||
embeddings=OpenAIEmbeddings(),
|
||||
namespace=namespace,
|
||||
retriever_type="single_document_context",
|
||||
)
|
||||
if loader_settings:
|
||||
# Assuming _document_loader returns a list of documents
|
||||
documents = await _document_loader(observation, loader_settings)
|
||||
logging.info("here are the docs %s", str(documents))
|
||||
chunk_count = 0
|
||||
for doc_list in documents:
|
||||
for doc in doc_list:
|
||||
chunk_count += 1
|
||||
params["chunk_count"] = doc.metadata.get("chunk_count", "None")
|
||||
logging.info(
|
||||
"Loading document with provided loader settings %s", str(doc)
|
||||
)
|
||||
params["source"] = doc.metadata.get("source", "None")
|
||||
logging.info("Params are %s", str(params))
|
||||
retriever.add_documents(
|
||||
[Document(metadata=params, page_content=doc.page_content)]
|
||||
)
|
||||
else:
|
||||
chunk_count = 0
|
||||
from cognee.database.vectordb.chunkers.chunkers import (
|
||||
chunk_data,
|
||||
)
|
||||
|
||||
documents = [
|
||||
chunk_data(
|
||||
chunk_strategy="VANILLA",
|
||||
source_data=observation,
|
||||
chunk_size=300,
|
||||
chunk_overlap=20,
|
||||
)
|
||||
]
|
||||
for doc in documents[0]:
|
||||
chunk_count += 1
|
||||
params["chunk_order"] = chunk_count
|
||||
params["source"] = "User loaded"
|
||||
logging.info(
|
||||
"Loading document with default loader settings %s", str(doc)
|
||||
)
|
||||
logging.info("Params are %s", str(params))
|
||||
retriever.add_documents(
|
||||
[Document(metadata=params, page_content=doc.page_content)]
|
||||
)
|
||||
|
||||
async def fetch_memories(
|
||||
self,
|
||||
observation: str,
|
||||
namespace: str = None,
|
||||
search_type: str = "hybrid",
|
||||
params=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Fetch documents from weaviate.
|
||||
|
||||
Parameters:
|
||||
- observation (str): User query.
|
||||
- namespace (str, optional): Type of memory accessed.
|
||||
- search_type (str, optional): Type of search ('text', 'hybrid', 'bm25', 'generate', 'generate_grouped'). Defaults to 'hybrid'.
|
||||
- **kwargs: Additional parameters for flexibility.
|
||||
|
||||
Returns:
|
||||
List of documents matching the query or an empty list in case of error.
|
||||
|
||||
Example:
|
||||
fetch_memories(query="some query", search_type='text', additional_param='value')
|
||||
"""
|
||||
client = self.init_weaviate(namespace=self.namespace)
|
||||
if search_type is None:
|
||||
search_type = "hybrid"
|
||||
|
||||
if not namespace:
|
||||
namespace = self.namespace
|
||||
|
||||
logging.info("Query on namespace %s", namespace)
|
||||
|
||||
params_user_id = {
|
||||
"path": ["user_id"],
|
||||
"operator": "Like",
|
||||
"valueText": self.user_id,
|
||||
}
|
||||
|
||||
def list_objects_of_class(class_name, schema):
|
||||
return [
|
||||
prop["name"]
|
||||
for class_obj in schema["classes"]
|
||||
if class_obj["class"] == class_name
|
||||
for prop in class_obj["properties"]
|
||||
]
|
||||
|
||||
base_query = (
|
||||
client.query.get(
|
||||
namespace, list(list_objects_of_class(namespace, client.schema.get()))
|
||||
)
|
||||
.with_additional(
|
||||
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", "distance"]
|
||||
)
|
||||
.with_where(params_user_id)
|
||||
.with_limit(10)
|
||||
)
|
||||
|
||||
n_of_observations = kwargs.get("n_of_observations", 2)
|
||||
|
||||
# try:
|
||||
if search_type == "text":
|
||||
query_output = (
|
||||
base_query.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == "hybrid":
|
||||
query_output = (
|
||||
base_query.with_hybrid(
|
||||
query=observation, fusion_type=HybridFusion.RELATIVE_SCORE
|
||||
)
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == "bm25":
|
||||
query_output = (
|
||||
base_query.with_bm25(query=observation)
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == "summary":
|
||||
filter_object = {
|
||||
"operator": "And",
|
||||
"operands": [
|
||||
{
|
||||
"path": ["user_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": self.user_id,
|
||||
},
|
||||
{
|
||||
"path": ["chunk_order"],
|
||||
"operator": "LessThan",
|
||||
"valueNumber": 30,
|
||||
},
|
||||
],
|
||||
}
|
||||
base_query = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
list(list_objects_of_class(namespace, client.schema.get())),
|
||||
)
|
||||
.with_additional(
|
||||
[
|
||||
"id",
|
||||
"creationTimeUnix",
|
||||
"lastUpdateTimeUnix",
|
||||
"score",
|
||||
"distance",
|
||||
]
|
||||
)
|
||||
.with_where(filter_object)
|
||||
.with_limit(30)
|
||||
)
|
||||
query_output = (
|
||||
base_query
|
||||
# .with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
||||
.do()
|
||||
)
|
||||
|
||||
elif search_type == "summary_filter_by_object_name":
|
||||
filter_object = {
|
||||
"operator": "And",
|
||||
"operands": [
|
||||
{
|
||||
"path": ["user_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": self.user_id,
|
||||
},
|
||||
{
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueText": params,
|
||||
},
|
||||
],
|
||||
}
|
||||
base_query = (
|
||||
client.query.get(
|
||||
namespace,
|
||||
list(list_objects_of_class(namespace, client.schema.get())),
|
||||
)
|
||||
.with_additional(
|
||||
[
|
||||
"id",
|
||||
"creationTimeUnix",
|
||||
"lastUpdateTimeUnix",
|
||||
"score",
|
||||
"distance",
|
||||
]
|
||||
)
|
||||
.with_where(filter_object)
|
||||
.with_limit(30)
|
||||
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
|
||||
)
|
||||
query_output = base_query.do()
|
||||
|
||||
return query_output
|
||||
elif search_type == "generate":
|
||||
generate_prompt = kwargs.get("generate_prompt", "")
|
||||
query_output = (
|
||||
base_query.with_generate(single_prompt=observation)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
elif search_type == "generate_grouped":
|
||||
generate_prompt = kwargs.get("generate_prompt", "")
|
||||
query_output = (
|
||||
base_query.with_generate(grouped_task=observation)
|
||||
.with_near_text({"concepts": [observation]})
|
||||
.with_autocut(n_of_observations)
|
||||
.do()
|
||||
)
|
||||
else:
|
||||
logging.error(f"Invalid search_type: {search_type}")
|
||||
return []
|
||||
# except Exception as e:
|
||||
# logging.error(f"Error executing query: {str(e)}")
|
||||
# return []
|
||||
|
||||
return query_output
|
||||
|
||||
async def delete_memories(self, namespace: str, params: dict = None):
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
client = self.init_weaviate(namespace=self.namespace)
|
||||
if params:
|
||||
where_filter = {
|
||||
"path": ["id"],
|
||||
"operator": "Equal",
|
||||
"valueText": params.get("id", None),
|
||||
}
|
||||
return client.batch.delete_objects(
|
||||
class_name=self.namespace,
|
||||
# Same `where` filter as in the GraphQL API
|
||||
where=where_filter,
|
||||
)
|
||||
else:
|
||||
# Delete all objects
|
||||
return client.batch.delete_objects(
|
||||
class_name=namespace,
|
||||
where={
|
||||
"path": ["version"],
|
||||
"operator": "Equal",
|
||||
"valueText": "1.0",
|
||||
},
|
||||
)
|
||||
|
||||
async def count_memories(self, namespace: str = None, params: dict = None) -> int:
|
||||
"""
|
||||
Count memories in a Weaviate database.
|
||||
|
||||
Args:
|
||||
namespace (str, optional): The Weaviate namespace to count memories in. If not provided, uses the default namespace.
|
||||
|
||||
Returns:
|
||||
int: The number of memories in the specified namespace.
|
||||
"""
|
||||
if namespace is None:
|
||||
namespace = self.namespace
|
||||
|
||||
client = self.init_weaviate(namespace=namespace)
|
||||
|
||||
try:
|
||||
object_count = client.query.aggregate(namespace).with_meta_count().do()
|
||||
return object_count
|
||||
except Exception as e:
|
||||
logging.info(f"Error counting memories: {str(e)}")
|
||||
# Handle the error or log it
|
||||
return 0
|
||||
|
||||
def update_memories(self, observation, namespace: str, params: dict = None):
|
||||
client = self.init_weaviate(namespace=self.namespace)
|
||||
|
||||
client.data_object.update(
|
||||
data_object={
|
||||
# "text": observation,
|
||||
"user_id": str(self.user_id),
|
||||
"version": params.get("version", None) or "",
|
||||
"agreement_id": params.get("agreement_id", None) or "",
|
||||
"privacy_policy": params.get("privacy_policy", None) or "",
|
||||
"terms_of_service": params.get("terms_of_service", None) or "",
|
||||
"format": params.get("format", None) or "",
|
||||
"schema_version": params.get("schema_version", None) or "",
|
||||
"checksum": params.get("checksum", None) or "",
|
||||
"owner": params.get("owner", None) or "",
|
||||
"license": params.get("license", None) or "",
|
||||
"validity_start": params.get("validity_start", None) or "",
|
||||
"validity_end": params.get("validity_end", None) or ""
|
||||
# **source_metadata,
|
||||
},
|
||||
class_name="Test",
|
||||
uuid=params.get("id", None),
|
||||
consistency_level=weaviate.data.replication.ConsistencyLevel.ALL, # default QUORUM
|
||||
)
|
||||
return
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
from typing import List
|
||||
from multiprocessing import Pool
|
||||
import weaviate
|
||||
import weaviate.classes as wvc
|
||||
import weaviate.classes.config as wvcc
|
||||
from weaviate.classes.data import DataObject
|
||||
from ..vector_db_interface import VectorDBInterface
|
||||
from ..models.DataPoint import DataPoint
|
||||
from ..models.ScoredResult import ScoredResult
|
||||
|
||||
class WeaviateAdapter(VectorDBInterface):
|
||||
async_pool: Pool = None
|
||||
|
||||
def __init__(self, url: str, api_key: str, openai_api_key: str):
|
||||
self.client = weaviate.connect_to_wcs(
|
||||
cluster_url = url,
|
||||
auth_credentials = weaviate.auth.AuthApiKey(api_key),
|
||||
headers = {
|
||||
"X-OpenAI-Api-Key": openai_api_key
|
||||
},
|
||||
additional_config = wvc.init.AdditionalConfig(timeou = wvc.init.Timeout(init = 30))
|
||||
)
|
||||
|
||||
async def create_collection(self, collection_name: str, collection_config: dict):
|
||||
return self.client.collections.create(
|
||||
name = collection_name,
|
||||
vectorizer_config = wvcc.Configure.Vectorizer.text2vec_openai(),
|
||||
generative_config = wvcc.Configure.Generative.openai(),
|
||||
properties = [
|
||||
wvcc.Property(
|
||||
name = "text",
|
||||
data_type = wvcc.DataType.TEXT
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def get_collection(self, collection_name: str):
|
||||
return self.client.collections.get(collection_name)
|
||||
|
||||
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
|
||||
def convert_to_weaviate_data_points(data_point: DataPoint):
|
||||
return DataObject(
|
||||
uuid = data_point.id,
|
||||
properties = data_point.payload
|
||||
)
|
||||
|
||||
objects = list(map(convert_to_weaviate_data_points, data_points))
|
||||
|
||||
return self.get_collection(collection_name).data.insert_many(objects)
|
||||
|
||||
async def search(self, collection_name: str, query_text: str, limit: int, with_vector: bool = False):
|
||||
search_result = self.get_collection(collection_name).query.bm25(
|
||||
query = query_text,
|
||||
limit = limit,
|
||||
include_vector = with_vector,
|
||||
return_metadata = wvc.query.MetadataQuery(score = True),
|
||||
)
|
||||
|
||||
return list(map(lambda result: ScoredResult(
|
||||
id = result.uuid,
|
||||
payload = result.properties,
|
||||
score = str(result.metadata.score)
|
||||
), search_result.objects))
|
||||
|
||||
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
|
||||
def query_search(query_text):
|
||||
return self.search(collection_name, query_text, limit = limit, with_vector = with_vectors)
|
||||
|
||||
return [await query_search(query_text) for query_text in query_texts]
|
||||
|
||||
async def prune(self):
|
||||
self.client.collections.delete_all()
|
||||
|
|
@ -0,0 +1 @@
|
|||
from .WeaviateAdapter import WeaviateAdapter
|
||||
|
|
@ -3,20 +3,26 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
|
|||
from cognee.shared.data_models import GraphDBType
|
||||
|
||||
|
||||
async def extract_node_descriptions(data):
|
||||
async def extract_node_descriptions(node):
|
||||
descriptions = []
|
||||
|
||||
for node_id, attributes in data:
|
||||
if 'description' in attributes and 'unique_id' in attributes:
|
||||
descriptions.append({'node_id': attributes['unique_id'], 'description': attributes['description'], 'layer_uuid': attributes['layer_uuid'], 'layer_decomposition_uuid': attributes['layer_decomposition_uuid'] })
|
||||
for node_id, attributes in node:
|
||||
if "description" in attributes and "unique_id" in attributes:
|
||||
descriptions.append({
|
||||
"node_id": attributes["unique_id"],
|
||||
"description": attributes["description"],
|
||||
"layer_uuid": attributes["layer_uuid"],
|
||||
"layer_decomposition_uuid": attributes["layer_decomposition_uuid"]
|
||||
})
|
||||
|
||||
return descriptions
|
||||
|
||||
|
||||
async def add_node_connection(node_descriptions):
|
||||
async def group_nodes_by_layer(node_descriptions):
|
||||
grouped_data = {}
|
||||
|
||||
for item in node_descriptions:
|
||||
uuid = item['layer_decomposition_uuid']
|
||||
uuid = item["layer_decomposition_uuid"]
|
||||
|
||||
if uuid not in grouped_data:
|
||||
grouped_data[uuid] = []
|
||||
|
|
@ -35,19 +41,19 @@ def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph:
|
|||
"""
|
||||
for id, relationships in relationship_dict.items():
|
||||
for relationship in relationships:
|
||||
searched_node_attr_id = relationship['searched_node_id']
|
||||
score_attr_id = relationship['original_id_for_search']
|
||||
score = relationship['score']
|
||||
searched_node_attr_id = relationship["searched_node_id"]
|
||||
score_attr_id = relationship["original_id_for_search"]
|
||||
score = relationship["score"]
|
||||
|
||||
# Initialize node keys for both searched_node and score_node
|
||||
searched_node_key, score_node_key = None, None
|
||||
|
||||
# Find nodes in the graph that match the searched_node_id and score_id from their attributes
|
||||
for node, attrs in graph.nodes(data=True):
|
||||
if 'unique_id' in attrs: # Ensure there is an 'id' attribute
|
||||
if attrs['unique_id'] == searched_node_attr_id:
|
||||
for node, attrs in graph.nodes(data = True):
|
||||
if "unique_id" in attrs: # Ensure there is an "id" attribute
|
||||
if attrs["unique_id"] == searched_node_attr_id:
|
||||
searched_node_key = node
|
||||
elif attrs['unique_id'] == score_attr_id:
|
||||
elif attrs["unique_id"] == score_attr_id:
|
||||
score_node_key = node
|
||||
|
||||
# If both nodes are found, no need to continue checking other nodes
|
||||
|
|
@ -57,9 +63,13 @@ def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph:
|
|||
# Check if both nodes were found in the graph
|
||||
if searched_node_key is not None and score_node_key is not None:
|
||||
# If both nodes exist, create an edge between them
|
||||
# You can customize the edge attributes as needed, here we use 'score' as an attribute
|
||||
graph.add_edge(searched_node_key, score_node_key, weight=score,
|
||||
score_metadata=relationship.get('score_metadata'))
|
||||
# You can customize the edge attributes as needed, here we use "score" as an attribute
|
||||
graph.add_edge(
|
||||
searched_node_key,
|
||||
score_node_key,
|
||||
weight = score,
|
||||
score_metadata = relationship.get("score_metadata")
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
|
@ -67,31 +77,23 @@ def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph:
|
|||
def graph_ready_output(results):
|
||||
relationship_dict = {}
|
||||
|
||||
for result_tuple in results:
|
||||
|
||||
uuid, scored_points_list, desc, node_id = result_tuple
|
||||
# Unpack the tuple
|
||||
for result in results:
|
||||
layer_id = result["layer_id"]
|
||||
layer_nodes = result["layer_nodes"]
|
||||
|
||||
# Ensure there's a list to collect related items for this uuid
|
||||
if uuid not in relationship_dict:
|
||||
relationship_dict[uuid] = []
|
||||
if layer_id not in relationship_dict:
|
||||
relationship_dict[layer_id] = []
|
||||
|
||||
for node in layer_nodes: # Iterate over the list of ScoredPoint lists
|
||||
for score_point in node["score_points"]:
|
||||
# Append a new dictionary to the list associated with the uuid
|
||||
relationship_dict[layer_id].append({
|
||||
"collection_id": layer_id,
|
||||
"searched_node_id": node["id"],
|
||||
"score": score_point.score,
|
||||
"score_metadata": score_point.payload,
|
||||
"original_id_for_search": score_point.id,
|
||||
})
|
||||
|
||||
for scored_points in scored_points_list: # Iterate over the list of ScoredPoint lists
|
||||
for scored_point in scored_points: # Iterate over each ScoredPoint object
|
||||
if scored_point.score > 0.9: # Check the score condition
|
||||
# Append a new dictionary to the list associated with the uuid
|
||||
relationship_dict[uuid].append({
|
||||
'collection_name_uuid': uuid,
|
||||
'searched_node_id': scored_point.id,
|
||||
'score': scored_point.score,
|
||||
'score_metadata': scored_point.payload,
|
||||
'original_id_for_search': node_id,
|
||||
})
|
||||
return relationship_dict
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
graph_client = get_graph_client(GraphDBType.NETWORKX)
|
||||
add_node_connection(graph_client)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
import json
|
||||
from datetime import datetime
|
||||
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client, GraphDBType
|
||||
from cognee.shared.encode_uuid import encode_uuid
|
||||
|
||||
|
||||
async def add_propositions(
|
||||
|
|
@ -69,7 +70,6 @@ async def add_propositions(
|
|||
async def append_to_graph(layer_graphs, required_layers):
|
||||
# Generate a UUID for the overall layer
|
||||
layer_uuid = uuid.uuid4()
|
||||
decomposition_uuids = set()
|
||||
# Extract category name from required_layers data
|
||||
data_type = required_layers["data_type"]
|
||||
|
||||
|
|
@ -84,9 +84,7 @@ async def append_to_graph(layer_graphs, required_layers):
|
|||
layer_description = json.loads(layer_json)
|
||||
|
||||
# Generate a UUID for this particular layer decomposition
|
||||
layer_decomposition_uuid = uuid.uuid4()
|
||||
|
||||
decomposition_uuids.add(layer_decomposition_uuid)
|
||||
layer_decomposition_id = encode_uuid(uuid.uuid4())
|
||||
|
||||
# Assuming append_data_to_graph is defined elsewhere and appends data to graph_client
|
||||
# You would pass relevant information from knowledge_graph along with other details to this function
|
||||
|
|
@ -96,11 +94,9 @@ async def append_to_graph(layer_graphs, required_layers):
|
|||
layer_description,
|
||||
knowledge_graph,
|
||||
layer_uuid,
|
||||
layer_decomposition_uuid
|
||||
layer_decomposition_id
|
||||
)
|
||||
|
||||
return decomposition_uuids
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
|
|
|
|||
|
|
@ -22,10 +22,10 @@ async def process_attribute(graph_client, parent_id: Optional[str], attribute: s
|
|||
if isinstance(value, BaseModel):
|
||||
node_id = await generate_node_id(value)
|
||||
|
||||
node_data = value.dict(exclude={"default_relationship"})
|
||||
node_data = value.model_dump(exclude = {"default_relationship"})
|
||||
|
||||
# Use the specified default relationship for the edge between the parent node and the current node
|
||||
relationship_data = value.default_relationship.dict() if hasattr(value, "default_relationship") else {}
|
||||
relationship_data = value.default_relationship.model_dump() if hasattr(value, "default_relationship") else {}
|
||||
|
||||
await add_node_and_edge(graph_client, parent_id, node_id, node_data, relationship_data)
|
||||
|
||||
|
|
@ -41,7 +41,7 @@ async def process_attribute(graph_client, parent_id: Optional[str], attribute: s
|
|||
async def create_dynamic(graph_model) :
|
||||
root_id = await generate_node_id(graph_model)
|
||||
|
||||
node_data = graph_model.dict(exclude = {"default_relationship", "id"})
|
||||
node_data = graph_model.model_dump(exclude = {"default_relationship", "id"})
|
||||
|
||||
graph_client = get_graph_client(GraphDBType.NETWORKX)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,30 +0,0 @@
|
|||
import asyncio
|
||||
|
||||
async def process_items(grouped_data, unique_layer_uuids, llm_client):
|
||||
results_to_check = [] # This will hold results excluding self comparisons
|
||||
tasks = [] # List to hold all tasks
|
||||
task_to_info = {} # Dictionary to map tasks to their corresponding group id and item info
|
||||
|
||||
# Iterate through each group in grouped_data
|
||||
for group_id, items in grouped_data.items():
|
||||
# Filter unique_layer_uuids to exclude the current group_id
|
||||
target_uuids = [uuid for uuid in unique_layer_uuids if uuid != group_id]
|
||||
|
||||
# Process each item in the group
|
||||
for item in items:
|
||||
# For each target UUID, create an async task for the item's embedding retrieval
|
||||
for target_id in target_uuids:
|
||||
task = asyncio.create_task(llm_client.async_get_embedding_with_backoff(item['description'], "text-embedding-3-large"))
|
||||
tasks.append(task)
|
||||
# Map the task to the target id, item's node_id, and description for later retrieval
|
||||
task_to_info[task] = (target_id, item['node_id'], group_id, item['description'])
|
||||
|
||||
# Await all tasks to complete and gather results
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Process the results, associating them with their target id, node id, and description
|
||||
for task, embedding in zip(tasks, results):
|
||||
target_id, node_id, group_id, description = task_to_info[task]
|
||||
results_to_check.append([target_id, embedding, description, node_id, group_id])
|
||||
|
||||
return results_to_check
|
||||
|
|
@ -11,7 +11,7 @@ async def classify_into_categories(text_input: str, system_prompt_file: str, res
|
|||
|
||||
llm_output = await llm_client.acreate_structured_output(text_input, system_prompt, response_model)
|
||||
|
||||
return extract_categories(llm_output.dict())
|
||||
return extract_categories(llm_output.model_dump())
|
||||
|
||||
def extract_categories(llm_output) -> List[dict]:
|
||||
# Extract the first subclass from the list (assuming there could be more)
|
||||
|
|
|
|||
39
cognee/modules/cognify/llm/resolve_cross_graph_references.py
Normal file
39
cognee/modules/cognify/llm/resolve_cross_graph_references.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from typing import Dict, List
|
||||
from cognee.infrastructure.databases.vector import get_vector_database
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
|
||||
async def resolve_cross_graph_references(nodes_by_layer: Dict):
|
||||
results = []
|
||||
|
||||
unique_layers = nodes_by_layer.keys()
|
||||
|
||||
for layer_id, layer_nodes in nodes_by_layer.items():
|
||||
# Filter unique_layer_uuids to exclude the current layer
|
||||
other_layers = [uuid for uuid in unique_layers if uuid != layer_id]
|
||||
|
||||
for other_layer in other_layers:
|
||||
results.append(await get_nodes_by_layer(other_layer, layer_nodes))
|
||||
|
||||
return results
|
||||
|
||||
async def get_nodes_by_layer(layer_id: str, layer_nodes: List):
|
||||
vector_engine = infrastructure_config.get_config()["vector_engine"]
|
||||
|
||||
score_points = await vector_engine.batch_search(
|
||||
layer_id,
|
||||
list(map(lambda node: node["description"], layer_nodes)),
|
||||
limit = 3
|
||||
)
|
||||
|
||||
return {
|
||||
"layer_id": layer_id,
|
||||
"layer_nodes": connect_score_points_to_node(score_points, layer_nodes)
|
||||
}
|
||||
|
||||
def connect_score_points_to_node(score_points, layer_nodes):
|
||||
return [
|
||||
{
|
||||
"id": node["node_id"],
|
||||
"score_points": score_points[node_index]
|
||||
} for node_index, node in enumerate(layer_nodes)
|
||||
]
|
||||
|
|
@ -1,40 +1,28 @@
|
|||
import asyncio
|
||||
|
||||
from qdrant_client import models
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.databases.vector import get_vector_database
|
||||
from cognee.infrastructure.databases.vector import DataPoint
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
|
||||
async def get_embeddings(texts:list):
|
||||
""" Get embeddings for a list of texts"""
|
||||
client = get_llm_client()
|
||||
tasks = [client.async_get_embedding_with_backoff(text, "text-embedding-3-large") for text in texts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def add_proposition_to_vector_store(id, metadata, embeddings, collection_name):
|
||||
""" Upload a single embedding to a collection in Qdrant."""
|
||||
client = get_vector_database()
|
||||
|
||||
await client.create_data_points(
|
||||
collection_name = collection_name,
|
||||
data_points = [
|
||||
models.PointStruct(
|
||||
id = id,
|
||||
payload = metadata,
|
||||
vector = {"content" : embeddings}
|
||||
)
|
||||
]
|
||||
def convert_to_data_point(node):
|
||||
return DataPoint(
|
||||
id = node["node_id"],
|
||||
payload = {
|
||||
"text": node["description"]
|
||||
},
|
||||
embed_field = "text"
|
||||
)
|
||||
|
||||
async def add_propositions(nodes_by_layer):
|
||||
vector_engine = infrastructure_config.get_config()["vector_engine"]
|
||||
|
||||
async def add_propositions(node_descriptions):
|
||||
for item in node_descriptions:
|
||||
embeddings = await get_embeddings([item["description"]])
|
||||
awaitables = []
|
||||
|
||||
await add_proposition_to_vector_store(
|
||||
id = item["node_id"],
|
||||
metadata = {
|
||||
"meta": item["description"]
|
||||
},
|
||||
embeddings = embeddings[0],
|
||||
collection_name = item["layer_decomposition_uuid"]
|
||||
for layer_id, layer_nodes in nodes_by_layer.items():
|
||||
awaitables.append(
|
||||
vector_engine.create_data_points(
|
||||
collection_name = layer_id,
|
||||
data_points = list(map(convert_to_data_point, layer_nodes))
|
||||
)
|
||||
)
|
||||
|
||||
return await asyncio.gather(*awaitables)
|
||||
|
|
|
|||
|
|
@ -1,24 +0,0 @@
|
|||
|
||||
async def adapted_qdrant_batch_search(results_to_check, vector_client):
|
||||
search_results_list = []
|
||||
|
||||
for result in results_to_check:
|
||||
id = result[0]
|
||||
embedding = result[1]
|
||||
node_id = result[2]
|
||||
target = result[3]
|
||||
|
||||
# Assuming each result in results_to_check contains a single embedding
|
||||
limits = [3] * len(embedding) # Set a limit of 3 results for this embedding
|
||||
|
||||
try:
|
||||
#Perform the batch search for this id with its embedding
|
||||
#Assuming qdrant_batch_search function accepts a single embedding and a list of limits
|
||||
#qdrant_batch_search
|
||||
id_search_results = await vector_client.batch_search(collection_name = id, embeddings = embedding, with_vectors = limits)
|
||||
search_results_list.append((id, id_search_results, node_id, target))
|
||||
except Exception as e:
|
||||
print(f"Error during batch search for ID {id}: {e}")
|
||||
continue
|
||||
|
||||
return search_results_list
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
""" This module contains the function to find the neighbours of a given node in the graph"""
|
||||
|
||||
|
||||
async def search_adjacent(graph,query:str, other_param:dict = None)->dict:
|
||||
async def search_adjacent(graph, query: str, other_param: dict = None) -> dict:
|
||||
""" Find the neighbours of a given node in the graph
|
||||
:param graph: A NetworkX graph object
|
||||
|
||||
|
|
|
|||
|
|
@ -1,21 +1,19 @@
|
|||
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.modules.cognify.graph.add_node_connections import extract_node_descriptions
|
||||
from cognee.infrastructure.databases.vector.get_vector_database import get_vector_database
|
||||
from cognee.infrastructure import infrastructure_config
|
||||
|
||||
async def search_similarity(query:str ,graph,other_param:str = None):
|
||||
|
||||
async def search_similarity(query: str, graph, other_param: str = None):
|
||||
node_descriptions = await extract_node_descriptions(graph.nodes(data = True))
|
||||
|
||||
unique_layer_uuids = set(node["layer_decomposition_uuid"] for node in node_descriptions)
|
||||
|
||||
client = get_llm_client()
|
||||
out = []
|
||||
query = await client.async_get_embedding_with_backoff(query)
|
||||
for id in unique_layer_uuids:
|
||||
vector_client = get_vector_database()
|
||||
|
||||
result = await vector_client.search(id, query,10)
|
||||
for id in unique_layer_uuids:
|
||||
vector_engine = infrastructure_config.get_config()["vector_engine"]
|
||||
|
||||
result = await vector_engine.search(id, query, 10)
|
||||
|
||||
if result:
|
||||
result_ = [ result_.id for result_ in result]
|
||||
|
|
@ -25,13 +23,16 @@ async def search_similarity(query:str ,graph,other_param:str = None):
|
|||
|
||||
relevant_context = []
|
||||
|
||||
if len(out) == 0:
|
||||
return []
|
||||
|
||||
for proposition_id in out[0][0]:
|
||||
for n,attr in graph.nodes(data=True):
|
||||
for n, attr in graph.nodes(data = True):
|
||||
if proposition_id in n:
|
||||
for n_, attr_ in graph.nodes(data=True):
|
||||
relevant_layer = attr['layer_uuid']
|
||||
relevant_layer = attr["layer_uuid"]
|
||||
|
||||
if attr_.get('layer_uuid') == relevant_layer:
|
||||
relevant_context.append(attr_['description'])
|
||||
if attr_.get("layer_uuid") == relevant_layer:
|
||||
relevant_context.append(attr_["description"])
|
||||
|
||||
return relevant_context
|
||||
|
|
|
|||
14
cognee/shared/encode_uuid.py
Normal file
14
cognee/shared/encode_uuid.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from uuid import UUID
|
||||
|
||||
def encode_uuid(uuid: UUID) -> str:
|
||||
uuid_int = uuid.int
|
||||
base = 52
|
||||
charset = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
|
||||
encoded = ''
|
||||
while len(encoded) < 36:
|
||||
uuid_int, remainder = divmod(uuid_int, base)
|
||||
uuid_int = uuid_int * 8
|
||||
encoded = charset[remainder] + encoded
|
||||
|
||||
return encoded
|
||||
74
poetry.lock
generated
74
poetry.lock
generated
|
|
@ -412,6 +412,20 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
|
|||
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
|
||||
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.3.0"
|
||||
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "Authlib-1.3.0-py2.py3-none-any.whl", hash = "sha256:9637e4de1fb498310a56900b3e2043a206b03cb11c05422014b0302cbc814be3"},
|
||||
{file = "Authlib-1.3.0.tar.gz", hash = "sha256:959ea62a5b7b5123c5059758296122b57cd2585ae2ed1c0622c21b371ffdae06"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
cryptography = "*"
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.14.0"
|
||||
|
|
@ -1955,6 +1969,21 @@ files = [
|
|||
[package.extras]
|
||||
protobuf = ["grpcio-tools (>=1.62.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "grpcio-health-checking"
|
||||
version = "1.62.1"
|
||||
description = "Standard Health Checking Service for gRPC"
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "grpcio-health-checking-1.62.1.tar.gz", hash = "sha256:9e56180a941b1d32a077d7491e0611d0483c396358afd5349bf00152612e4583"},
|
||||
{file = "grpcio_health_checking-1.62.1-py3-none-any.whl", hash = "sha256:9ce761c09fc383e7aa2f7e6c0b0b65d5a1157c1b98d1f5871f7c38aca47d49b9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
grpcio = ">=1.62.1"
|
||||
protobuf = ">=4.21.6"
|
||||
|
||||
[[package]]
|
||||
name = "grpcio-tools"
|
||||
version = "1.62.1"
|
||||
|
|
@ -7690,6 +7719,28 @@ h11 = ">=0.8"
|
|||
[package.extras]
|
||||
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "validators"
|
||||
version = "0.22.0"
|
||||
description = "Python Data Validation for Humans™"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "validators-0.22.0-py3-none-any.whl", hash = "sha256:61cf7d4a62bbae559f2e54aed3b000cea9ff3e2fdbe463f51179b92c58c9585a"},
|
||||
{file = "validators-0.22.0.tar.gz", hash = "sha256:77b2689b172eeeb600d9605ab86194641670cdb73b60afd577142a9397873370"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs-offline = ["myst-parser (>=2.0.0)", "pypandoc-binary (>=1.11)", "sphinx (>=7.1.1)"]
|
||||
docs-online = ["mkdocs (>=1.5.2)", "mkdocs-git-revision-date-localized-plugin (>=1.2.0)", "mkdocs-material (>=9.2.6)", "mkdocstrings[python] (>=0.22.0)", "pyaml (>=23.7.0)"]
|
||||
hooks = ["pre-commit (>=3.3.3)"]
|
||||
package = ["build (>=1.0.0)", "twine (>=4.0.2)"]
|
||||
runner = ["tox (>=4.11.1)"]
|
||||
sast = ["bandit[toml] (>=1.7.5)"]
|
||||
testing = ["pytest (>=7.4.0)"]
|
||||
tooling = ["black (>=23.7.0)", "pyright (>=1.1.325)", "ruff (>=0.0.287)"]
|
||||
tooling-extras = ["pyaml (>=23.7.0)", "pypandoc-binary (>=1.11)", "pytest (>=7.4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "watchdog"
|
||||
version = "4.0.0"
|
||||
|
|
@ -7742,6 +7793,27 @@ files = [
|
|||
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "weaviate-client"
|
||||
version = "4.5.4"
|
||||
description = "A python native Weaviate client"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "weaviate-client-4.5.4.tar.gz", hash = "sha256:fc53dc73cd53df453c5e6dc758e49a6a1549212d6670ddd013392107120692f8"},
|
||||
{file = "weaviate_client-4.5.4-py3-none-any.whl", hash = "sha256:f6d3a6b759e5aa0d3350067490526ea38b9274ae4043b4a3ae0064c28d56883f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
authlib = ">=1.2.1,<2.0.0"
|
||||
grpcio = ">=1.57.0,<2.0.0"
|
||||
grpcio-health-checking = ">=1.57.0,<2.0.0"
|
||||
grpcio-tools = ">=1.57.0,<2.0.0"
|
||||
httpx = "0.27.0"
|
||||
pydantic = ">=2.5.0,<3.0.0"
|
||||
requests = ">=2.30.0,<3.0.0"
|
||||
validators = "0.22.0"
|
||||
|
||||
[[package]]
|
||||
name = "webcolors"
|
||||
version = "1.13"
|
||||
|
|
@ -8057,4 +8129,4 @@ weaviate = []
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "~3.10"
|
||||
content-hash = "37a0db9a6a86b71a35c91ac5ef86204d76529033260032917906a907bffc8216"
|
||||
content-hash = "d742617c6e8a62dc9bff8656fa97955f63c991720805388190b205da66e4712a"
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ qdrant-client = "^1.8.0"
|
|||
duckdb-engine = "^0.11.2"
|
||||
graphistry = "^0.33.5"
|
||||
tenacity = "^8.2.3"
|
||||
weaviate-client = "^4.5.4"
|
||||
|
||||
[tool.poetry.extras]
|
||||
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue