fix: add max_tokens to all ExtractKnowledgeGraph calls

This commit is contained in:
Boris Arzentar 2024-04-20 20:16:33 +02:00
parent 30055cc60c
commit b2aecf833b
2 changed files with 3 additions and 2 deletions

View file

@ -36,7 +36,8 @@ def evaluate():
evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096)
compiled_extract_knowledge_graph = ExtractKnowledgeGraph()
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
def evaluate_answer(example, graph_prediction, trace = None):

View file

@ -41,7 +41,7 @@ def are_all_nodes_connected(graph: KnowledgeGraph) -> bool:
class ExtractKnowledgeGraph(dspy.Module):
def __init__(self, lm = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat")):
def __init__(self, lm = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)):
super().__init__()
self.lm = lm
dspy.settings.configure(lm=self.lm)