fix: add max_tokens to all ExtractKnowledgeGraph calls
This commit is contained in:
parent
30055cc60c
commit
b2aecf833b
2 changed files with 3 additions and 2 deletions
|
|
@ -36,7 +36,8 @@ def evaluate():
|
||||||
|
|
||||||
evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096)
|
evaluate_on_hotpotqa = Evaluate(devset = devset, num_threads = 1, display_progress = True, display_table = 5, max_tokens = 4096)
|
||||||
|
|
||||||
compiled_extract_knowledge_graph = ExtractKnowledgeGraph()
|
gpt4 = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)
|
||||||
|
compiled_extract_knowledge_graph = ExtractKnowledgeGraph(lm = gpt4)
|
||||||
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
|
compiled_extract_knowledge_graph.load(get_absolute_path("./programs/extract_knowledge_graph/extract_knowledge_graph.json"))
|
||||||
|
|
||||||
def evaluate_answer(example, graph_prediction, trace = None):
|
def evaluate_answer(example, graph_prediction, trace = None):
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ def are_all_nodes_connected(graph: KnowledgeGraph) -> bool:
|
||||||
|
|
||||||
|
|
||||||
class ExtractKnowledgeGraph(dspy.Module):
|
class ExtractKnowledgeGraph(dspy.Module):
|
||||||
def __init__(self, lm = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat")):
|
def __init__(self, lm = dspy.OpenAI(model = config.openai_model, api_key = config.openai_key, model_type = "chat", max_tokens = 4096)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lm = lm
|
self.lm = lm
|
||||||
dspy.settings.configure(lm=self.lm)
|
dspy.settings.configure(lm=self.lm)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue