diff --git a/cognee/api/v1/cognify/cognify.py b/cognee/api/v1/cognify/cognify.py index 28203f7e2..88f09ec64 100644 --- a/cognee/api/v1/cognify/cognify.py +++ b/cognee/api/v1/cognify/cognify.py @@ -87,9 +87,8 @@ async def cognify(datasets: Union[str, List[str]] = None): chunk_engine = infrastructure_config.get_config()["chunk_engine"] chunk_strategy = infrastructure_config.get_config()["chunk_strategy"] - - for (dataset_name, files) in dataset_files: - for file_metadata in files: + async def process_batch(files_batch): + for dataset_name, file_metadata in files_batch: with open(file_metadata["file_path"], "rb") as file: try: file_type = guess_file_type(file) @@ -102,21 +101,68 @@ async def cognify(datasets: Union[str, List[str]] = None): data_chunks[dataset_name] = [] for subchunk in subchunks: - data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata)) + data_chunks[dataset_name].append( + dict(text=subchunk, chunk_id=str(uuid4()), file_metadata=file_metadata)) except FileTypeException: logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"]) + added_chunks = await add_data_chunks(data_chunks) + # await asyncio.gather( + # *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in + # added_chunks] + # ) + batch_size = 20 + file_count = 0 + files_batch = [] - added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks) + for (dataset_name, files) in dataset_files: + for file_metadata in files: + files_batch.append((dataset_name, file_metadata)) + file_count += 1 - await asyncio.gather( - *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks] - ) + if file_count >= batch_size: + await process_batch(files_batch) + files_batch = [] + file_count = 0 + + # Process any remaining files in the last batch + if files_batch: + await process_batch(files_batch) return graph_client.graph + # + # for (dataset_name, files) in dataset_files: + # for file_metadata in files: + # with open(file_metadata["file_path"], "rb") as file: + # try: + # file_type = guess_file_type(file) + # text = extract_text_from_file(file, file_type) + # if text is None: + # text = "empty file" + # subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap) + # + # if dataset_name not in data_chunks: + # data_chunks[dataset_name] = [] + # + # for subchunk in subchunks: + # data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata)) + # except FileTypeException: + # logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"]) + # + # + # + # + # added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks) + # + # await asyncio.gather( + # *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks] + # ) + # + # return graph_client.graph + async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict): print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).") @@ -124,36 +170,46 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi graph_topology = infrastructure_config.get_config()["graph_topology"] + print("got here") + document_id = await add_document_node( graph_client, parent_node_id = f"{file_metadata['name']}.{file_metadata['extension']}", #make a param of defaultgraph model to make sure when user passes his stuff, it doesn't break pipeline document_metadata = file_metadata, ) + # print("got here2") + # await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|")) + + # classified_categories = await get_content_categories(input_text) # - await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|")) + # print("classified_categories", classified_categories) + # await add_classification_nodes( + # graph_client, + # parent_node_id = document_id, + # categories = classified_categories, + # ) - classified_categories = await get_content_categories(input_text) - await add_classification_nodes( - graph_client, - parent_node_id = document_id, - categories = classified_categories, - ) + classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}] - # print(f"Chunk ({chunk_id}) classified.") - print("document_id", document_id) - content_summary = await get_content_summary(input_text) - await add_summary_nodes(graph_client, document_id, content_summary) + print(f"Chunk ({chunk_id}) classified.") + + # print("document_id", document_id) + # + # content_summary = await get_content_summary(input_text) + # await add_summary_nodes(graph_client, document_id, content_summary) print(f"Chunk ({chunk_id}) summarized.") - + # cognitive_layers = await get_cognitive_layers(input_text, classified_categories) cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2] # layer_graphs = await get_layer_graphs(input_text, cognitive_layers) await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs) + + print("got here 4444") # # if infrastructure_config.get_config()["connect_documents"] is True: # db_engine = infrastructure_config.get_config()["database_engine"] @@ -200,7 +256,7 @@ if __name__ == "__main__": # # await add("data://" +data_directory_path, "example") - infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine()}) + infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() }) from cognee.shared.SourceCodeGraph import SourceCodeGraph from cognee.api.v1.config import config diff --git a/cognee/config.py b/cognee/config.py index 639a6b41e..149a72eab 100644 --- a/cognee/config.py +++ b/cognee/config.py @@ -51,13 +51,13 @@ class Config: # Model parameters llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama - custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1") #"mistralai/Mixtral-8x7B-Instruct-v0.1" + custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1" custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY") ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1" ollama_key: Optional[str] = "ollama" ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct" - openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4-1106-preview" ) #"gpt-4-1106-preview" + openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o" model_endpoint: str = "openai" openai_key: Optional[str] = os.getenv("OPENAI_API_KEY") openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0)) diff --git a/cognee/infrastructure/llm/generic_llm_api/adapter.py b/cognee/infrastructure/llm/generic_llm_api/adapter.py index 1741b6f3c..2ee65178c 100644 --- a/cognee/infrastructure/llm/generic_llm_api/adapter.py +++ b/cognee/infrastructure/llm/generic_llm_api/adapter.py @@ -1,25 +1,37 @@ import asyncio +import os from typing import List, Type from pydantic import BaseModel import instructor from tenacity import retry, stop_after_attempt from openai import AsyncOpenAI import openai + +from cognee.infrastructure import infrastructure_config from cognee.infrastructure.llm.llm_interface import LLMInterface from cognee.infrastructure.llm.prompts import read_query_prompt - class GenericAPIAdapter(LLMInterface): - """Adapter for Ollama's API""" + """Adapter for Generic API LLM provider API """ def __init__(self, api_endpoint, api_key: str, model: str): - self.aclient = instructor.patch( - AsyncOpenAI( - base_url = api_endpoint, - api_key = api_key, # required, but unused - ), - mode = instructor.Mode.JSON, - ) + + + if infrastructure_config.get_config()["llm_provider"] == 'groq': + from groq import groq + self.aclient = instructor.from_openai(client = groq.Groq( + api_key=api_key, + ), mode=instructor.Mode.MD_JSON) + else: + self.aclient = instructor.patch( + AsyncOpenAI( + base_url = api_endpoint, + api_key = api_key, # required, but unused + ), + mode = instructor.Mode.JSON, + ) + + self.model = model @retry(stop = stop_after_attempt(5)) @@ -75,20 +87,21 @@ class GenericAPIAdapter(LLMInterface): return embeddings - @retry(stop=stop_after_attempt(5)) - async def acreate_structured_output(self, text_input: str, system_prompt: str, - response_model: Type[BaseModel]) -> BaseModel: + @retry(stop = stop_after_attempt(5)) + async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel: """Generate a response from a user query.""" + return await self.aclient.chat.completions.create( - model=self.model, - messages=[ + model = self.model, + messages = [ { "role": "user", "content": f"""Use the given format to - extract information from the following input: {text_input}. {system_prompt} """, - } + extract information from the following input: {text_input}. """, + }, + {"role": "system", "content": system_prompt}, ], - response_model=response_model, + response_model = response_model, ) def show_prompt(self, text_input: str, system_prompt: str) -> str: diff --git a/cognee/modules/cognify/graph/add_data_chunks.py b/cognee/modules/cognify/graph/add_data_chunks.py index 1be631815..42b51170b 100644 --- a/cognee/modules/cognify/graph/add_data_chunks.py +++ b/cognee/modules/cognify/graph/add_data_chunks.py @@ -15,12 +15,12 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]): identified_chunks = [] for (dataset_name, chunks) in dataset_data_chunks.items(): - # try: - # # if not await vector_client.collection_exists(dataset_name): - # # logging.error(f"Creating collection {str(dataset_name)}") - # await vector_client.create_collection(dataset_name) - # except Exception: - # pass + try: + # if not await vector_client.collection_exists(dataset_name): + # logging.error(f"Creating collection {str(dataset_name)}") + await vector_client.create_collection(dataset_name) + except Exception: + pass dataset_chunks = [ dict( @@ -33,29 +33,29 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]): identified_chunks.extend(dataset_chunks) - # # if not await vector_client.collection_exists(dataset_name): - # try: - # logging.error("Collection still not found. Creating collection again.") - # await vector_client.create_collection(dataset_name) - # except: - # pass - # - # async def create_collection_retry(dataset_name, dataset_chunks): - # await vector_client.create_data_points( - # dataset_name, - # [ - # DataPoint( - # id = chunk["chunk_id"], - # payload = dict(text = chunk["text"]), - # embed_field = "text" - # ) for chunk in dataset_chunks - # ], - # ) - # - # try: - # await create_collection_retry(dataset_name, dataset_chunks) - # except Exception: - # logging.error("Collection not found in create data points.") - # await create_collection_retry(dataset_name, dataset_chunks) + # if not await vector_client.collection_exists(dataset_name): + try: + logging.error("Collection still not found. Creating collection again.") + await vector_client.create_collection(dataset_name) + except: + pass + + async def create_collection_retry(dataset_name, dataset_chunks): + await vector_client.create_data_points( + dataset_name, + [ + DataPoint( + id = chunk["chunk_id"], + payload = dict(text = chunk["text"]), + embed_field = "text" + ) for chunk in dataset_chunks + ], + ) + + try: + await create_collection_retry(dataset_name, dataset_chunks) + except Exception: + logging.error("Collection not found in create data points.") + await create_collection_retry(dataset_name, dataset_chunks) return identified_chunks diff --git a/poetry.lock b/poetry.lock index 46a96ece4..5536f45da 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2107,6 +2107,25 @@ files = [ [package.dependencies] colorama = ">=0.4" +[[package]] +name = "groq" +version = "0.5.0" +description = "The official Python library for the groq API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "groq-0.5.0-py3-none-any.whl", hash = "sha256:a7e6be1118bcdfea3ed071ec00f505a34d4e6ec28c435adb5a5afd33545683a1"}, + {file = "groq-0.5.0.tar.gz", hash = "sha256:d476cdc3383b45d2a4dc1876142a9542e663ea1029f9e07a05de24f895cae48c"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +typing-extensions = ">=4.7,<5" + [[package]] name = "grpcio" version = "1.63.0" @@ -7837,4 +7856,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9.0,<3.12" -content-hash = "459e5945fb184ecdc2c9a817e89b2a83b9e0c97273d1a7df8f75c9c0b355ba47" +content-hash = "4c2e75aba3260da9e4023e5fa1c4d117b06e4f042bda9441664a0e91af4069a1" diff --git a/pyproject.toml b/pyproject.toml index c36abd566..4905a8bf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ lancedb = "^0.6.10" importlib-metadata = "6.8.0" deepeval = "^0.21.36" litellm = "^1.37.3" +groq = "^0.5.0"