diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 38a3d6c71..ed456f5a7 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -87,26 +87,22 @@ async def cognify(datasets: Union[str, List[str]] = None): chunk_engine = infrastructure_config.get_config()["chunk_engine"] chunk_strategy = infrastructure_config.get_config()["chunk_strategy"] async def process_batch(files_batch): - for dataset_name, file_metadata in files_batch: + for dataset_name, file_metadata, document_id in files_batch: with open(file_metadata["file_path"], "rb") as file: try: - document_id = await add_document_node( - graph_client, - parent_node_id = f"DefaultGraphModel__{USER_ID}", - document_metadata = file_metadata, - ) - file_type = guess_file_type(file) text = extract_text_from_file(file, file_type) if text is None: text = "empty file" + if text == "": + text = "empty file" subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap) if dataset_name not in data_chunks: data_chunks[dataset_name] = [] for subchunk in subchunks: - data_chunks[dataset_name].append(dict(document_id = document_id, chunk_id = str(uuid4()), text = subchunk)) + data_chunks[dataset_name].append(dict(document_id = document_id, chunk_id = str(uuid4()), text = subchunk, file_metadata = file_metadata)) except FileTypeException: logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"]) @@ -124,8 +120,14 @@ async def cognify(datasets: Union[str, List[str]] = None): files_batch = [] for (dataset_name, files) in dataset_files: + for file_metadata in files: - files_batch.append((dataset_name, file_metadata)) + document_id = await add_document_node( + graph_client, + parent_node_id=f"DefaultGraphModel__{USER_ID}", + document_metadata=file_metadata, + ) + files_batch.append((dataset_name, file_metadata, document_id)) file_count += 1 if file_count >= batch_size: @@ -199,11 +201,6 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}] - await asyncio.gather( - *[process_text(chunk["document_id"], chunk["chunk_id"], chunk["collection"], chunk["text"]) for chunk in added_chunks] - ) - - return graph_client.graph # # async def process_text(document_id: str, chunk_id: str, chunk_collection: str, input_text: str): # raw_document_id = document_id.split("__")[-1] diff --git a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py index 32f6f25ce..f67d5f541 100644 --- a/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py +++ b/cognee/infrastructure/databases/vector/embeddings/DefaultEmbeddingEngine.py @@ -27,25 +27,23 @@ class DefaultEmbeddingEngine(EmbeddingEngine): class LiteLLMEmbeddingEngine(EmbeddingEngine): + import asyncio + from typing import List - async def embed_text(self, text: List[str]) -> List[float]: - - - - print("text", text) - try: - text = str(text[0]) - except: - text = str(text) - - - response = await aembedding(config.litellm_embedding_model, input=text) + async def embed_text(self, text: List[str]) -> List[List[float]]: + async def get_embedding(text_): + response = await aembedding(config.litellm_embedding_model, input=text_) + return response.data[0]['embedding'] + tasks = [get_embedding(text_) for text_ in text] + result = await asyncio.gather(*tasks) + return result # embedding = response.data[0].embedding - # embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text))) - print("response", type(response.data[0]['embedding'])) - return response.data[0]['embedding'] + # # embeddings_list = list(map(lambda embedding: embedding.tolist(), embedding_model.embed(text))) + # print("response", type(response.data[0]['embedding'])) + # print("response", response.data[0]) + # return [response.data[0]['embedding']] def get_vector_size(self) -> int: diff --git a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py index 9ffe3d8de..adb1c161d 100644 --- a/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py +++ b/cognee/infrastructure/databases/vector/lancedb/LanceDBAdapter.py @@ -37,10 +37,11 @@ class LanceDBAdapter(VectorDBInterface): async def create_collection(self, collection_name: str, payload_schema: BaseModel): data_point_types = get_type_hints(DataPoint) + vector_size = self.embedding_engine.get_vector_size() class LanceDataPoint(LanceModel): id: data_point_types["id"] = Field(...) - vector: Vector(self.embedding_engine.get_vector_size()) + vector: Vector(vector_size) payload: payload_schema if not await self.collection_exists(collection_name): @@ -68,10 +69,11 @@ class LanceDBAdapter(VectorDBInterface): IdType = TypeVar("IdType") PayloadSchema = TypeVar("PayloadSchema") + vector_size = self.embedding_engine.get_vector_size() class LanceDataPoint(LanceModel, Generic[IdType, PayloadSchema]): id: IdType - vector: Vector(self.embedding_engine.get_vector_size()) + vector: Vector(vector_size) payload: PayloadSchema lance_data_points = [ diff --git a/cognee/modules/cognify/graph/add_data_chunks.py b/cognee/modules/cognify/graph/add_data_chunks.py index e283f01e5..291c15716 100644 --- a/cognee/modules/cognify/graph/add_data_chunks.py +++ b/cognee/modules/cognify/graph/add_data_chunks.py @@ -34,6 +34,7 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]): collection = dataset_name, text = chunk["text"], document_id = chunk["document_id"], + file_metadata = chunk["file_metadata"], ) for chunk in chunks ] diff --git a/cognee/modules/cognify/graph/add_document_node.py b/cognee/modules/cognify/graph/add_document_node.py index 3e878dcf1..b70f700e5 100644 --- a/cognee/modules/cognify/graph/add_document_node.py +++ b/cognee/modules/cognify/graph/add_document_node.py @@ -25,6 +25,7 @@ async def add_document_node(graph_client: GraphDBInterface, parent_node_id, docu dict(relationship_name = "has_document"), ) - await add_label_nodes(graph_client, document_id, document_metadata["keywords"].split("|")) + # + # await add_label_nodes(graph_client, document_id, document_metadata["keywords"].split("|")) return document_id diff --git a/cognee/modules/cognify/graph/add_label_nodes.py b/cognee/modules/cognify/graph/add_label_nodes.py index 8d991c9d9..574b19f6c 100644 --- a/cognee/modules/cognify/graph/add_label_nodes.py +++ b/cognee/modules/cognify/graph/add_label_nodes.py @@ -19,7 +19,7 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str] id = keyword_id, name = keyword.lower().capitalize(), keyword = keyword.lower(), - type = "Keyword", + entity_type = "Keyword", created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), updated_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ), @@ -45,29 +45,6 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str] references: References # Add data to vector - - keyword_data_points = [ - DataPoint( - id = str(uuid4()), - payload = dict( - value = keyword_data["keyword"], - references = dict( - node_id = keyword_node_id, - cognitive_layer = parent_node_id, - ), - ), - embed_field = "value" - ) for (keyword_node_id, keyword_data) in keyword_nodes - ] - - try: - await vector_client.create_collection(parent_node_id) - except Exception: - # It's ok if the collection already exists. - pass - - await vector_client.create_data_points(parent_node_id, keyword_data_points) - keyword_data_points = [ DataPoint[PayloadSchema]( id = str(uuid4()), @@ -84,9 +61,8 @@ async def add_label_nodes(graph_client, parent_node_id: str, keywords: List[str] try: await vector_client.create_collection(parent_node_id, payload_schema = PayloadSchema) - except Exception: + except Exception as e: # It's ok if the collection already exists. - pass - - await vector_client.create_data_points(parent_node_id, keyword_data_points) + print(e) + await vector_client.create_data_points(parent_node_id, keyword_data_points) \ No newline at end of file