Fixes to memory component
This commit is contained in:
parent
fee4982aa2
commit
0a38e09b3f
5 changed files with 36 additions and 38 deletions
|
|
@ -1,11 +1,10 @@
|
||||||
|
""" This module contains the classifiers for the documents. """
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
import json
|
import json
|
||||||
|
from langchain.document_loaders import TextLoader
|
||||||
# TO DO, ADD ALL CLASSIFIERS HERE
|
from langchain.document_loaders import DirectoryLoader
|
||||||
|
|
||||||
|
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
|
@ -15,17 +14,19 @@ from ..database.vectordb.loaders.loaders import _document_loader
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
OPENAI_API_KEY = config.openai_key
|
OPENAI_API_KEY = config.openai_key
|
||||||
from langchain.document_loaders import TextLoader
|
|
||||||
from langchain.document_loaders import DirectoryLoader
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_documents(query: str, document_id: str, content: str):
|
async def classify_documents(query: str, document_id: str, content: str):
|
||||||
|
"""Classify the documents based on the query and content."""
|
||||||
document_context = content
|
document_context = content
|
||||||
logging.info("This is the document context", document_context)
|
logging.info("This is the document context", document_context)
|
||||||
|
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
prompt_classify = ChatPromptTemplate.from_template(
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
"""You are a summarizer and classifier. Determine what book this is and where does it belong in the output : {query}, Id: {d_id} Document context is: {context}"""
|
"""You are a summarizer and classifier.
|
||||||
|
Determine what book this is and where does it belong in the output :
|
||||||
|
{query}, Id: {d_id} Document context is: {context}"""
|
||||||
)
|
)
|
||||||
json_structure = [
|
json_structure = [
|
||||||
{
|
{
|
||||||
|
|
@ -36,7 +37,8 @@ async def classify_documents(query: str, document_id: str, content: str):
|
||||||
"properties": {
|
"properties": {
|
||||||
"DocumentCategory": {
|
"DocumentCategory": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The classification of documents in groups such as legal, medical, etc.",
|
"description": "The classification of documents "
|
||||||
|
"in groups such as legal, medical, etc.",
|
||||||
},
|
},
|
||||||
"Title": {
|
"Title": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|
@ -58,7 +60,8 @@ async def classify_documents(query: str, document_id: str, content: str):
|
||||||
classifier_output = await chain_filter.ainvoke(
|
classifier_output = await chain_filter.ainvoke(
|
||||||
{"query": query, "d_id": document_id, "context": str(document_context)}
|
{"query": query, "d_id": document_id, "context": str(document_context)}
|
||||||
)
|
)
|
||||||
|
|
||||||
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
arguments_str = classifier_output.additional_kwargs["function_call"]["arguments"]
|
||||||
logging.info("This is the arguments string %s", arguments_str)
|
logging.info("This is the arguments string %s", arguments_str)
|
||||||
arguments_dict = json.loads(arguments_str)
|
arguments_dict = json.loads(arguments_str)
|
||||||
return arguments_dict
|
return arguments_dict
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,16 @@
|
||||||
|
""" This module contains the function to classify a summary of a document. """
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# TO DO, ADD ALL CLASSIFIERS HERE
|
|
||||||
|
|
||||||
|
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..database.vectordb.loaders.loaders import _document_loader
|
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
OPENAI_API_KEY = config.openai_key
|
OPENAI_API_KEY = config.openai_key
|
||||||
from langchain.document_loaders import TextLoader
|
|
||||||
from langchain.document_loaders import DirectoryLoader
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,9 +18,12 @@ from langchain.document_loaders import DirectoryLoader
|
||||||
|
|
||||||
|
|
||||||
async def classify_summary(query, document_summaries):
|
async def classify_summary(query, document_summaries):
|
||||||
|
"""Classify the documents based on the query and content."""
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
prompt_classify = ChatPromptTemplate.from_template(
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
"""You are a classifier. Determine what document are relevant for the given query: {query}, Document summaries and ids:{document_summaries}"""
|
"""You are a classifier. Determine what document
|
||||||
|
are relevant for the given query: {query},
|
||||||
|
Document summaries and ids:{document_summaries}"""
|
||||||
)
|
)
|
||||||
json_structure = [
|
json_structure = [
|
||||||
{
|
{
|
||||||
|
|
@ -37,7 +34,8 @@ async def classify_summary(query, document_summaries):
|
||||||
"properties": {
|
"properties": {
|
||||||
"DocumentSummary": {
|
"DocumentSummary": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The summary of the document and the topic it deals with.",
|
"description": "The summary of the document "
|
||||||
|
"and the topic it deals with.",
|
||||||
},
|
},
|
||||||
"d_id": {"type": "string", "description": "The id of the document"},
|
"d_id": {"type": "string", "description": "The id of the document"},
|
||||||
},
|
},
|
||||||
|
|
@ -59,4 +57,4 @@ async def classify_summary(query, document_summaries):
|
||||||
|
|
||||||
logging.info("This is the classifier id %s", classfier_id)
|
logging.info("This is the classifier id %s", classfier_id)
|
||||||
|
|
||||||
return classfier_id
|
return classfier_id
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
|
""" This module contains the classifiers for the documents. """
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# TO DO, ADD ALL CLASSIFIERS HERE
|
|
||||||
|
|
||||||
|
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
|
@ -15,14 +15,14 @@ from ..database.vectordb.loaders.loaders import _document_loader
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
OPENAI_API_KEY = config.openai_key
|
OPENAI_API_KEY = config.openai_key
|
||||||
from langchain.document_loaders import TextLoader
|
|
||||||
from langchain.document_loaders import DirectoryLoader
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_user_input(query, input_type):
|
async def classify_user_input(query, input_type):
|
||||||
|
""" Classify the user input based on the query and input type."""
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
prompt_classify = ChatPromptTemplate.from_template(
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
"""You are a classifier. Determine with a True or False if the following input: {query}, is relevant for the following memory category: {input_type}"""
|
"""You are a classifier.
|
||||||
|
Determine with a True or False if the following input: {query},
|
||||||
|
is relevant for the following memory category: {input_type}"""
|
||||||
)
|
)
|
||||||
json_structure = [
|
json_structure = [
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,10 @@
|
||||||
import logging
|
""" This module contains the function to classify the user query. """
|
||||||
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# TO DO, ADD ALL CLASSIFIERS HERE
|
|
||||||
|
|
||||||
|
|
||||||
from langchain.chains import create_extraction_chain
|
from langchain.chains import create_extraction_chain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.document_loaders import TextLoader
|
||||||
|
from langchain.document_loaders import DirectoryLoader
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..database.vectordb.loaders.loaders import _document_loader
|
from ..database.vectordb.loaders.loaders import _document_loader
|
||||||
|
|
@ -15,14 +12,15 @@ from ..database.vectordb.loaders.loaders import _document_loader
|
||||||
config = Config()
|
config = Config()
|
||||||
config.load()
|
config.load()
|
||||||
OPENAI_API_KEY = config.openai_key
|
OPENAI_API_KEY = config.openai_key
|
||||||
from langchain.document_loaders import TextLoader
|
|
||||||
from langchain.document_loaders import DirectoryLoader
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_user_query(query, context, document_types):
|
async def classify_user_query(query, context, document_types):
|
||||||
|
"""Classify the user query based on the context and document types."""
|
||||||
llm = ChatOpenAI(temperature=0, model=config.model)
|
llm = ChatOpenAI(temperature=0, model=config.model)
|
||||||
prompt_classify = ChatPromptTemplate.from_template(
|
prompt_classify = ChatPromptTemplate.from_template(
|
||||||
"""You are a classifier. You store user memories, thoughts and feelings. Determine if you need to use them to answer this query : {query}"""
|
"""You are a classifier.
|
||||||
|
You store user memories, thoughts and feelings.
|
||||||
|
Determine if you need to use them to answer this query : {query}"""
|
||||||
)
|
)
|
||||||
json_structure = [
|
json_structure = [
|
||||||
{
|
{
|
||||||
|
|
@ -33,7 +31,8 @@ async def classify_user_query(query, context, document_types):
|
||||||
"properties": {
|
"properties": {
|
||||||
"UserQueryClassifier": {
|
"UserQueryClassifier": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"description": "The classification of documents in groups such as legal, medical, etc.",
|
"description": "The classification of documents "
|
||||||
|
"in groups such as legal, medical, etc.",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["UserQueryClassifier"],
|
"required": ["UserQueryClassifier"],
|
||||||
|
|
@ -50,7 +49,5 @@ async def classify_user_query(query, context, document_types):
|
||||||
print("This is the arguments string", arguments_str)
|
print("This is the arguments string", arguments_str)
|
||||||
arguments_dict = json.loads(arguments_str)
|
arguments_dict = json.loads(arguments_str)
|
||||||
classfier_value = arguments_dict.get("UserQueryClassifier", None)
|
classfier_value = arguments_dict.get("UserQueryClassifier", None)
|
||||||
|
|
||||||
print("This is the classifier value", classfier_value)
|
print("This is the classifier value", classfier_value)
|
||||||
|
|
||||||
return classfier_value
|
return classfier_value
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import networkx as nx
|
||||||
|
|
||||||
|
|
||||||
class NetworkXGraphDB:
|
class NetworkXGraphDB:
|
||||||
def __init__(self, filename="networkx_graph.pkl"):
|
def __init__(self, filename="cognee_graph.pkl"):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
try:
|
try:
|
||||||
self.graph = self.load_graph() # Attempt to load an existing graph
|
self.graph = self.load_graph() # Attempt to load an existing graph
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue