Fix search
This commit is contained in:
parent
c6586cdedc
commit
0bf5e8d047
4 changed files with 30 additions and 10 deletions
|
|
@ -25,7 +25,7 @@ async def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]])
|
|||
SearchType.NEIGHBOR: search_neighbour,
|
||||
}
|
||||
|
||||
results = set()
|
||||
results = []
|
||||
|
||||
# Create a list to hold all the coroutine objects
|
||||
search_tasks = []
|
||||
|
|
@ -34,7 +34,8 @@ async def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]])
|
|||
search_func = search_functions.get(search_type)
|
||||
if search_func:
|
||||
# Schedule the coroutine for execution and store the task
|
||||
task = search_func(graph, **params)
|
||||
full_params = {**params, 'graph': graph}
|
||||
task = search_func(**full_params)
|
||||
search_tasks.append(task)
|
||||
|
||||
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||
|
|
@ -42,9 +43,9 @@ async def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]])
|
|||
|
||||
# Update the results set with the results from all tasks
|
||||
for search_result in search_results:
|
||||
results.update(search_result)
|
||||
results.append(search_result)
|
||||
|
||||
return list(results)
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
|
|
|||
|
|
@ -65,7 +65,8 @@ class QDrantAdapter(VectorDBInterface):
|
|||
|
||||
return await client.search(
|
||||
collection_name = collection_name,
|
||||
query_vector = query_vector,
|
||||
query_vector = (
|
||||
"content", query_vector),
|
||||
limit = limit,
|
||||
with_vectors = with_vector
|
||||
)
|
||||
|
|
|
|||
|
|
@ -28,12 +28,12 @@ class OpenAIAdapter(LLMInterface):
|
|||
return await openai.chat.completions.acreate(**kwargs)
|
||||
|
||||
@retry(stop = stop_after_attempt(5))
|
||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-ada-002"):
|
||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||
|
||||
return await self.aclient.embeddings.create(input=input, model=model)
|
||||
|
||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
|
||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting + is async"""
|
||||
text = text.replace("\n", " ")
|
||||
|
|
@ -46,7 +46,7 @@ class OpenAIAdapter(LLMInterface):
|
|||
"""Wrapper around Embedding.create w/ backoff"""
|
||||
return openai.embeddings.create(**kwargs)
|
||||
|
||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
|
||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
||||
"""To get text embeddings, import/call this function
|
||||
It specifies defaults + handles rate-limiting
|
||||
:param text: str
|
||||
|
|
|
|||
|
|
@ -15,7 +15,12 @@ async def search_similarity(query:str ,graph):
|
|||
query = await client.async_get_embedding_with_backoff(query)
|
||||
# print(query)
|
||||
for id in unique_layer_uuids:
|
||||
result = client.search(id, query[0])
|
||||
from cognitive_architecture.infrastructure.databases.vector.get_vector_database import get_vector_database
|
||||
vector_client = get_vector_database()
|
||||
|
||||
print(query)
|
||||
|
||||
result = await vector_client.search(id, query,10)
|
||||
|
||||
if result:
|
||||
result_ = [ result_.id for result_ in result]
|
||||
|
|
@ -23,4 +28,17 @@ async def search_similarity(query:str ,graph):
|
|||
|
||||
out.append([result_, score_])
|
||||
|
||||
return out
|
||||
relevant_context = []
|
||||
|
||||
for proposition_id in out[0][0]:
|
||||
print(proposition_id)
|
||||
for n,attr in graph.nodes(data=True):
|
||||
if proposition_id in n:
|
||||
for n_, attr_ in graph.nodes(data=True):
|
||||
relevant_layer = attr['layer_uuid']
|
||||
|
||||
if attr_.get('layer_uuid') == relevant_layer:
|
||||
print(attr_['description'])
|
||||
relevant_context.append(attr_['description'])
|
||||
|
||||
return relevant_context
|
||||
Loading…
Add table
Reference in a new issue