fix deployment

This commit is contained in:
Vasilije 2024-05-17 10:09:43 +02:00
commit 79311ee510
24 changed files with 415 additions and 8614 deletions

2
.gitignore vendored
View file

@ -160,7 +160,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
.vscode/ .vscode/
cognee/data/ cognee/data/

View file

@ -41,14 +41,24 @@ Join our <a href="https://discord.gg/NQPKmU5CCg">Discord</a> community
## 📦 Installation ## 📦 Installation
With pip: ### With pip
```bash
pip install cognee
```
Use Weaviate vector storage:
```bash ```bash
pip install "cognee[weaviate]" pip install "cognee[weaviate]"
``` ```
With poetry: ### With poetry
```bash
poetry add cognee
```
Use Weaviate vector storage:
```bash ```bash
poetry add "cognee[weaviate]" poetry add "cognee[weaviate]"
``` ```

View file

@ -21,7 +21,6 @@ from cognee.modules.cognify.graph.add_node_connections import group_nodes_by_lay
graph_ready_output, connect_nodes_in_graph graph_ready_output, connect_nodes_in_graph
from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references from cognee.modules.cognify.llm.resolve_cross_graph_references import resolve_cross_graph_references
from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client from cognee.infrastructure.databases.graph.get_graph_client import get_graph_client
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers from cognee.modules.cognify.graph.add_cognitive_layers import add_cognitive_layers
# from cognee.modules.cognify.graph.initialize_graph import initialize_graph # from cognee.modules.cognify.graph.initialize_graph import initialize_graph
from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException from cognee.infrastructure.files.utils.guess_file_type import guess_file_type, FileTypeException
@ -49,7 +48,7 @@ logger = logging.getLogger("cognify")
async def cognify(datasets: Union[str, List[str]] = None): async def cognify(datasets: Union[str, List[str]] = None):
"""This function is responsible for the cognitive processing of the content.""" """This function is responsible for the cognitive processing of the content."""
# Has to be loaded in advance, multithreading doesn't work without it. # Has to be loaded in advance, multithreading doesn't work without it.
nltk.download('stopwords', quiet=True) nltk.download("stopwords", quiet=True)
stopwords.ensure_loaded() stopwords.ensure_loaded()
graph_db_type = infrastructure_config.get_config()["graph_engine"] graph_db_type = infrastructure_config.get_config()["graph_engine"]
@ -91,6 +90,12 @@ async def cognify(datasets: Union[str, List[str]] = None):
for dataset_name, file_metadata in files_batch: for dataset_name, file_metadata in files_batch:
with open(file_metadata["file_path"], "rb") as file: with open(file_metadata["file_path"], "rb") as file:
try: try:
document_id = await add_document_node(
graph_client,
parent_node_id = f"DefaultGraphModel__{USER_ID}",
document_metadata = file_metadata,
)
file_type = guess_file_type(file) file_type = guess_file_type(file)
text = extract_text_from_file(file, file_type) text = extract_text_from_file(file, file_type)
if text is None: if text is None:
@ -101,17 +106,18 @@ async def cognify(datasets: Union[str, List[str]] = None):
data_chunks[dataset_name] = [] data_chunks[dataset_name] = []
for subchunk in subchunks: for subchunk in subchunks:
data_chunks[dataset_name].append( data_chunks[dataset_name].append(dict(document_id = document_id, chunk_id = str(uuid4()), text = subchunk))
dict(text=subchunk, chunk_id=str(uuid4()), file_metadata=file_metadata))
except FileTypeException: except FileTypeException:
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"]) logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
added_chunks = await add_data_chunks(data_chunks) added_chunks = await add_data_chunks(data_chunks)
# await asyncio.gather(
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in await asyncio.gather(
# added_chunks] *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in
# ) added_chunks]
)
batch_size = 20 batch_size = 20
file_count = 0 file_count = 0
@ -193,23 +199,44 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}] classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
await asyncio.gather(
*[process_text(chunk["document_id"], chunk["chunk_id"], chunk["collection"], chunk["text"]) for chunk in added_chunks]
)
print(f"Chunk ({chunk_id}) classified.") return graph_client.graph
#
# print("document_id", document_id) # async def process_text(document_id: str, chunk_id: str, chunk_collection: str, input_text: str):
# # raw_document_id = document_id.split("__")[-1]
# content_summary = await get_content_summary(input_text) #
# await add_summary_nodes(graph_client, document_id, content_summary) # print(f"Processing chunk ({chunk_id}) from document ({raw_document_id}).")
#
print(f"Chunk ({chunk_id}) summarized.") # graph_client = await get_graph_client(infrastructure_config.get_config()["graph_engine"])
# #
cognitive_layers = await get_cognitive_layers(input_text, classified_categories) # classified_categories = await get_content_categories(input_text)
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2] # await add_classification_nodes(
# # graph_client,
layer_graphs = await get_layer_graphs(input_text, cognitive_layers) # parent_node_id = document_id,
await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs) # categories = classified_categories,
# )
print("got here 4444") # >>>>>>> origin/main
#
# print(f"Chunk ({chunk_id}) classified.")
#
# # print("document_id", document_id)
# #
# # content_summary = await get_content_summary(input_text)
# # await add_summary_nodes(graph_client, document_id, content_summary)
#
# print(f"Chunk ({chunk_id}) summarized.")
# #
# cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
# cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
# #
# layer_graphs = await get_layer_graphs(input_text, cognitive_layers)
# await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs)
#
# <<<<<<< HEAD
# print("got here 4444")
# #
# if infrastructure_config.get_config()["connect_documents"] is True: # if infrastructure_config.get_config()["connect_documents"] is True:
# db_engine = infrastructure_config.get_config()["database_engine"] # db_engine = infrastructure_config.get_config()["database_engine"]
@ -240,6 +267,37 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
# send_telemetry("cognee.cognify") # send_telemetry("cognee.cognify")
# #
# print(f"Chunk ({chunk_id}) cognified.") # print(f"Chunk ({chunk_id}) cognified.")
# =======
# if infrastructure_config.get_config()["connect_documents"] is True:
# db_engine = infrastructure_config.get_config()["database_engine"]
# relevant_documents_to_connect = db_engine.fetch_cognify_data(excluded_document_id = raw_document_id)
#
# list_of_nodes = []
#
# relevant_documents_to_connect.append({
# "layer_id": document_id,
# })
#
# for document in relevant_documents_to_connect:
# node_descriptions_to_match = await graph_client.extract_node_description(document["layer_id"])
# list_of_nodes.extend(node_descriptions_to_match)
#
# nodes_by_layer = await group_nodes_by_layer(list_of_nodes)
#
# results = await resolve_cross_graph_references(nodes_by_layer)
#
# relationships = graph_ready_output(results)
#
# await connect_nodes_in_graph(
# graph_client,
# relationships,
# score_threshold = infrastructure_config.get_config()["intra_layer_score_treshold"]
# )
#
# send_telemetry("cognee.cognify")
#
# print(f"Chunk ({chunk_id}) cognified.")
# >>>>>>> origin/main
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -1,7 +1,6 @@
from cognee.config import Config from cognee.config import Config
from .databases.relational import DuckDBAdapter, DatabaseEngine from .databases.relational import DuckDBAdapter, DatabaseEngine
from .databases.vector.vector_db_interface import VectorDBInterface from .databases.vector.vector_db_interface import VectorDBInterface
from .databases.vector.qdrant.QDrantAdapter import QDrantAdapter
from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine from .databases.vector.embeddings.DefaultEmbeddingEngine import DefaultEmbeddingEngine
from .llm.llm_interface import LLMInterface from .llm.llm_interface import LLMInterface
from .llm.openai.adapter import OpenAIAdapter from .llm.openai.adapter import OpenAIAdapter
@ -85,6 +84,12 @@ class InfrastructureConfig():
if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None: if (config_entity is None or config_entity == "llm_engine") and self.llm_engine is None:
self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model) self.llm_engine = OpenAIAdapter(config.openai_key, config.openai_model)
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None:
self.database_directory_path = self.system_root_directory + "/" + config.db_path
if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None:
self.database_file_path = self.system_root_directory + "/" + config.db_path + "/" + config.db_name
if (config_entity is None or config_entity == "vector_engine") and self.vector_engine is None: if (config_entity is None or config_entity == "vector_engine") and self.vector_engine is None:
try: try:
from .databases.vector.weaviate_db import WeaviateAdapter from .databases.vector.weaviate_db import WeaviateAdapter
@ -98,17 +103,24 @@ class InfrastructureConfig():
embedding_engine = self.embedding_engine embedding_engine = self.embedding_engine
) )
except (EnvironmentError, ModuleNotFoundError): except (EnvironmentError, ModuleNotFoundError):
self.vector_engine = QDrantAdapter( if config.qdrant_url and config.qdrant_api_key:
qdrant_url = config.qdrant_url, from .databases.vector.qdrant.QDrantAdapter import QDrantAdapter
qdrant_api_key = config.qdrant_api_key,
embedding_engine = self.embedding_engine
)
if (config_entity is None or config_entity == "database_directory_path") and self.database_directory_path is None: self.vector_engine = QDrantAdapter(
self.database_directory_path = self.system_root_directory + "/" + config.db_path qdrant_url = config.qdrant_url,
qdrant_api_key = config.qdrant_api_key,
embedding_engine = self.embedding_engine
)
else:
from .databases.vector.lancedb.LanceDBAdapter import LanceDBAdapter
lance_db_path = self.database_directory_path + "/cognee.lancedb"
LocalStorage.ensure_directory_exists(lance_db_path)
if (config_entity is None or config_entity == "database_file_path") and self.database_file_path is None: self.vector_engine = LanceDBAdapter(
self.database_file_path = self.system_root_directory + "/" + config.db_path + "/" + config.db_name uri = lance_db_path,
api_key = None,
embedding_engine = self.embedding_engine,
)
if config_entity is not None: if config_entity is not None:
return getattr(self, config_entity) return getattr(self, config_entity)

View file

@ -4,8 +4,6 @@ from typing import List
import instructor import instructor
from openai import AsyncOpenAI from openai import AsyncOpenAI
from fastembed import TextEmbedding from fastembed import TextEmbedding
from fastembed import TextEmbedding
from openai import AsyncOpenAI
from cognee.config import Config from cognee.config import Config
from cognee.root_dir import get_absolute_path from cognee.root_dir import get_absolute_path

View file

@ -1,7 +1,7 @@
from typing import List, Protocol from typing import Protocol
class EmbeddingEngine(Protocol): class EmbeddingEngine(Protocol):
async def embed_text(self, text: str) -> List[float]: async def embed_text(self, text: list[str]) -> list[list[float]]:
raise NotImplementedError() raise NotImplementedError()
def get_vector_size(self) -> int: def get_vector_size(self) -> int:

View file

@ -1,74 +1,145 @@
from typing import List, Optional, get_type_hints, Generic, TypeVar
import asyncio import asyncio
from pydantic import BaseModel, Field
import lancedb import lancedb
from typing import List, Optional from lancedb.pydantic import Vector, LanceModel
from cognee.infrastructure.files.storage import LocalStorage
from ..models.ScoredResult import ScoredResult
from ..vector_db_interface import VectorDBInterface, DataPoint
from ..embeddings.EmbeddingEngine import EmbeddingEngine
import asyncio class LanceDBAdapter(VectorDBInterface):
import lancedb connection: lancedb.AsyncConnection = None
from pathlib import Path
import tempfile
class LanceDBAdapter: def __init__(
def __init__(self, uri: Optional[str] = None, api_key: Optional[str] = None): self,
if uri: uri: Optional[str],
self.uri = uri api_key: Optional[str],
else: embedding_engine: EmbeddingEngine,
# Create a temporary directory for the LanceDB 'in-memory' simulation ):
self.temp_dir = tempfile.mkdtemp(suffix='.lancedb') self.uri = uri
self.uri = f"file://{self.temp_dir}"
self.api_key = api_key self.api_key = api_key
self.db = None self.embedding_engine = embedding_engine
async def connect(self): async def get_connection(self):
# Asynchronously connect to a LanceDB database, effectively in-memory if no URI is provided if self.connection is None:
self.db = await lancedb.connect_async(self.uri, api_key=self.api_key) self.connection = await lancedb.connect_async(self.uri, api_key = self.api_key)
async def disconnect(self): return self.connection
# Disconnect and clean up the database if it was set up as temporary
await self.db.close()
if hasattr(self, 'temp_dir'):
Path(self.temp_dir).unlink(missing_ok=True) # Remove the temporary directory
async def create_table(self, table_name: str, schema=None, data=None): async def embed_data(self, data: list[str]) -> list[list[float]]:
if not await self.table_exists(table_name): return await self.embedding_engine.embed_text(data)
return await self.db.create_table(name=table_name, schema=schema, data=data)
else:
raise ValueError(f"Table {table_name} already exists")
async def table_exists(self, table_name: str) -> bool: async def collection_exists(self, collection_name: str) -> bool:
table_names = await self.db.table_names() connection = await self.get_connection()
return table_name in table_names collection_names = await connection.table_names()
return collection_name in collection_names
async def insert_data(self, table_name: str, data_points: List[dict]): async def create_collection(self, collection_name: str, payload_schema: BaseModel):
table = await self.db.open_table(table_name) data_point_types = get_type_hints(DataPoint)
await table.add(data_points)
async def query_data(self, table_name: str, query=None, limit=10): class LanceDataPoint(LanceModel):
# Asynchronously query data from a table id: data_point_types["id"] = Field(...)
table = await self.db.open_table(table_name) vector: Vector(self.embedding_engine.get_vector_size())
if query: payload: payload_schema
query_result = await table.query().where(query).limit(limit).to_pandas()
else:
query_result = await table.query().limit(limit).to_pandas()
return query_result
async def vector_search(self, table_name: str, query_vector: List[float], limit=10): if not await self.collection_exists(collection_name):
# Perform an asynchronous vector search connection = await self.get_connection()
table = await self.db.open_table(table_name) return await connection.create_table(
query_result = await table.vector_search().nearest_to(query_vector).limit(limit).to_pandas() name = collection_name,
return query_result schema = LanceDataPoint,
exist_ok = True,
)
async def create_data_points(self, collection_name: str, data_points: List[DataPoint]):
connection = await self.get_connection()
async def main(): if not await self.collection_exists(collection_name):
# Example without providing a URI, simulates in-memory behavior await self.create_collection(
adapter = LanceDBAdapter() collection_name,
await adapter.connect() payload_schema = type(data_points[0].payload),
)
try: collection = await connection.open_table(collection_name)
await adapter.create_table("my_table")
data_points = [{"id": 1, "text": "example", "vector": [0.1, 0.2, 0.3]}]
await adapter.insert_data("my_table", data_points)
finally:
await adapter.disconnect()
if __name__ == "__main__": data_vectors = await self.embed_data(
asyncio.run(main()) [data_point.get_embeddable_data() for data_point in data_points]
)
IdType = TypeVar("IdType")
PayloadSchema = TypeVar("PayloadSchema")
class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]):
id: IdType
vector: Vector(self.embedding_engine.get_vector_size())
payload: PayloadSchema
lance_data_points = [
LanceDataPoint[type(data_point.id), type(data_point.payload)](
id = data_point.id,
vector = data_vectors[data_index],
payload = data_point.payload,
) for (data_index, data_point) in enumerate(data_points)
]
await collection.add(lance_data_points)
async def retrieve(self, collection_name: str, data_point_id: str):
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.query().where(f"id = '{data_point_id}'").to_pandas()
result = results.to_dict("index")[0]
return ScoredResult(
id = result["id"],
payload = result["payload"],
score = 1,
)
async def search(
self,
collection_name: str,
query_text: str = None,
query_vector: List[float] = None,
limit: int = 10,
with_vector: bool = False,
):
if query_text is None and query_vector is None:
raise ValueError("One of query_text or query_vector must be provided!")
if query_text and not query_vector:
query_vector = (await self.embedding_engine.embed_text([query_text]))[0]
connection = await self.get_connection()
collection = await connection.open_table(collection_name)
results = await collection.vector_search(query_vector).limit(limit).to_pandas()
return [ScoredResult(
id = str(result["id"]),
score = float(result["_distance"]),
payload = result["payload"],
) for result in results.to_dict("index").values()]
async def batch_search(
self,
collection_name: str,
query_texts: List[str],
limit: int = None,
with_vector: bool = False,
):
query_vectors = await self.embedding_engine.embed_text(query_texts)
return asyncio.gather(
*[self.search(
collection_name = collection_name,
query_vector = query_vector,
limit = limit,
with_vector = with_vector,
) for query_vector in query_vectors]
)
async def prune(self):
# Clean up the database if it was set up as temporary
if self.uri.startswith("/"):
LocalStorage.remove_all(self.uri) # Remove the temporary directory and files inside

View file

@ -1,10 +1,13 @@
from typing import Union from typing import Generic, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
class DataPoint(BaseModel): PayloadSchema = TypeVar("PayloadSchema", bound = BaseModel)
class DataPoint(BaseModel, Generic[PayloadSchema]):
id: str id: str
payload: dict[str, Union[str, dict[str, str]]] payload: PayloadSchema
embed_field: str = "value" embed_field: str = "value"
def get_embeddable_data(self): def get_embeddable_data(self):
return self.payload[self.embed_field] if hasattr(self.payload, self.embed_field):
return getattr(self.payload, self.embed_field)

View file

@ -0,0 +1,3 @@
from typing import TypeVar
PayloadSchema = TypeVar("PayloadSchema")

View file

@ -68,6 +68,7 @@ class QDrantAdapter(VectorDBInterface):
async def create_collection( async def create_collection(
self, self,
collection_name: str, collection_name: str,
payload_schema = None,
): ):
client = self.get_qdrant_client() client = self.get_qdrant_client()
@ -93,7 +94,7 @@ class QDrantAdapter(VectorDBInterface):
def convert_to_qdrant_point(data_point: DataPoint): def convert_to_qdrant_point(data_point: DataPoint):
return models.PointStruct( return models.PointStruct(
id = data_point.id, id = data_point.id,
payload = data_point.payload, payload = data_point.payload.dict(),
vector = { vector = {
"text": data_vectors[data_points.index(data_point)] "text": data_vectors[data_points.index(data_point)]
} }
@ -110,9 +111,9 @@ class QDrantAdapter(VectorDBInterface):
return result return result
async def retrieve(self, collection_name: str, data_id: str): async def retrieve(self, collection_name: str, data_point_id: str):
client = self.get_qdrant_client() client = self.get_qdrant_client()
results = await client.retrieve(collection_name, [data_id], with_payload = True) results = await client.retrieve(collection_name, [data_point_id], with_payload = True)
await client.close() await client.close()
return results[0] if len(results) > 0 else None return results[0] if len(results) > 0 else None

View file

@ -1,42 +1,21 @@
from typing import List, Protocol, Optional from typing import List, Protocol, Optional
from abc import abstractmethod from abc import abstractmethod
from .models.DataPoint import DataPoint from .models.DataPoint import DataPoint
from .models.PayloadSchema import PayloadSchema
class VectorDBInterface(Protocol): class VectorDBInterface(Protocol):
""" Collections """ """ Collections """
@abstractmethod
async def collection_exists(self, collection_name: str) -> bool:
raise NotImplementedError
@abstractmethod @abstractmethod
async def create_collection( async def create_collection(
self, self,
collection_name: str collection_name: str,
payload_schema: Optional[PayloadSchema] = None,
): raise NotImplementedError ): raise NotImplementedError
# @abstractmethod
# async def update_collection(
# self,
# collection_name: str,
# collection_config: object
# ): raise NotImplementedError
# @abstractmethod
# async def delete_collection(
# self,
# collection_name: str
# ): raise NotImplementedError
# @abstractmethod
# async def create_vector_index(
# self,
# collection_name: str,
# vector_index_config: object
# ): raise NotImplementedError
# @abstractmethod
# async def create_data_index(
# self,
# collection_name: str,
# vector_index_config: object
# ): raise NotImplementedError
""" Data points """ """ Data points """
@abstractmethod @abstractmethod
async def create_data_points( async def create_data_points(
@ -45,27 +24,12 @@ class VectorDBInterface(Protocol):
data_points: List[DataPoint] data_points: List[DataPoint]
): raise NotImplementedError ): raise NotImplementedError
# @abstractmethod @abstractmethod
# async def get_data_point( async def retrieve(
# self, self,
# collection_name: str, collection_name: str,
# data_point_id: str data_point_id: str
# ): raise NotImplementedError ): raise NotImplementedError
# @abstractmethod
# async def update_data_point(
# self,
# collection_name: str,
# data_point_id: str,
# payload: object
# ): raise NotImplementedError
# @abstractmethod
# async def delete_data_point(
# self,
# collection_name: str,
# data_point_id: str
# ): raise NotImplementedError
""" Search """ """ Search """
@abstractmethod @abstractmethod

View file

@ -21,9 +21,6 @@ class WeaviateAdapter(VectorDBInterface):
self.client = weaviate.connect_to_wcs( self.client = weaviate.connect_to_wcs(
cluster_url=url, cluster_url=url,
auth_credentials=weaviate.auth.AuthApiKey(api_key), auth_credentials=weaviate.auth.AuthApiKey(api_key),
# headers = {
# "X-OpenAI-Api-Key": openai_api_key
# },
additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30)) additional_config=wvc.init.AdditionalConfig(timeout=wvc.init.Timeout(init=30))
) )
@ -31,20 +28,23 @@ class WeaviateAdapter(VectorDBInterface):
return await self.embedding_engine.embed_text(data) return await self.embedding_engine.embed_text(data)
async def collection_exists(self, collection_name: str) -> bool: async def collection_exists(self, collection_name: str) -> bool:
event_loop = asyncio.get_event_loop() future = asyncio.Future()
def sync_collection_exists(): future.set_result(self.client.collections.exists(collection_name))
return self.client.collections.exists(collection_name)
return await event_loop.run_in_executor(None, sync_collection_exists) return await future
async def create_collection(self, collection_name: str): async def create_collection(
self,
collection_name: str,
payload_schema = None,
):
import weaviate.classes.config as wvcc import weaviate.classes.config as wvcc
event_loop = asyncio.get_event_loop() future = asyncio.Future()
def sync_create_collection(): future.set_result(
return self.client.collections.create( self.client.collections.create(
name=collection_name, name=collection_name,
properties=[ properties=[
wvcc.Property( wvcc.Property(
@ -54,13 +54,9 @@ class WeaviateAdapter(VectorDBInterface):
) )
] ]
) )
)
# try: return await future
result = await event_loop.run_in_executor(None, sync_create_collection)
# finally:
# event_loop.shutdown_executor()
return result
def get_collection(self, collection_name: str): def get_collection(self, collection_name: str):
return self.client.collections.get(collection_name) return self.client.collections.get(collection_name)
@ -73,30 +69,26 @@ class WeaviateAdapter(VectorDBInterface):
def convert_to_weaviate_data_points(data_point: DataPoint): def convert_to_weaviate_data_points(data_point: DataPoint):
return DataObject( return DataObject(
uuid=data_point.id, uuid = data_point.id,
properties=data_point.payload, properties = data_point.payload.dict(),
vector=data_vectors[data_points.index(data_point)] vector = data_vectors[data_points.index(data_point)]
) )
objects = list(map(convert_to_weaviate_data_points, data_points)) objects = list(map(convert_to_weaviate_data_points, data_points))
return self.get_collection(collection_name).data.insert_many(objects) return self.get_collection(collection_name).data.insert_many(objects)
async def retrieve(self, collection_name: str, data_id: str): async def retrieve(self, collection_name: str, data_point_id: str):
def sync_retrieve(): future = asyncio.Future()
return self.get_collection(collection_name).query.fetch_object_by_id(UUID(data_id))
event_loop = asyncio.get_event_loop() data_point = self.get_collection(collection_name).query.fetch_object_by_id(UUID(data_point_id))
# try:
data_point = await event_loop.run_in_executor(None, sync_retrieve)
# finally:
# event_loop.shutdown_executor()
data_point.payload = data_point.properties data_point.payload = data_point.properties
del data_point.properties del data_point.properties
return data_point future.set_result(data_point)
return await future
async def search( async def search(
self, self,
@ -114,7 +106,6 @@ class WeaviateAdapter(VectorDBInterface):
if query_vector is None: if query_vector is None:
query_vector = (await self.embed_data([query_text]))[0] query_vector = (await self.embed_data([query_text]))[0]
# def sync_search():
search_result = self.get_collection(collection_name).query.hybrid( search_result = self.get_collection(collection_name).query.hybrid(
query = None, query = None,
vector = query_vector, vector = query_vector,

View file

@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
from uuid import uuid4 from uuid import uuid4
from typing import List, Tuple, TypedDict from typing import List, Tuple, TypedDict
from pydantic import BaseModel
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.vector import DataPoint from cognee.infrastructure.databases.vector import DataPoint
from cognee.utils import extract_pos_tags, extract_named_entities, extract_sentiment_vader from cognee.utils import extract_pos_tags, extract_named_entities, extract_sentiment_vader
@ -113,14 +114,22 @@ async def add_cognitive_layer_graphs(
await graph_client.add_edges(graph_edges) await graph_client.add_edges(graph_edges)
class References(BaseModel):
node_id: str
cognitive_layer: str
class PayloadSchema(BaseModel):
value: str
references: References
try: try:
await vector_client.create_collection(layer_id) await vector_client.create_collection(layer_id, payload_schema = PayloadSchema)
except Exception: except Exception:
# It's ok if the collection already exists. # It's ok if the collection already exists.
pass pass
data_points = [ data_points = [
DataPoint( DataPoint[PayloadSchema](
id = str(uuid4()), id = str(uuid4()),
payload = dict( payload = dict(
value = node_data["name"], value = node_data["name"],

View file

@ -1,6 +1,9 @@
import json import json
import logging import logging
from typing import TypedDict from typing import TypedDict
from pydantic import BaseModel, Field
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.vector import DataPoint from cognee.infrastructure.databases.vector import DataPoint
@ -14,12 +17,15 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
identified_chunks = [] identified_chunks = []
class PayloadSchema(BaseModel):
text: str = Field(...)
for (dataset_name, chunks) in dataset_data_chunks.items(): for (dataset_name, chunks) in dataset_data_chunks.items():
try: try:
# if not await vector_client.collection_exists(dataset_name):
# logging.error(f"Creating collection {str(dataset_name)}") await vector_client.create_collection(dataset_name, payload_schema = PayloadSchema)
await vector_client.create_collection(dataset_name) except Exception as error:
except Exception: print(error)
pass pass
dataset_chunks = [ dataset_chunks = [
@ -27,35 +33,21 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
chunk_id = chunk["chunk_id"], chunk_id = chunk["chunk_id"],
collection = dataset_name, collection = dataset_name,
text = chunk["text"], text = chunk["text"],
file_metadata = chunk["file_metadata"], document_id = chunk["document_id"],
) for chunk in chunks ) for chunk in chunks
] ]
identified_chunks.extend(dataset_chunks) identified_chunks.extend(dataset_chunks)
# if not await vector_client.collection_exists(dataset_name): await vector_client.create_data_points(
try: dataset_name,
logging.error("Collection still not found. Creating collection again.") [
await vector_client.create_collection(dataset_name) DataPoint[PayloadSchema](
except: id = chunk["chunk_id"],
pass payload = PayloadSchema.parse_obj(dict(text = chunk["text"])),
embed_field = "text",
async def create_collection_retry(dataset_name, dataset_chunks): ) for chunk in dataset_chunks
await vector_client.create_data_points( ],
dataset_name, )
[
DataPoint(
id = chunk["chunk_id"],
payload = dict(text = chunk["text"]),
embed_field = "text"
) for chunk in dataset_chunks
],
)
try:
await create_collection_retry(dataset_name, dataset_chunks)
except Exception:
logging.error("Collection not found in create data points.")
await create_collection_retry(dataset_name, dataset_chunks)
return identified_chunks return identified_chunks

View file

@ -1,13 +1,10 @@
from cognee.shared.data_models import Document from cognee.shared.data_models import Document
from cognee.modules.cognify.graph.add_label_nodes import add_label_nodes
from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface from cognee.infrastructure.databases.graph.graph_db_interface import GraphDBInterface
async def add_document_node(graph_client: GraphDBInterface, parent_node_id, document_metadata): async def add_document_node(graph_client: GraphDBInterface, parent_node_id, document_metadata):
document_id = f"DOCUMENT__{document_metadata['id']}" document_id = f"DOCUMENT__{document_metadata['id']}"
document = await graph_client.extract_node(document_id) document = await graph_client.extract_node(document_id)
if not document: if not document:
@ -21,6 +18,13 @@ async def add_document_node(graph_client: GraphDBInterface, parent_node_id, docu
await graph_client.add_node(document_id, document) await graph_client.add_node(document_id, document)
await graph_client.add_edge(parent_node_id, document_id, "has_document", dict(relationship_name = "has_document")) await graph_client.add_edge(
parent_node_id,
document_id,
"has_document",
dict(relationship_name = "has_document"),
)
await add_label_nodes(graph_client, document_id, document_metadata["keywords"].split("|"))
return document_id return document_id

View file

@ -1,10 +1,11 @@
from uuid import uuid4 from uuid import uuid4
from typing import List from typing import List
from datetime import datetime from datetime import datetime
from pydantic import BaseModel
from cognee.infrastructure import infrastructure_config from cognee.infrastructure import infrastructure_config
from cognee.infrastructure.databases.vector import DataPoint from cognee.infrastructure.databases.vector import DataPoint
async def add_label_nodes(graph_client, parent_node_id: str, chunk_id: str, keywords: List[str]) -> None: async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str]) -> None:
vector_client = infrastructure_config.get_config("vector_engine") vector_client = infrastructure_config.get_config("vector_engine")
keyword_nodes = [] keyword_nodes = []
@ -16,7 +17,6 @@ async def add_label_nodes(graph_client, parent_node_id: str, chunk_id: str, keyw
keyword_id, keyword_id,
dict( dict(
id = keyword_id, id = keyword_id,
chunk_id = chunk_id,
name = keyword.lower().capitalize(), name = keyword.lower().capitalize(),
keyword = keyword.lower(), keyword = keyword.lower(),
type = "Keyword", type = "Keyword",
@ -36,25 +36,57 @@ async def add_label_nodes(graph_client, parent_node_id: str, chunk_id: str, keyw
) for (keyword_id, __) in keyword_nodes ) for (keyword_id, __) in keyword_nodes
]) ])
class References(BaseModel):
node_id: str
cognitive_layer: str
class PayloadSchema(BaseModel):
value: str
references: References
# Add data to vector # Add data to vector
# keyword_data_points = [
# DataPoint( keyword_data_points = [
# id = str(uuid4()), DataPoint(
# payload = dict( id = str(uuid4()),
# value = keyword_data["keyword"], payload = dict(
# references = dict( value = keyword_data["keyword"],
# node_id = keyword_node_id, references = dict(
# cognitive_layer = parent_node_id, node_id = keyword_node_id,
# ), cognitive_layer = parent_node_id,
# ), ),
# embed_field = "value" ),
# ) for (keyword_node_id, keyword_data) in keyword_nodes embed_field = "value"
# ] ) for (keyword_node_id, keyword_data) in keyword_nodes
# ]
# try:
# await vector_client.create_collection(parent_node_id) try:
# except Exception: await vector_client.create_collection(parent_node_id)
# # It's ok if the collection already exists. except Exception:
# pass # It's ok if the collection already exists.
# pass
# await vector_client.create_data_points(parent_node_id, keyword_data_points)
await vector_client.create_data_points(parent_node_id, keyword_data_points)
keyword_data_points = [
DataPoint[PayloadSchema](
id = str(uuid4()),
payload = dict(
value = keyword_data["keyword"],
references = dict(
node_id = keyword_node_id,
cognitive_layer = parent_node_id,
),
),
embed_field = "value"
) for (keyword_node_id, keyword_data) in keyword_nodes
]
try:
await vector_client.create_collection(parent_node_id, payload_schema = PayloadSchema)
except Exception:
# It's ok if the collection already exists.
pass
await vector_client.create_data_points(parent_node_id, keyword_data_points)

View file

@ -7,18 +7,18 @@ from .extract_content_graph import extract_content_graph
logger = logging.getLogger("extract_knowledge_graph(text: str)") logger = logging.getLogger("extract_knowledge_graph(text: str)")
async def extract_knowledge_graph(text: str, cognitive_layer, graph_model): async def extract_knowledge_graph(text: str, cognitive_layer, graph_model):
# try: try:
# compiled_extract_knowledge_graph = ExtractKnowledgeGraph() compiled_extract_knowledge_graph = ExtractKnowledgeGraph()
# compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json")) compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
# event_loop = asyncio.get_event_loop() event_loop = asyncio.get_event_loop()
# def sync_extract_knowledge_graph(): def sync_extract_knowledge_graph():
# return compiled_extract_knowledge_graph(context = text, question = "") return compiled_extract_knowledge_graph(context = text, question = "")
# return (await event_loop.run_in_executor(None, sync_extract_knowledge_graph)).graph return (await event_loop.run_in_executor(None, sync_extract_knowledge_graph)).graph
# # return compiled_extract_knowledge_graph(text, question = "").graph # return compiled_extract_knowledge_graph(text, question = "").graph
# except Exception as error: except Exception as error:
# logger.error("Error extracting graph from content: %s", error, exc_info = True) # TODO: Log error to Sentry
return await extract_content_graph(text, cognitive_layer, graph_model) return await extract_content_graph(text, cognitive_layer, graph_model)

View file

@ -1 +0,0 @@
from .create_vector_memory import create_vector_memory

View file

@ -1,7 +0,0 @@
from cognee.infrastructure.databases.vector.qdrant.adapter import CollectionConfig
from cognee.infrastructure.databases.vector.get_vector_database import get_vector_database
async def create_vector_memory(memory_name: str, collection_config: CollectionConfig):
vector_db = get_vector_database()
return await vector_db.create_collection(memory_name, collection_config)

View file

@ -25,7 +25,7 @@ async def search_similarity(query: str, graph):
layer_id = result.payload["references"]["cognitive_layer"], layer_id = result.payload["references"]["cognitive_layer"],
node_id = result.payload["references"]["node_id"], node_id = result.payload["references"]["node_id"],
score = result.score, score = result.score,
) for result in results if result.score > 0.5 ) for result in results if result.score > 0.8
]) ])
if len(graph_nodes) == 0: if len(graph_nodes) == 0:
@ -39,7 +39,10 @@ async def search_similarity(query: str, graph):
if "chunk_collection" not in graph_node and "chunk_id" not in graph_node: if "chunk_collection" not in graph_node and "chunk_id" not in graph_node:
continue continue
vector_point = await vector_engine.retrieve(graph_node["chunk_collection"], graph_node["chunk_id"]) vector_point = await vector_engine.retrieve(
graph_node["chunk_collection"],
graph_node["chunk_id"],
)
relevant_context.append(vector_point.payload["text"]) relevant_context.append(vector_point.payload["text"])

File diff suppressed because one or more lines are too long

7859
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "cognee" name = "cognee"
version = "0.1.6" version = "0.1.7"
description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning." description = "Cognee - is a library for enriching LLM context with a semantic layer for better understanding and reasoning."
authors = ["Vasilije Markovic", "Boris Arzentar"] authors = ["Vasilije Markovic", "Boris Arzentar"]
readme = "README.md" readme = "README.md"
@ -27,10 +27,8 @@ uvicorn = "0.22.0"
boto3 = "^1.26.125" boto3 = "^1.26.125"
gunicorn = "^20.1.0" gunicorn = "^20.1.0"
sqlalchemy = "^2.0.21" sqlalchemy = "^2.0.21"
asyncpg = "^0.28.0"
instructor = "1.2.1" instructor = "1.2.1"
networkx = "^3.2.1" networkx = "^3.2.1"
graphviz = "^0.20.1"
debugpy = "^1.8.0" debugpy = "^1.8.0"
pyarrow = "^15.0.0" pyarrow = "^15.0.0"
pylint = "^3.0.3" pylint = "^3.0.3"
@ -53,7 +51,6 @@ scikit-learn = "^1.4.1.post1"
fastembed = "^0.2.5" fastembed = "^0.2.5"
pypdf = "^4.1.0" pypdf = "^4.1.0"
anthropic = "^0.21.3" anthropic = "^0.21.3"
xmltodict = "^0.13.0"
neo4j = "^5.18.0" neo4j = "^5.18.0"
jinja2 = "^3.1.3" jinja2 = "^3.1.3"
matplotlib = "^3.8.3" matplotlib = "^3.8.3"
@ -63,38 +60,24 @@ tiktoken = "^0.6.0"
dspy-ai = "2.4.3" dspy-ai = "2.4.3"
posthog = "^3.5.0" posthog = "^3.5.0"
lancedb = "^0.6.10" lancedb = "^0.6.10"
importlib-metadata = "6.8.0" importlib-metadata = "6.8.0"
deepeval = "^0.21.36" deepeval = "^0.21.36"
litellm = "^1.37.3" litellm = "^1.37.3"
groq = "^0.5.0" groq = "^0.5.0"
tantivy = "^0.21.0"
[tool.poetry.extras] [tool.poetry.extras]
dbt = ["dbt-core", "dbt-redshift", "dbt-bigquery", "dbt-duckdb", "dbt-snowflake", "dbt-athena-community", "dbt-databricks"]
gcp = ["grpcio", "google-cloud-bigquery", "db-dtypes", "gcsfs"]
# bigquery is alias on gcp extras
bigquery = ["grpcio", "google-cloud-bigquery", "pyarrow", "db-dtypes", "gcsfs"]
postgres = ["psycopg2-binary", "psycopg2cffi"]
redshift = ["psycopg2-binary", "psycopg2cffi"]
parquet = ["pyarrow"] parquet = ["pyarrow"]
duckdb = ["duckdb"] duckdb = ["duckdb"]
filesystem = ["s3fs", "botocore"] filesystem = ["s3fs", "botocore"]
s3 = ["s3fs", "botocore"]
gs = ["gcsfs"]
az = ["adlfs"]
snowflake = ["snowflake-connector-python"]
motherduck = ["duckdb", "pyarrow"] motherduck = ["duckdb", "pyarrow"]
cli = ["pipdeptree", "cron-descriptor"] cli = ["pipdeptree", "cron-descriptor"]
athena = ["pyathena", "pyarrow", "s3fs", "botocore"]
weaviate = ["weaviate-client"] weaviate = ["weaviate-client"]
mssql = ["pyodbc"] qdrant = ["qdrant-client"]
synapse = ["pyodbc", "adlfs", "pyarrow"]
databricks = ["databricks-sql-connector"]
lancedb = ["lancedb"]
pinecone = ["pinecone-client"]
neo4j = ["neo4j", "py2neo"] neo4j = ["neo4j", "py2neo"]
notebook =[ "ipykernel","overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"] notebook = ["ipykernel","overrides", "ipywidgets", "jupyterlab", "jupyterlab_widgets", "jupyterlab-server", "jupyterlab-git"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^7.4.0" pytest = "^7.4.0"
@ -115,7 +98,6 @@ mkdocs-redirects = "^1.2.1"
[tool.poetry.group.test-docs.dependencies] [tool.poetry.group.test-docs.dependencies]
fastapi = "^0.109.2" fastapi = "^0.109.2"
redis = "^5.0.1"
diskcache = "^5.6.3" diskcache = "^5.6.3"
pandas = "^2.2.0" pandas = "^2.2.0"
tabulate = "^0.9.0" tabulate = "^0.9.0"