Merge branch 'feat/COG-113-integrate-weviate' into feat/COG-118-remove-unused-code

This commit is contained in:
Boris Arzentar 2024-03-21 10:12:43 +01:00
commit a727cce00f
27 changed files with 895 additions and 1097 deletions

File diff suppressed because one or more lines are too long

View file

@ -1,23 +1,19 @@
import asyncio
# import logging
from typing import List, Union
from qdrant_client import models
import instructor
from openai import OpenAI
from unstructured.cleaners.core import clean
from unstructured.partition.pdf import partition_pdf
from cognee.infrastructure.databases.vector.qdrant.QDrantAdapter import CollectionConfig
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.cognify.graph.add_classification_nodes import add_classification_nodes
from cognee.modules.cognify.llm.label_content import label_content
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
from cognee.modules.cognify.graph.add_node_connections import add_node_connection, graph_ready_output, \
from cognee.modules.cognify.llm.summarize_content import summarize_content
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_layer, graph_ready_output, \
connect_nodes_in_graph, extract_node_descriptions
from cognee.modules.cognify.graph.add_propositions import append_to_graph
from cognee.modules.cognify.graph.add_summary_nodes import add_summary_nodes
from cognee.modules.cognify.llm.add_node_connection_embeddings import process_items
from cognee.modules.cognify.llm.label_content import label_content
from cognee.modules.cognify.llm.summarize_content import summarize_content
from cognee.modules.cognify.vector.batch_search import adapted_qdrant_batch_search
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
from cognee.modules.cognify.vector.add_propositions import add_propositions
from cognee.config import Config
@ -28,10 +24,11 @@ from cognee.shared.data_models import DefaultContentPrediction, KnowledgeGraph,
SummarizedContent, LabeledContent
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.shared.data_models import GraphDBType
from cognee.infrastructure.databases.vector.get_vector_database import get_vector_database
from cognee.infrastructure.databases.relational import DuckDBAdapter
from cognee.modules.cognify.graph.add_document_node import add_document_node
from cognee.modules.cognify.graph.initialize_graph import initialize_graph
from cognee.infrastructure.databases.vector import CollectionConfig, VectorConfig
from cognee.infrastructure import infrastructure_config
config = Config()
config.load()
@ -76,7 +73,7 @@ async def cognify(datasets: Union[str, List[str]] = None, graphdatamodel: object
async def process_text(input_text: str, file_metadata: dict):
print(f"Processing document ({file_metadata['id']})")
classified_categories = []
try:
@ -133,31 +130,22 @@ async def process_text(input_text: str, file_metadata: dict):
# Run the async function for each set of cognitive layers
layer_graphs = await generate_graph_per_layer(input_text, cognitive_layers)
# print(layer_graphs)
print(f"Document ({file_metadata['id']}) layer graphs created")
# G = await create_semantic_graph(graph_model_instance)
await add_classification_nodes(f"DOCUMENT:{file_metadata['id']}", classified_categories[0])
# print(file_metadata['summary'])
await add_summary_nodes(f"DOCUMENT:{file_metadata['id']}", {"summary": file_metadata["summary"]})
await add_summary_nodes(f"DOCUMENT:{file_metadata['id']}", {"summary": file_metadata['summary']})
await add_label_nodes(f"DOCUMENT:{file_metadata['id']}", {"content_labels": file_metadata["content_labels"]})
# print(file_metadata['content_labels'])
await add_label_nodes(f"DOCUMENT:{file_metadata['id']}", {"content_labels": file_metadata['content_labels']})
unique_layer_uuids = await append_to_graph(layer_graphs, classified_categories[0])
await append_to_graph(layer_graphs, classified_categories[0])
print(f"Document ({file_metadata['id']}) layers connected")
print("Document categories, summaries and metadata are: ", str(classified_categories))
print(f"Document categories, summaries and metadata are ",str(classified_categories) )
print(f"Document metadata is ",str(file_metadata) )
print("Document metadata is: ", str(file_metadata))
graph_client = get_graph_client(GraphDBType.NETWORKX)
@ -165,45 +153,34 @@ async def process_text(input_text: str, file_metadata: dict):
graph = graph_client.graph
# # Extract the node descriptions
node_descriptions = await extract_node_descriptions(graph.nodes(data = True))
# print(node_descriptions)
unique_layer_uuids = set(node["layer_decomposition_uuid"] for node in node_descriptions)
nodes_by_layer = await group_nodes_by_layer(node_descriptions)
unique_layers = nodes_by_layer.keys()
collection_config = CollectionConfig(
vector_config = {
"content": models.VectorParams(
distance = models.Distance.COSINE,
size = 3072
)
},
vector_config = VectorConfig(
distance = "Cosine",
size = 3072
)
)
try:
for layer in unique_layer_uuids:
db = get_vector_database()
await db.create_collection(layer, collection_config)
db_engine = infrastructure_config.get_config()["vector_engine"]
for layer in unique_layers:
await db_engine.create_collection(layer, collection_config)
except Exception as e:
print(e)
await add_propositions(node_descriptions)
await add_propositions(nodes_by_layer)
grouped_data = await add_node_connection(node_descriptions)
# print("we are here, grouped_data", grouped_data)
results = await resolve_cross_graph_references(nodes_by_layer)
llm_client = get_llm_client()
relationships = graph_ready_output(results)
relationship_dict = await process_items(grouped_data, unique_layer_uuids, llm_client)
# print("we are here", relationship_dict[0])
results = await adapted_qdrant_batch_search(relationship_dict, db)
# print(results)
relationship_d = graph_ready_output(results)
# print(relationship_d)
connect_nodes_in_graph(graph, relationship_d)
connect_nodes_in_graph(graph, relationships)
print(f"Document ({file_metadata['id']}) processed")
@ -220,4 +197,4 @@ if __name__ == "__main__":
print(graph_url)
asyncio.run(main())
asyncio.run(main())

View file

@ -1,21 +1,32 @@
from cognee.config import Config
from .databases.relational import SqliteEngine, DatabaseEngine
from .databases.vector import WeaviateAdapter, VectorDBInterface
config = Config()
config.load()
class InfrastructureConfig():
database_engine: DatabaseEngine = None
vector_engine: VectorDBInterface = None
def get_config(self) -> dict:
if self.database_engine is None:
self.database_engine = SqliteEngine(config.db_path, config.db_name)
if self.vector_engine is None:
self.vector_engine = WeaviateAdapter(
config.weaviate_url,
config.weaviate_api_key,
config.openai_key
)
return {
"database_engine": self.database_engine
"database_engine": self.database_engine,
"vector_engine": self.vector_engine
}
def set_config(self, new_config: dict):
self.database_engine = new_config["database_engine"]
self.vector_engine = new_config["vector_engine"]
infrastructure_config = InfrastructureConfig()

View file

@ -1,2 +1,7 @@
from .get_vector_database import get_vector_database
from .qdrant import QDrantAdapter, CollectionConfig
from .qdrant import QDrantAdapter
from .models.DataPoint import DataPoint
from .models.VectorConfig import VectorConfig
from .models.CollectionConfig import CollectionConfig
from .weaviate_db import WeaviateAdapter
from .vector_db_interface import VectorDBInterface

View file

@ -1,8 +1,10 @@
from cognee.config import Config
from .qdrant import QDrantAdapter
# from .qdrant import QDrantAdapter
from .weaviate_db import WeaviateAdapter
config = Config()
config.load()
def get_vector_database():
return QDrantAdapter(config.qdrant_path, config.qdrant_url, config.qdrant_api_key)
# return QDrantAdapter(config.qdrant_path, config.qdrant_url, config.qdrant_api_key)
return WeaviateAdapter(config.weaviate_url, config.weaviate_api_key, config.openai_key)

View file

@ -0,0 +1,5 @@
from pydantic import BaseModel
from .VectorConfig import VectorConfig
class CollectionConfig(BaseModel):
vector_config: VectorConfig

View file

@ -0,0 +1,10 @@
from typing import Dict
from pydantic import BaseModel
class DataPoint(BaseModel):
id: str
payload: Dict[str, str]
embed_field: str
def get_embeddable_data(self):
return self.payload[self.embed_field]

View file

@ -0,0 +1,8 @@
from uuid import UUID
from typing import Any, Dict
from pydantic import BaseModel
class ScoredResult(BaseModel):
id: UUID
score: int
payload: Dict[str, Any]

View file

@ -0,0 +1,6 @@
from typing import Literal
from pydantic import BaseModel
class VectorConfig(BaseModel):
distance: Literal['Cosine', 'Dot']
size: int

View file

@ -1,19 +1,59 @@
from typing import List, Optional, Dict
from pydantic import BaseModel, Field
import asyncio
from typing import List, Dict
# from pydantic import BaseModel, Field
from qdrant_client import AsyncQdrantClient, models
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.VectorConfig import VectorConfig
from ..models.CollectionConfig import CollectionConfig
from cognee.infrastructure.llm.get_llm_client import get_llm_client
class CollectionConfig(BaseModel, extra = "forbid"):
vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
# class CollectionConfig(BaseModel, extra = "forbid"):
# vector_config: Dict[str, models.VectorParams] = Field(..., description="Vectors configuration" )
# hnsw_config: Optional[models.HnswConfig] = Field(default = None, description="HNSW vector index configuration")
# optimizers_config: Optional[models.OptimizersConfig] = Field(default = None, description="Optimizers configuration")
# quantization_config: Optional[models.QuantizationConfig] = Field(default = None, description="Quantization configuration")
async def embed_data(data: str):
llm_client = get_llm_client()
return await llm_client.async_get_embedding_with_backoff(data)
async def convert_to_qdrant_point(data_point: DataPoint):
return models.PointStruct(
id = data_point.id,
payload = data_point.payload,
vector = {
"text": await embed_data(data_point.get_embeddable_data())
}
)
def create_vector_config(vector_config: VectorConfig):
return models.VectorParams(
size = vector_config.size,
distance = vector_config.distance
)
def create_hnsw_config(hnsw_config: Dict):
if hnsw_config is not None:
return models.HnswConfig()
return None
def create_optimizers_config(optimizers_config: Dict):
if optimizers_config is not None:
return models.OptimizersConfig()
return None
def create_quantization_config(quantization_config: Dict):
if quantization_config is not None:
return models.QuantizationConfig()
return None
class QDrantAdapter(VectorDBInterface):
qdrant_url: str = None
qdrant_path: str = None
qdrant_api_key: str = None
def __init__(self, qdrant_path, qdrant_url, qdrant_api_key):
if qdrant_path is not None:
self.qdrant_path = qdrant_path
@ -46,43 +86,49 @@ class QDrantAdapter(VectorDBInterface):
return await client.create_collection(
collection_name = collection_name,
vectors_config = collection_config.vector_config,
hnsw_config = collection_config.hnsw_config,
optimizers_config = collection_config.optimizers_config,
quantization_config = collection_config.quantization_config
vectors_config = {
"text": create_vector_config(collection_config.vector_config)
}
)
async def create_data_points(self, collection_name: str, data_points):
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
client = self.get_qdrant_client()
awaitables = []
for point in data_points:
awaitables.append(convert_to_qdrant_point(point))
points = await asyncio.gather(*awaitables)
return await client.upload_points(
collection_name = collection_name,
points = data_points
points = points
)
async def search(self, collection_name: str, query_vector: List[float], limit: int, with_vector: bool = False):
async def search(self, collection_name: str, query_text: str, limit: int, with_vector: bool = False):
client = self.get_qdrant_client()
return await client.search(
collection_name = collection_name,
query_vector = (
"content", query_vector),
query_vector = models.NamedVector(
name = "text",
vector = await embed_data(query_text)
),
limit = limit,
with_vectors = with_vector
)
async def batch_search(self, collection_name: str, embeddings: List[List[float]],
with_vectors: List[bool] = None):
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
"""
Perform batch search in a Qdrant collection with dynamic search requests.
Args:
- collection_name (str): Name of the collection to search in.
- embeddings (List[List[float]]): List of embeddings to search for.
- limits (List[int]): List of result limits for each search request.
- with_vectors (List[bool], optional): List indicating whether to return vectors for each search request.
Defaults to None, in which case vectors are not returned.
- query_texts (List[str]): List of query texts to search for.
- limit (int): List of result limits for search requests.
- with_vectors (bool, optional): Bool indicating whether to return vectors for search requests.
Returns:
- results: The search results from Qdrant.
@ -90,30 +136,32 @@ class QDrantAdapter(VectorDBInterface):
client = self.get_qdrant_client()
# Default with_vectors to False for each request if not provided
if with_vectors is None:
with_vectors = [False] * len(embeddings)
# Ensure with_vectors list matches the length of embeddings and limits
if len(with_vectors) != len(embeddings):
raise ValueError("The length of with_vectors must match the length of embeddings and limits")
vectors = await asyncio.gather(*[embed_data(query_text) for query_text in query_texts])
# Generate dynamic search requests based on the provided embeddings
requests = [
models.SearchRequest(vector=models.NamedVector(
name="content",
vector=embedding,
),
# vector= embedding,
limit=3,
with_vector=False
) for embedding in [embeddings]
models.SearchRequest(
vector = models.NamedVector(
name = "text",
vector = vector
),
limit = limit,
with_vector = with_vectors
) for vector in vectors
]
# Perform batch search with the dynamically generated requests
results = await client.search_batch(
collection_name=collection_name,
requests=requests
collection_name = collection_name,
requests = requests
)
return results
return [filter(lambda result: result.score > 0.9, result_group) for result_group in results]
async def prune(self):
client = self.get_qdrant_client()
response = await client.get_collections()
for collection in response.collections:
await client.delete_collection(collection.name)

View file

@ -1,6 +1,7 @@
from typing import List
from typing import List, Protocol
from abc import abstractmethod
from typing import Protocol
from .models.CollectionConfig import CollectionConfig
from .models.DataPoint import DataPoint
class VectorDBInterface(Protocol):
""" Collections """
@ -8,7 +9,7 @@ class VectorDBInterface(Protocol):
async def create_collection(
self,
collection_name: str,
collection_config: object
collection_config: CollectionConfig
): raise NotImplementedError
# @abstractmethod
@ -43,7 +44,7 @@ class VectorDBInterface(Protocol):
async def create_data_points(
self,
collection_name: str,
data_points
data_points: List[DataPoint]
): raise NotImplementedError
# @abstractmethod
@ -67,12 +68,13 @@ class VectorDBInterface(Protocol):
# collection_name: str,
# data_point_id: str
# ): raise NotImplementedError
""" Search """
@abstractmethod
async def search(
self,
collection_name: str,
query_vector: List[float],
query_text: str,
limit: int,
with_vector: bool = False
@ -82,7 +84,7 @@ class VectorDBInterface(Protocol):
async def batch_search(
self,
collection_name: str,
embeddings: List[List[float]],
with_vectors: List[bool] = None
query_texts: List[str],
limit: int,
with_vectors: bool = False
): raise NotImplementedError

View file

@ -1,417 +0,0 @@
from weaviate.gql.get import HybridFusion
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.retrievers import WeaviateHybridSearchRetriever, ParentDocumentRetriever
from databases.vector.vector_db_interface import VectorDBInterface
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from cognee.database.vectordb.loaders.loaders import _document_loader
class WeaviateVectorDB(VectorDBInterface):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_weaviate(embeddings=self.embeddings, namespace=self.namespace)
def init_weaviate(
self,
embeddings=OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API_KEY", "")),
namespace=None,
retriever_type="",
):
# Weaviate initialization logic
auth_config = weaviate.auth.AuthApiKey(
api_key=os.environ.get("WEAVIATE_API_KEY")
)
client = weaviate.Client(
url=os.environ.get("WEAVIATE_URL"),
auth_client_secret=auth_config,
additional_headers={"X-OpenAI-Api-Key": os.environ.get("OPENAI_API_KEY")},
)
if retriever_type == "single_document_context":
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=namespace,
text_key="text",
attributes=[],
embedding=embeddings,
create_schema_if_missing=True,
)
return retriever
elif retriever_type == "multi_document_context":
retriever = WeaviateHybridSearchRetriever(
client=client,
index_name=namespace,
text_key="text",
attributes=[],
embedding=embeddings,
create_schema_if_missing=True,
)
return retriever
else:
return client
# child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
# store = InMemoryStore()
# retriever = ParentDocumentRetriever(
# vectorstore=vectorstore,
# docstore=store,
# child_splitter=child_splitter,
# )
from marshmallow import Schema, fields
def create_document_structure(observation, params, metadata_schema_class=None):
"""
Create and validate a document structure with optional custom fields.
:param observation: Content of the document.
:param params: Metadata information.
:param metadata_schema_class: Custom metadata schema class (optional).
:return: A list containing the validated document data.
"""
document_data = {"metadata": params, "page_content": observation}
def get_document_schema():
class DynamicDocumentSchema(Schema):
metadata = fields.Nested(metadata_schema_class, required=True)
page_content = fields.Str(required=True)
return DynamicDocumentSchema
# Validate and deserialize, defaulting to "1.0" if not provided
CurrentDocumentSchema = get_document_schema()
loaded_document = CurrentDocumentSchema().load(document_data)
return [loaded_document]
def _stuct(self, observation, params, metadata_schema_class=None):
"""Utility function to create the document structure with optional custom fields."""
# Construct document data
document_data = {"metadata": params, "page_content": observation}
def get_document_schema():
class DynamicDocumentSchema(Schema):
metadata = fields.Nested(metadata_schema_class, required=True)
page_content = fields.Str(required=True)
return DynamicDocumentSchema
# Validate and deserialize # Default to "1.0" if not provided
CurrentDocumentSchema = get_document_schema()
loaded_document = CurrentDocumentSchema().load(document_data)
return [loaded_document]
async def add_memories(
self,
observation,
loader_settings=None,
params=None,
namespace=None,
metadata_schema_class=None,
embeddings="hybrid",
):
# Update Weaviate memories here
if namespace is None:
namespace = self.namespace
params["user_id"] = self.user_id
logging.info("User id is %s", self.user_id)
retriever = self.init_weaviate(
embeddings=OpenAIEmbeddings(),
namespace=namespace,
retriever_type="single_document_context",
)
if loader_settings:
# Assuming _document_loader returns a list of documents
documents = await _document_loader(observation, loader_settings)
logging.info("here are the docs %s", str(documents))
chunk_count = 0
for doc_list in documents:
for doc in doc_list:
chunk_count += 1
params["chunk_count"] = doc.metadata.get("chunk_count", "None")
logging.info(
"Loading document with provided loader settings %s", str(doc)
)
params["source"] = doc.metadata.get("source", "None")
logging.info("Params are %s", str(params))
retriever.add_documents(
[Document(metadata=params, page_content=doc.page_content)]
)
else:
chunk_count = 0
from cognee.database.vectordb.chunkers.chunkers import (
chunk_data,
)
documents = [
chunk_data(
chunk_strategy="VANILLA",
source_data=observation,
chunk_size=300,
chunk_overlap=20,
)
]
for doc in documents[0]:
chunk_count += 1
params["chunk_order"] = chunk_count
params["source"] = "User loaded"
logging.info(
"Loading document with default loader settings %s", str(doc)
)
logging.info("Params are %s", str(params))
retriever.add_documents(
[Document(metadata=params, page_content=doc.page_content)]
)
async def fetch_memories(
self,
observation: str,
namespace: str = None,
search_type: str = "hybrid",
params=None,
**kwargs,
):
"""
Fetch documents from weaviate.
Parameters:
- observation (str): User query.
- namespace (str, optional): Type of memory accessed.
- search_type (str, optional): Type of search ('text', 'hybrid', 'bm25', 'generate', 'generate_grouped'). Defaults to 'hybrid'.
- **kwargs: Additional parameters for flexibility.
Returns:
List of documents matching the query or an empty list in case of error.
Example:
fetch_memories(query="some query", search_type='text', additional_param='value')
"""
client = self.init_weaviate(namespace=self.namespace)
if search_type is None:
search_type = "hybrid"
if not namespace:
namespace = self.namespace
logging.info("Query on namespace %s", namespace)
params_user_id = {
"path": ["user_id"],
"operator": "Like",
"valueText": self.user_id,
}
def list_objects_of_class(class_name, schema):
return [
prop["name"]
for class_obj in schema["classes"]
if class_obj["class"] == class_name
for prop in class_obj["properties"]
]
base_query = (
client.query.get(
namespace, list(list_objects_of_class(namespace, client.schema.get()))
)
.with_additional(
["id", "creationTimeUnix", "lastUpdateTimeUnix", "score", "distance"]
)
.with_where(params_user_id)
.with_limit(10)
)
n_of_observations = kwargs.get("n_of_observations", 2)
# try:
if search_type == "text":
query_output = (
base_query.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
elif search_type == "hybrid":
query_output = (
base_query.with_hybrid(
query=observation, fusion_type=HybridFusion.RELATIVE_SCORE
)
.with_autocut(n_of_observations)
.do()
)
elif search_type == "bm25":
query_output = (
base_query.with_bm25(query=observation)
.with_autocut(n_of_observations)
.do()
)
elif search_type == "summary":
filter_object = {
"operator": "And",
"operands": [
{
"path": ["user_id"],
"operator": "Equal",
"valueText": self.user_id,
},
{
"path": ["chunk_order"],
"operator": "LessThan",
"valueNumber": 30,
},
],
}
base_query = (
client.query.get(
namespace,
list(list_objects_of_class(namespace, client.schema.get())),
)
.with_additional(
[
"id",
"creationTimeUnix",
"lastUpdateTimeUnix",
"score",
"distance",
]
)
.with_where(filter_object)
.with_limit(30)
)
query_output = (
base_query
# .with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
.do()
)
elif search_type == "summary_filter_by_object_name":
filter_object = {
"operator": "And",
"operands": [
{
"path": ["user_id"],
"operator": "Equal",
"valueText": self.user_id,
},
{
"path": ["doc_id"],
"operator": "Equal",
"valueText": params,
},
],
}
base_query = (
client.query.get(
namespace,
list(list_objects_of_class(namespace, client.schema.get())),
)
.with_additional(
[
"id",
"creationTimeUnix",
"lastUpdateTimeUnix",
"score",
"distance",
]
)
.with_where(filter_object)
.with_limit(30)
.with_hybrid(query=observation, fusion_type=HybridFusion.RELATIVE_SCORE)
)
query_output = base_query.do()
return query_output
elif search_type == "generate":
generate_prompt = kwargs.get("generate_prompt", "")
query_output = (
base_query.with_generate(single_prompt=observation)
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
elif search_type == "generate_grouped":
generate_prompt = kwargs.get("generate_prompt", "")
query_output = (
base_query.with_generate(grouped_task=observation)
.with_near_text({"concepts": [observation]})
.with_autocut(n_of_observations)
.do()
)
else:
logging.error(f"Invalid search_type: {search_type}")
return []
# except Exception as e:
# logging.error(f"Error executing query: {str(e)}")
# return []
return query_output
async def delete_memories(self, namespace: str, params: dict = None):
if namespace is None:
namespace = self.namespace
client = self.init_weaviate(namespace=self.namespace)
if params:
where_filter = {
"path": ["id"],
"operator": "Equal",
"valueText": params.get("id", None),
}
return client.batch.delete_objects(
class_name=self.namespace,
# Same `where` filter as in the GraphQL API
where=where_filter,
)
else:
# Delete all objects
return client.batch.delete_objects(
class_name=namespace,
where={
"path": ["version"],
"operator": "Equal",
"valueText": "1.0",
},
)
async def count_memories(self, namespace: str = None, params: dict = None) -> int:
"""
Count memories in a Weaviate database.
Args:
namespace (str, optional): The Weaviate namespace to count memories in. If not provided, uses the default namespace.
Returns:
int: The number of memories in the specified namespace.
"""
if namespace is None:
namespace = self.namespace
client = self.init_weaviate(namespace=namespace)
try:
object_count = client.query.aggregate(namespace).with_meta_count().do()
return object_count
except Exception as e:
logging.info(f"Error counting memories: {str(e)}")
# Handle the error or log it
return 0
def update_memories(self, observation, namespace: str, params: dict = None):
client = self.init_weaviate(namespace=self.namespace)
client.data_object.update(
data_object={
# "text": observation,
"user_id": str(self.user_id),
"version": params.get("version", None) or "",
"agreement_id": params.get("agreement_id", None) or "",
"privacy_policy": params.get("privacy_policy", None) or "",
"terms_of_service": params.get("terms_of_service", None) or "",
"format": params.get("format", None) or "",
"schema_version": params.get("schema_version", None) or "",
"checksum": params.get("checksum", None) or "",
"owner": params.get("owner", None) or "",
"license": params.get("license", None) or "",
"validity_start": params.get("validity_start", None) or "",
"validity_end": params.get("validity_end", None) or ""
# **source_metadata,
},
class_name="Test",
uuid=params.get("id", None),
consistency_level=weaviate.data.replication.ConsistencyLevel.ALL, # default QUORUM
)
return

View file

@ -0,0 +1,72 @@
from typing import List
from multiprocessing import Pool
import weaviate
import weaviate.classes as wvc
import weaviate.classes.config as wvcc
from weaviate.classes.data import DataObject
from ..vector_db_interface import VectorDBInterface
from ..models.DataPoint import DataPoint
from ..models.ScoredResult import ScoredResult
class WeaviateAdapter(VectorDBInterface):
async_pool: Pool = None
def __init__(self, url: str, api_key: str, openai_api_key: str):
self.client = weaviate.connect_to_wcs(
cluster_url = url,
auth_credentials = weaviate.auth.AuthApiKey(api_key),
headers = {
"X-OpenAI-Api-Key": openai_api_key
},
additional_config = wvc.init.AdditionalConfig(timeou = wvc.init.Timeout(init = 30))
)
async def create_collection(self, collection_name: str, collection_config: dict):
return self.client.collections.create(
name = collection_name,
vectorizer_config = wvcc.Configure.Vectorizer.text2vec_openai(),
generative_config = wvcc.Configure.Generative.openai(),
properties = [
wvcc.Property(
name = "text",
data_type = wvcc.DataType.TEXT
)
]
)
def get_collection(self, collection_name: str):
return self.client.collections.get(collection_name)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
def convert_to_weaviate_data_points(data_point: DataPoint):
return DataObject(
uuid = data_point.id,
properties = data_point.payload
)
objects = list(map(convert_to_weaviate_data_points, data_points))
return self.get_collection(collection_name).data.insert_many(objects)
async def search(self, collection_name: str, query_text: str, limit: int, with_vector: bool = False):
search_result = self.get_collection(collection_name).query.bm25(
query = query_text,
limit = limit,
include_vector = with_vector,
return_metadata = wvc.query.MetadataQuery(score = True),
)
return list(map(lambda result: ScoredResult(
id = result.uuid,
payload = result.properties,
score = str(result.metadata.score)
), search_result.objects))
async def batch_search(self, collection_name: str, query_texts: List[str], limit: int, with_vectors: bool = False):
def query_search(query_text):
return self.search(collection_name, query_text, limit = limit, with_vector = with_vectors)
return [await query_search(query_text) for query_text in query_texts]
async def prune(self):
self.client.collections.delete_all()

View file

@ -0,0 +1 @@
from .WeaviateAdapter import WeaviateAdapter

View file

@ -3,20 +3,26 @@ from cognee.infrastructure.databases.graph.get_graph_client import get_graph_cli
from cognee.shared.data_models import GraphDBType
async def extract_node_descriptions(data):
async def extract_node_descriptions(node):
descriptions = []
for node_id, attributes in data:
if 'description' in attributes and 'unique_id' in attributes:
descriptions.append({'node_id': attributes['unique_id'], 'description': attributes['description'], 'layer_uuid': attributes['layer_uuid'], 'layer_decomposition_uuid': attributes['layer_decomposition_uuid'] })
for node_id, attributes in node:
if "description" in attributes and "unique_id" in attributes:
descriptions.append({
"node_id": attributes["unique_id"],
"description": attributes["description"],
"layer_uuid": attributes["layer_uuid"],
"layer_decomposition_uuid": attributes["layer_decomposition_uuid"]
})
return descriptions
async def add_node_connection(node_descriptions):
async def group_nodes_by_layer(node_descriptions):
grouped_data = {}
for item in node_descriptions:
uuid = item['layer_decomposition_uuid']
uuid = item["layer_decomposition_uuid"]
if uuid not in grouped_data:
grouped_data[uuid] = []
@ -35,19 +41,19 @@ def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph:
"""
for id, relationships in relationship_dict.items():
for relationship in relationships:
searched_node_attr_id = relationship['searched_node_id']
score_attr_id = relationship['original_id_for_search']
score = relationship['score']
searched_node_attr_id = relationship["searched_node_id"]
score_attr_id = relationship["original_id_for_search"]
score = relationship["score"]
# Initialize node keys for both searched_node and score_node
searched_node_key, score_node_key = None, None
# Find nodes in the graph that match the searched_node_id and score_id from their attributes
for node, attrs in graph.nodes(data=True):
if 'unique_id' in attrs: # Ensure there is an 'id' attribute
if attrs['unique_id'] == searched_node_attr_id:
for node, attrs in graph.nodes(data = True):
if "unique_id" in attrs: # Ensure there is an "id" attribute
if attrs["unique_id"] == searched_node_attr_id:
searched_node_key = node
elif attrs['unique_id'] == score_attr_id:
elif attrs["unique_id"] == score_attr_id:
score_node_key = node
# If both nodes are found, no need to continue checking other nodes
@ -57,9 +63,13 @@ def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph:
# Check if both nodes were found in the graph
if searched_node_key is not None and score_node_key is not None:
# If both nodes exist, create an edge between them
# You can customize the edge attributes as needed, here we use 'score' as an attribute
graph.add_edge(searched_node_key, score_node_key, weight=score,
score_metadata=relationship.get('score_metadata'))
# You can customize the edge attributes as needed, here we use "score" as an attribute
graph.add_edge(
searched_node_key,
score_node_key,
weight = score,
score_metadata = relationship.get("score_metadata")
)
return graph
@ -67,31 +77,23 @@ def connect_nodes_in_graph(graph: Graph, relationship_dict: dict) -> Graph:
def graph_ready_output(results):
relationship_dict = {}
for result_tuple in results:
uuid, scored_points_list, desc, node_id = result_tuple
# Unpack the tuple
for result in results:
layer_id = result["layer_id"]
layer_nodes = result["layer_nodes"]
# Ensure there's a list to collect related items for this uuid
if uuid not in relationship_dict:
relationship_dict[uuid] = []
if layer_id not in relationship_dict:
relationship_dict[layer_id] = []
for node in layer_nodes: # Iterate over the list of ScoredPoint lists
for score_point in node["score_points"]:
# Append a new dictionary to the list associated with the uuid
relationship_dict[layer_id].append({
"collection_id": layer_id,
"searched_node_id": node["id"],
"score": score_point.score,
"score_metadata": score_point.payload,
"original_id_for_search": score_point.id,
})
for scored_points in scored_points_list: # Iterate over the list of ScoredPoint lists
for scored_point in scored_points: # Iterate over each ScoredPoint object
if scored_point.score > 0.9: # Check the score condition
# Append a new dictionary to the list associated with the uuid
relationship_dict[uuid].append({
'collection_name_uuid': uuid,
'searched_node_id': scored_point.id,
'score': scored_point.score,
'score_metadata': scored_point.payload,
'original_id_for_search': node_id,
})
return relationship_dict
if __name__ == '__main__':
graph_client = get_graph_client(GraphDBType.NETWORKX)
add_node_connection(graph_client)

View file

@ -3,6 +3,7 @@ import uuid
import json
from datetime import datetime
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client, GraphDBType
from cognee.shared.encode_uuid import encode_uuid
async def add_propositions(
@ -69,7 +70,6 @@ async def add_propositions(
async def append_to_graph(layer_graphs, required_layers):
# Generate a UUID for the overall layer
layer_uuid = uuid.uuid4()
decomposition_uuids = set()
# Extract category name from required_layers data
data_type = required_layers["data_type"]
@ -84,9 +84,7 @@ async def append_to_graph(layer_graphs, required_layers):
layer_description = json.loads(layer_json)
# Generate a UUID for this particular layer decomposition
layer_decomposition_uuid = uuid.uuid4()
decomposition_uuids.add(layer_decomposition_uuid)
layer_decomposition_id = encode_uuid(uuid.uuid4())
# Assuming append_data_to_graph is defined elsewhere and appends data to graph_client
# You would pass relevant information from knowledge_graph along with other details to this function
@ -96,11 +94,9 @@ async def append_to_graph(layer_graphs, required_layers):
layer_description,
knowledge_graph,
layer_uuid,
layer_decomposition_uuid
layer_decomposition_id
)
return decomposition_uuids
# if __name__ == "__main__":
# import asyncio

View file

@ -22,10 +22,10 @@ async def process_attribute(graph_client, parent_id: Optional[str], attribute: s
if isinstance(value, BaseModel):
node_id = await generate_node_id(value)
node_data = value.dict(exclude={"default_relationship"})
node_data = value.model_dump(exclude = {"default_relationship"})
# Use the specified default relationship for the edge between the parent node and the current node
relationship_data = value.default_relationship.dict() if hasattr(value, "default_relationship") else {}
relationship_data = value.default_relationship.model_dump() if hasattr(value, "default_relationship") else {}
await add_node_and_edge(graph_client, parent_id, node_id, node_data, relationship_data)
@ -41,7 +41,7 @@ async def process_attribute(graph_client, parent_id: Optional[str], attribute: s
async def create_dynamic(graph_model) :
root_id = await generate_node_id(graph_model)
node_data = graph_model.dict(exclude = {"default_relationship", "id"})
node_data = graph_model.model_dump(exclude = {"default_relationship", "id"})
graph_client = get_graph_client(GraphDBType.NETWORKX)

View file

@ -1,30 +0,0 @@
import asyncio
async def process_items(grouped_data, unique_layer_uuids, llm_client):
results_to_check = [] # This will hold results excluding self comparisons
tasks = [] # List to hold all tasks
task_to_info = {} # Dictionary to map tasks to their corresponding group id and item info
# Iterate through each group in grouped_data
for group_id, items in grouped_data.items():
# Filter unique_layer_uuids to exclude the current group_id
target_uuids = [uuid for uuid in unique_layer_uuids if uuid != group_id]
# Process each item in the group
for item in items:
# For each target UUID, create an async task for the item's embedding retrieval
for target_id in target_uuids:
task = asyncio.create_task(llm_client.async_get_embedding_with_backoff(item['description'], "text-embedding-3-large"))
tasks.append(task)
# Map the task to the target id, item's node_id, and description for later retrieval
task_to_info[task] = (target_id, item['node_id'], group_id, item['description'])
# Await all tasks to complete and gather results
results = await asyncio.gather(*tasks)
# Process the results, associating them with their target id, node id, and description
for task, embedding in zip(tasks, results):
target_id, node_id, group_id, description = task_to_info[task]
results_to_check.append([target_id, embedding, description, node_id, group_id])
return results_to_check

View file

@ -11,7 +11,7 @@ async def classify_into_categories(text_input: str, system_prompt_file: str, res
llm_output = await llm_client.acreate_structured_output(text_input, system_prompt, response_model)
return extract_categories(llm_output.dict())
return extract_categories(llm_output.model_dump())
def extract_categories(llm_output) -> List[dict]:
# Extract the first subclass from the list (assuming there could be more)

View file

@ -0,0 +1,39 @@
from typing import Dict, List
from cognee.infrastructure.databases.vector import get_vector_database
from cognee.infrastructure import infrastructure_config
async def resolve_cross_graph_references(nodes_by_layer: Dict):
results = []
unique_layers = nodes_by_layer.keys()
for layer_id, layer_nodes in nodes_by_layer.items():
# Filter unique_layer_uuids to exclude the current layer
other_layers = [uuid for uuid in unique_layers if uuid != layer_id]
for other_layer in other_layers:
results.append(await get_nodes_by_layer(other_layer, layer_nodes))
return results
async def get_nodes_by_layer(layer_id: str, layer_nodes: List):
vector_engine = infrastructure_config.get_config()["vector_engine"]
score_points = await vector_engine.batch_search(
layer_id,
list(map(lambda node: node["description"], layer_nodes)),
limit = 3
)
return {
"layer_id": layer_id,
"layer_nodes": connect_score_points_to_node(score_points, layer_nodes)
}
def connect_score_points_to_node(score_points, layer_nodes):
return [
{
"id": node["node_id"],
"score_points": score_points[node_index]
} for node_index, node in enumerate(layer_nodes)
]

View file

@ -1,40 +1,28 @@
import asyncio
from qdrant_client import models
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.databases.vector import get_vector_database
from cognee.infrastructure.databases.vector import DataPoint
from cognee.infrastructure import infrastructure_config
async def get_embeddings(texts:list):
""" Get embeddings for a list of texts"""
client = get_llm_client()
tasks = [client.async_get_embedding_with_backoff(text, "text-embedding-3-large") for text in texts]
return await asyncio.gather(*tasks)
async def add_proposition_to_vector_store(id, metadata, embeddings, collection_name):
""" Upload a single embedding to a collection in Qdrant."""
client = get_vector_database()
await client.create_data_points(
collection_name = collection_name,
data_points = [
models.PointStruct(
id = id,
payload = metadata,
vector = {"content" : embeddings}
)
]
def convert_to_data_point(node):
return DataPoint(
id = node["node_id"],
payload = {
"text": node["description"]
},
embed_field = "text"
)
async def add_propositions(nodes_by_layer):
vector_engine = infrastructure_config.get_config()["vector_engine"]
async def add_propositions(node_descriptions):
for item in node_descriptions:
embeddings = await get_embeddings([item["description"]])
awaitables = []
await add_proposition_to_vector_store(
id = item["node_id"],
metadata = {
"meta": item["description"]
},
embeddings = embeddings[0],
collection_name = item["layer_decomposition_uuid"]
for layer_id, layer_nodes in nodes_by_layer.items():
awaitables.append(
vector_engine.create_data_points(
collection_name = layer_id,
data_points = list(map(convert_to_data_point, layer_nodes))
)
)
return await asyncio.gather(*awaitables)

View file

@ -1,24 +0,0 @@
async def adapted_qdrant_batch_search(results_to_check, vector_client):
search_results_list = []
for result in results_to_check:
id = result[0]
embedding = result[1]
node_id = result[2]
target = result[3]
# Assuming each result in results_to_check contains a single embedding
limits = [3] * len(embedding) # Set a limit of 3 results for this embedding
try:
#Perform the batch search for this id with its embedding
#Assuming qdrant_batch_search function accepts a single embedding and a list of limits
#qdrant_batch_search
id_search_results = await vector_client.batch_search(collection_name = id, embeddings = embedding, with_vectors = limits)
search_results_list.append((id, id_search_results, node_id, target))
except Exception as e:
print(f"Error during batch search for ID {id}: {e}")
continue
return search_results_list

View file

@ -1,7 +1,7 @@
""" This module contains the function to find the neighbours of a given node in the graph"""
async def search_adjacent(graph,query:str, other_param:dict = None)->dict:
async def search_adjacent(graph, query: str, other_param: dict = None) -> dict:
""" Find the neighbours of a given node in the graph
:param graph: A NetworkX graph object

View file

@ -1,21 +1,19 @@
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.modules.cognify.graph.add_node_connections import extract_node_descriptions
from cognee.infrastructure.databases.vector.get_vector_database import get_vector_database
from cognee.infrastructure import infrastructure_config
async def search_similarity(query:str ,graph,other_param:str = None):
async def search_similarity(query: str, graph, other_param: str = None):
node_descriptions = await extract_node_descriptions(graph.nodes(data = True))
unique_layer_uuids = set(node["layer_decomposition_uuid"] for node in node_descriptions)
client = get_llm_client()
out = []
query = await client.async_get_embedding_with_backoff(query)
for id in unique_layer_uuids:
vector_client = get_vector_database()
result = await vector_client.search(id, query,10)
for id in unique_layer_uuids:
vector_engine = infrastructure_config.get_config()["vector_engine"]
result = await vector_engine.search(id, query, 10)
if result:
result_ = [ result_.id for result_ in result]
@ -25,13 +23,16 @@ async def search_similarity(query:str ,graph,other_param:str = None):
relevant_context = []
if len(out) == 0:
return []
for proposition_id in out[0][0]:
for n,attr in graph.nodes(data=True):
for n, attr in graph.nodes(data = True):
if proposition_id in n:
for n_, attr_ in graph.nodes(data=True):
relevant_layer = attr['layer_uuid']
relevant_layer = attr["layer_uuid"]
if attr_.get('layer_uuid') == relevant_layer:
relevant_context.append(attr_['description'])
if attr_.get("layer_uuid") == relevant_layer:
relevant_context.append(attr_["description"])
return relevant_context

View file

@ -0,0 +1,14 @@
from uuid import UUID
def encode_uuid(uuid: UUID) -> str:
uuid_int = uuid.int
base = 52
charset = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
encoded = ''
while len(encoded) < 36:
uuid_int, remainder = divmod(uuid_int, base)
uuid_int = uuid_int * 8
encoded = charset[remainder] + encoded
return encoded

74
poetry.lock generated
View file

@ -412,6 +412,20 @@ tests = ["attrs[tests-no-zope]", "zope-interface"]
tests-mypy = ["mypy (>=1.6)", "pytest-mypy-plugins"]
tests-no-zope = ["attrs[tests-mypy]", "cloudpickle", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
[[package]]
name = "authlib"
version = "1.3.0"
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
optional = false
python-versions = ">=3.8"
files = [
{file = "Authlib-1.3.0-py2.py3-none-any.whl", hash = "sha256:9637e4de1fb498310a56900b3e2043a206b03cb11c05422014b0302cbc814be3"},
{file = "Authlib-1.3.0.tar.gz", hash = "sha256:959ea62a5b7b5123c5059758296122b57cd2585ae2ed1c0622c21b371ffdae06"},
]
[package.dependencies]
cryptography = "*"
[[package]]
name = "babel"
version = "2.14.0"
@ -1955,6 +1969,21 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"]
[[package]]
name = "grpcio-health-checking"
version = "1.62.1"
description = "Standard Health Checking Service for gRPC"
optional = false
python-versions = ">=3.6"
files = [
{file = "grpcio-health-checking-1.62.1.tar.gz", hash = "sha256:9e56180a941b1d32a077d7491e0611d0483c396358afd5349bf00152612e4583"},
{file = "grpcio_health_checking-1.62.1-py3-none-any.whl", hash = "sha256:9ce761c09fc383e7aa2f7e6c0b0b65d5a1157c1b98d1f5871f7c38aca47d49b9"},
]
[package.dependencies]
grpcio = ">=1.62.1"
protobuf = ">=4.21.6"
[[package]]
name = "grpcio-tools"
version = "1.62.1"
@ -7690,6 +7719,28 @@ h11 = ">=0.8"
[package.extras]
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
[[package]]
name = "validators"
version = "0.22.0"
description = "Python Data Validation for Humans™"
optional = false
python-versions = ">=3.8"
files = [
{file = "validators-0.22.0-py3-none-any.whl", hash = "sha256:61cf7d4a62bbae559f2e54aed3b000cea9ff3e2fdbe463f51179b92c58c9585a"},
{file = "validators-0.22.0.tar.gz", hash = "sha256:77b2689b172eeeb600d9605ab86194641670cdb73b60afd577142a9397873370"},
]
[package.extras]
docs-offline = ["myst-parser (>=2.0.0)", "pypandoc-binary (>=1.11)", "sphinx (>=7.1.1)"]
docs-online = ["mkdocs (>=1.5.2)", "mkdocs-git-revision-date-localized-plugin (>=1.2.0)", "mkdocs-material (>=9.2.6)", "mkdocstrings[python] (>=0.22.0)", "pyaml (>=23.7.0)"]
hooks = ["pre-commit (>=3.3.3)"]
package = ["build (>=1.0.0)", "twine (>=4.0.2)"]
runner = ["tox (>=4.11.1)"]
sast = ["bandit[toml] (>=1.7.5)"]
testing = ["pytest (>=7.4.0)"]
tooling = ["black (>=23.7.0)", "pyright (>=1.1.325)", "ruff (>=0.0.287)"]
tooling-extras = ["pyaml (>=23.7.0)", "pypandoc-binary (>=1.11)", "pytest (>=7.4.0)"]
[[package]]
name = "watchdog"
version = "4.0.0"
@ -7742,6 +7793,27 @@ files = [
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
[[package]]
name = "weaviate-client"
version = "4.5.4"
description = "A python native Weaviate client"
optional = false
python-versions = ">=3.8"
files = [
{file = "weaviate-client-4.5.4.tar.gz", hash = "sha256:fc53dc73cd53df453c5e6dc758e49a6a1549212d6670ddd013392107120692f8"},
{file = "weaviate_client-4.5.4-py3-none-any.whl", hash = "sha256:f6d3a6b759e5aa0d3350067490526ea38b9274ae4043b4a3ae0064c28d56883f"},
]
[package.dependencies]
authlib = ">=1.2.1,<2.0.0"
grpcio = ">=1.57.0,<2.0.0"
grpcio-health-checking = ">=1.57.0,<2.0.0"
grpcio-tools = ">=1.57.0,<2.0.0"
httpx = "0.27.0"
pydantic = ">=2.5.0,<3.0.0"
requests = ">=2.30.0,<3.0.0"
validators = "0.22.0"
[[package]]
name = "webcolors"
version = "1.13"
@ -8057,4 +8129,4 @@ weaviate = []
[metadata]
lock-version = "2.0"
python-versions = "~3.10"
content-hash = "37a0db9a6a86b71a35c91ac5ef86204d76529033260032917906a907bffc8216"
content-hash = "d742617c6e8a62dc9bff8656fa97955f63c991720805388190b205da66e4712a"

View file

@ -51,6 +51,7 @@ qdrant-client = "^1.8.0"
duckdb-engine = "^0.11.2"
graphistry = "^0.33.5"
tenacity = "^8.2.3"
weaviate-client = "^4.5.4"
[tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]