fix deployment
This commit is contained in:
parent
3c261ce6a1
commit
d2e17dd4b7
6 changed files with 160 additions and 71 deletions
|
|
@ -87,9 +87,8 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
|
|
||||||
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
|
chunk_engine = infrastructure_config.get_config()["chunk_engine"]
|
||||||
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
|
chunk_strategy = infrastructure_config.get_config()["chunk_strategy"]
|
||||||
|
async def process_batch(files_batch):
|
||||||
for (dataset_name, files) in dataset_files:
|
for dataset_name, file_metadata in files_batch:
|
||||||
for file_metadata in files:
|
|
||||||
with open(file_metadata["file_path"], "rb") as file:
|
with open(file_metadata["file_path"], "rb") as file:
|
||||||
try:
|
try:
|
||||||
file_type = guess_file_type(file)
|
file_type = guess_file_type(file)
|
||||||
|
|
@ -102,21 +101,68 @@ async def cognify(datasets: Union[str, List[str]] = None):
|
||||||
data_chunks[dataset_name] = []
|
data_chunks[dataset_name] = []
|
||||||
|
|
||||||
for subchunk in subchunks:
|
for subchunk in subchunks:
|
||||||
data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
|
data_chunks[dataset_name].append(
|
||||||
|
dict(text=subchunk, chunk_id=str(uuid4()), file_metadata=file_metadata))
|
||||||
except FileTypeException:
|
except FileTypeException:
|
||||||
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
||||||
|
|
||||||
|
added_chunks = await add_data_chunks(data_chunks)
|
||||||
|
|
||||||
|
# await asyncio.gather(
|
||||||
|
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in
|
||||||
|
# added_chunks]
|
||||||
|
# )
|
||||||
|
|
||||||
|
batch_size = 20
|
||||||
|
file_count = 0
|
||||||
|
files_batch = []
|
||||||
|
|
||||||
added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
|
for (dataset_name, files) in dataset_files:
|
||||||
|
for file_metadata in files:
|
||||||
|
files_batch.append((dataset_name, file_metadata))
|
||||||
|
file_count += 1
|
||||||
|
|
||||||
await asyncio.gather(
|
if file_count >= batch_size:
|
||||||
*[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
|
await process_batch(files_batch)
|
||||||
)
|
files_batch = []
|
||||||
|
file_count = 0
|
||||||
|
|
||||||
|
# Process any remaining files in the last batch
|
||||||
|
if files_batch:
|
||||||
|
await process_batch(files_batch)
|
||||||
|
|
||||||
return graph_client.graph
|
return graph_client.graph
|
||||||
|
|
||||||
|
#
|
||||||
|
# for (dataset_name, files) in dataset_files:
|
||||||
|
# for file_metadata in files:
|
||||||
|
# with open(file_metadata["file_path"], "rb") as file:
|
||||||
|
# try:
|
||||||
|
# file_type = guess_file_type(file)
|
||||||
|
# text = extract_text_from_file(file, file_type)
|
||||||
|
# if text is None:
|
||||||
|
# text = "empty file"
|
||||||
|
# subchunks = chunk_engine.chunk_data(chunk_strategy, text, config.chunk_size, config.chunk_overlap)
|
||||||
|
#
|
||||||
|
# if dataset_name not in data_chunks:
|
||||||
|
# data_chunks[dataset_name] = []
|
||||||
|
#
|
||||||
|
# for subchunk in subchunks:
|
||||||
|
# data_chunks[dataset_name].append(dict(text = subchunk, chunk_id = str(uuid4()), file_metadata = file_metadata))
|
||||||
|
# except FileTypeException:
|
||||||
|
# logger.warning("File (%s) has an unknown file type. We are skipping it.", file_metadata["id"])
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# added_chunks: list[tuple[str, str, dict]] = await add_data_chunks(data_chunks)
|
||||||
|
#
|
||||||
|
# await asyncio.gather(
|
||||||
|
# *[process_text(chunk["collection"], chunk["chunk_id"], chunk["text"], chunk["file_metadata"]) for chunk in added_chunks]
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# return graph_client.graph
|
||||||
|
|
||||||
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
|
async def process_text(chunk_collection: str, chunk_id: str, input_text: str, file_metadata: dict):
|
||||||
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
print(f"Processing chunk ({chunk_id}) from document ({file_metadata['id']}).")
|
||||||
|
|
||||||
|
|
@ -124,36 +170,46 @@ async def process_text(chunk_collection: str, chunk_id: str, input_text: str, fi
|
||||||
|
|
||||||
graph_topology = infrastructure_config.get_config()["graph_topology"]
|
graph_topology = infrastructure_config.get_config()["graph_topology"]
|
||||||
|
|
||||||
|
print("got here")
|
||||||
|
|
||||||
|
|
||||||
document_id = await add_document_node(
|
document_id = await add_document_node(
|
||||||
graph_client,
|
graph_client,
|
||||||
parent_node_id = f"{file_metadata['name']}.{file_metadata['extension']}", #make a param of defaultgraph model to make sure when user passes his stuff, it doesn't break pipeline
|
parent_node_id = f"{file_metadata['name']}.{file_metadata['extension']}", #make a param of defaultgraph model to make sure when user passes his stuff, it doesn't break pipeline
|
||||||
document_metadata = file_metadata,
|
document_metadata = file_metadata,
|
||||||
)
|
)
|
||||||
|
# print("got here2")
|
||||||
|
# await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
|
||||||
|
|
||||||
|
# classified_categories = await get_content_categories(input_text)
|
||||||
#
|
#
|
||||||
await add_label_nodes(graph_client, document_id, chunk_id, file_metadata["keywords"].split("|"))
|
# print("classified_categories", classified_categories)
|
||||||
|
# await add_classification_nodes(
|
||||||
|
# graph_client,
|
||||||
|
# parent_node_id = document_id,
|
||||||
|
# categories = classified_categories,
|
||||||
|
# )
|
||||||
|
|
||||||
classified_categories = await get_content_categories(input_text)
|
classified_categories= [{'data_type': 'text', 'category_name': 'Source code in various programming languages'}]
|
||||||
await add_classification_nodes(
|
|
||||||
graph_client,
|
|
||||||
parent_node_id = document_id,
|
|
||||||
categories = classified_categories,
|
|
||||||
)
|
|
||||||
|
|
||||||
# print(f"Chunk ({chunk_id}) classified.")
|
|
||||||
|
|
||||||
print("document_id", document_id)
|
|
||||||
|
|
||||||
content_summary = await get_content_summary(input_text)
|
print(f"Chunk ({chunk_id}) classified.")
|
||||||
await add_summary_nodes(graph_client, document_id, content_summary)
|
|
||||||
|
# print("document_id", document_id)
|
||||||
|
#
|
||||||
|
# content_summary = await get_content_summary(input_text)
|
||||||
|
# await add_summary_nodes(graph_client, document_id, content_summary)
|
||||||
|
|
||||||
print(f"Chunk ({chunk_id}) summarized.")
|
print(f"Chunk ({chunk_id}) summarized.")
|
||||||
|
#
|
||||||
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
|
cognitive_layers = await get_cognitive_layers(input_text, classified_categories)
|
||||||
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
|
cognitive_layers = (await add_cognitive_layers(graph_client, document_id, cognitive_layers))[:2]
|
||||||
#
|
#
|
||||||
layer_graphs = await get_layer_graphs(input_text, cognitive_layers)
|
layer_graphs = await get_layer_graphs(input_text, cognitive_layers)
|
||||||
await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs)
|
await add_cognitive_layer_graphs(graph_client, chunk_collection, chunk_id, layer_graphs)
|
||||||
|
|
||||||
|
print("got here 4444")
|
||||||
#
|
#
|
||||||
# if infrastructure_config.get_config()["connect_documents"] is True:
|
# if infrastructure_config.get_config()["connect_documents"] is True:
|
||||||
# db_engine = infrastructure_config.get_config()["database_engine"]
|
# db_engine = infrastructure_config.get_config()["database_engine"]
|
||||||
|
|
@ -200,7 +256,7 @@ if __name__ == "__main__":
|
||||||
#
|
#
|
||||||
# await add("data://" +data_directory_path, "example")
|
# await add("data://" +data_directory_path, "example")
|
||||||
|
|
||||||
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine()})
|
infrastructure_config.set_config( {"chunk_engine": LangchainChunkEngine() , "chunk_strategy": ChunkStrategy.CODE,'embedding_engine': LiteLLMEmbeddingEngine() })
|
||||||
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
from cognee.shared.SourceCodeGraph import SourceCodeGraph
|
||||||
from cognee.api.v1.config import config
|
from cognee.api.v1.config import config
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,13 +51,13 @@ class Config:
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama
|
llm_provider: str = os.getenv("LLM_PROVIDER","openai") #openai, or custom or ollama
|
||||||
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "mistralai/Mixtral-8x7B-Instruct-v0.1") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
custom_model: str = os.getenv("CUSTOM_LLM_MODEL", "llama3-70b-8192") #"mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
|
custom_endpoint: str = os.getenv("CUSTOM_ENDPOINT", "https://api.endpoints.anyscale.com/v1") #"https://api.endpoints.anyscale.com/v1" # pass claude endpoint
|
||||||
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
|
custom_key: Optional[str] = os.getenv("CUSTOM_LLM_API_KEY")
|
||||||
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
|
ollama_endpoint: str = os.getenv("CUSTOM_OLLAMA_ENDPOINT", "http://localhost:11434/v1") #"http://localhost:11434/v1"
|
||||||
ollama_key: Optional[str] = "ollama"
|
ollama_key: Optional[str] = "ollama"
|
||||||
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
|
ollama_model: str = os.getenv("CUSTOM_OLLAMA_MODEL", "mistral:instruct") #"mistral:instruct"
|
||||||
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4-1106-preview" ) #"gpt-4-1106-preview"
|
openai_model: str = os.getenv("OPENAI_MODEL", "gpt-4o" ) #"gpt-4o"
|
||||||
model_endpoint: str = "openai"
|
model_endpoint: str = "openai"
|
||||||
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
|
openai_key: Optional[str] = os.getenv("OPENAI_API_KEY")
|
||||||
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
|
openai_temperature: float = float(os.getenv("OPENAI_TEMPERATURE", 0.0))
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,37 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from typing import List, Type
|
from typing import List, Type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import instructor
|
import instructor
|
||||||
from tenacity import retry, stop_after_attempt
|
from tenacity import retry, stop_after_attempt
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
from cognee.infrastructure import infrastructure_config
|
||||||
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
from cognee.infrastructure.llm.llm_interface import LLMInterface
|
||||||
from cognee.infrastructure.llm.prompts import read_query_prompt
|
from cognee.infrastructure.llm.prompts import read_query_prompt
|
||||||
|
|
||||||
|
|
||||||
class GenericAPIAdapter(LLMInterface):
|
class GenericAPIAdapter(LLMInterface):
|
||||||
"""Adapter for Ollama's API"""
|
"""Adapter for Generic API LLM provider API """
|
||||||
|
|
||||||
def __init__(self, api_endpoint, api_key: str, model: str):
|
def __init__(self, api_endpoint, api_key: str, model: str):
|
||||||
self.aclient = instructor.patch(
|
|
||||||
AsyncOpenAI(
|
|
||||||
base_url = api_endpoint,
|
if infrastructure_config.get_config()["llm_provider"] == 'groq':
|
||||||
api_key = api_key, # required, but unused
|
from groq import groq
|
||||||
),
|
self.aclient = instructor.from_openai(client = groq.Groq(
|
||||||
mode = instructor.Mode.JSON,
|
api_key=api_key,
|
||||||
)
|
), mode=instructor.Mode.MD_JSON)
|
||||||
|
else:
|
||||||
|
self.aclient = instructor.patch(
|
||||||
|
AsyncOpenAI(
|
||||||
|
base_url = api_endpoint,
|
||||||
|
api_key = api_key, # required, but unused
|
||||||
|
),
|
||||||
|
mode = instructor.Mode.JSON,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
@retry(stop = stop_after_attempt(5))
|
||||||
|
|
@ -75,20 +87,21 @@ class GenericAPIAdapter(LLMInterface):
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(5))
|
@retry(stop = stop_after_attempt(5))
|
||||||
async def acreate_structured_output(self, text_input: str, system_prompt: str,
|
async def acreate_structured_output(self, text_input: str, system_prompt: str, response_model: Type[BaseModel]) -> BaseModel:
|
||||||
response_model: Type[BaseModel]) -> BaseModel:
|
|
||||||
"""Generate a response from a user query."""
|
"""Generate a response from a user query."""
|
||||||
|
|
||||||
return await self.aclient.chat.completions.create(
|
return await self.aclient.chat.completions.create(
|
||||||
model=self.model,
|
model = self.model,
|
||||||
messages=[
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"""Use the given format to
|
"content": f"""Use the given format to
|
||||||
extract information from the following input: {text_input}. {system_prompt} """,
|
extract information from the following input: {text_input}. """,
|
||||||
}
|
},
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
],
|
],
|
||||||
response_model=response_model,
|
response_model = response_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
def show_prompt(self, text_input: str, system_prompt: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,12 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
identified_chunks = []
|
identified_chunks = []
|
||||||
|
|
||||||
for (dataset_name, chunks) in dataset_data_chunks.items():
|
for (dataset_name, chunks) in dataset_data_chunks.items():
|
||||||
# try:
|
try:
|
||||||
# # if not await vector_client.collection_exists(dataset_name):
|
# if not await vector_client.collection_exists(dataset_name):
|
||||||
# # logging.error(f"Creating collection {str(dataset_name)}")
|
# logging.error(f"Creating collection {str(dataset_name)}")
|
||||||
# await vector_client.create_collection(dataset_name)
|
await vector_client.create_collection(dataset_name)
|
||||||
# except Exception:
|
except Exception:
|
||||||
# pass
|
pass
|
||||||
|
|
||||||
dataset_chunks = [
|
dataset_chunks = [
|
||||||
dict(
|
dict(
|
||||||
|
|
@ -33,29 +33,29 @@ async def add_data_chunks(dataset_data_chunks: dict[str, list[TextChunk]]):
|
||||||
|
|
||||||
identified_chunks.extend(dataset_chunks)
|
identified_chunks.extend(dataset_chunks)
|
||||||
|
|
||||||
# # if not await vector_client.collection_exists(dataset_name):
|
# if not await vector_client.collection_exists(dataset_name):
|
||||||
# try:
|
try:
|
||||||
# logging.error("Collection still not found. Creating collection again.")
|
logging.error("Collection still not found. Creating collection again.")
|
||||||
# await vector_client.create_collection(dataset_name)
|
await vector_client.create_collection(dataset_name)
|
||||||
# except:
|
except:
|
||||||
# pass
|
pass
|
||||||
#
|
|
||||||
# async def create_collection_retry(dataset_name, dataset_chunks):
|
async def create_collection_retry(dataset_name, dataset_chunks):
|
||||||
# await vector_client.create_data_points(
|
await vector_client.create_data_points(
|
||||||
# dataset_name,
|
dataset_name,
|
||||||
# [
|
[
|
||||||
# DataPoint(
|
DataPoint(
|
||||||
# id = chunk["chunk_id"],
|
id = chunk["chunk_id"],
|
||||||
# payload = dict(text = chunk["text"]),
|
payload = dict(text = chunk["text"]),
|
||||||
# embed_field = "text"
|
embed_field = "text"
|
||||||
# ) for chunk in dataset_chunks
|
) for chunk in dataset_chunks
|
||||||
# ],
|
],
|
||||||
# )
|
)
|
||||||
#
|
|
||||||
# try:
|
try:
|
||||||
# await create_collection_retry(dataset_name, dataset_chunks)
|
await create_collection_retry(dataset_name, dataset_chunks)
|
||||||
# except Exception:
|
except Exception:
|
||||||
# logging.error("Collection not found in create data points.")
|
logging.error("Collection not found in create data points.")
|
||||||
# await create_collection_retry(dataset_name, dataset_chunks)
|
await create_collection_retry(dataset_name, dataset_chunks)
|
||||||
|
|
||||||
return identified_chunks
|
return identified_chunks
|
||||||
|
|
|
||||||
21
poetry.lock
generated
21
poetry.lock
generated
|
|
@ -2107,6 +2107,25 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
colorama = ">=0.4"
|
colorama = ">=0.4"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "groq"
|
||||||
|
version = "0.5.0"
|
||||||
|
description = "The official Python library for the groq API"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "groq-0.5.0-py3-none-any.whl", hash = "sha256:a7e6be1118bcdfea3ed071ec00f505a34d4e6ec28c435adb5a5afd33545683a1"},
|
||||||
|
{file = "groq-0.5.0.tar.gz", hash = "sha256:d476cdc3383b45d2a4dc1876142a9542e663ea1029f9e07a05de24f895cae48c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = ">=3.5.0,<5"
|
||||||
|
distro = ">=1.7.0,<2"
|
||||||
|
httpx = ">=0.23.0,<1"
|
||||||
|
pydantic = ">=1.9.0,<3"
|
||||||
|
sniffio = "*"
|
||||||
|
typing-extensions = ">=4.7,<5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "grpcio"
|
name = "grpcio"
|
||||||
version = "1.63.0"
|
version = "1.63.0"
|
||||||
|
|
@ -7837,4 +7856,4 @@ weaviate = ["weaviate-client"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9.0,<3.12"
|
python-versions = ">=3.9.0,<3.12"
|
||||||
content-hash = "459e5945fb184ecdc2c9a817e89b2a83b9e0c97273d1a7df8f75c9c0b355ba47"
|
content-hash = "4c2e75aba3260da9e4023e5fa1c4d117b06e4f042bda9441664a0e91af4069a1"
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,7 @@ lancedb = "^0.6.10"
|
||||||
importlib-metadata = "6.8.0"
|
importlib-metadata = "6.8.0"
|
||||||
deepeval = "^0.21.36"
|
deepeval = "^0.21.36"
|
||||||
litellm = "^1.37.3"
|
litellm = "^1.37.3"
|
||||||
|
groq = "^0.5.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue