Fix search
This commit is contained in:
parent
c6586cdedc
commit
0bf5e8d047
4 changed files with 30 additions and 10 deletions
|
|
@ -25,7 +25,7 @@ async def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]])
|
||||||
SearchType.NEIGHBOR: search_neighbour,
|
SearchType.NEIGHBOR: search_neighbour,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = set()
|
results = []
|
||||||
|
|
||||||
# Create a list to hold all the coroutine objects
|
# Create a list to hold all the coroutine objects
|
||||||
search_tasks = []
|
search_tasks = []
|
||||||
|
|
@ -34,7 +34,8 @@ async def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]])
|
||||||
search_func = search_functions.get(search_type)
|
search_func = search_functions.get(search_type)
|
||||||
if search_func:
|
if search_func:
|
||||||
# Schedule the coroutine for execution and store the task
|
# Schedule the coroutine for execution and store the task
|
||||||
task = search_func(graph, **params)
|
full_params = {**params, 'graph': graph}
|
||||||
|
task = search_func(**full_params)
|
||||||
search_tasks.append(task)
|
search_tasks.append(task)
|
||||||
|
|
||||||
# Use asyncio.gather to run all scheduled tasks concurrently
|
# Use asyncio.gather to run all scheduled tasks concurrently
|
||||||
|
|
@ -42,9 +43,9 @@ async def complex_search(graph, query_params: Dict[SearchType, Dict[str, Any]])
|
||||||
|
|
||||||
# Update the results set with the results from all tasks
|
# Update the results set with the results from all tasks
|
||||||
for search_result in search_results:
|
for search_result in search_results:
|
||||||
results.update(search_result)
|
results.append(search_result)
|
||||||
|
|
||||||
return list(results)
|
return results
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,8 @@ class QDrantAdapter(VectorDBInterface):
|
||||||
|
|
||||||
return await client.search(
|
return await client.search(
|
||||||
collection_name = collection_name,
|
collection_name = collection_name,
|
||||||
query_vector = query_vector,
|
query_vector = (
|
||||||
|
"content", query_vector),
|
||||||
limit = limit,
|
limit = limit,
|
||||||
with_vectors = with_vector
|
with_vectors = with_vector
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -28,12 +28,12 @@ class OpenAIAdapter(LLMInterface):
|
||||||
return await openai.chat.completions.acreate(**kwargs)
|
return await openai.chat.completions.acreate(**kwargs)
|
||||||
|
|
||||||
@retry(stop = stop_after_attempt(5))
|
@retry(stop = stop_after_attempt(5))
|
||||||
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-ada-002"):
|
async def acreate_embedding_with_backoff(self, input: List[str], model: str = "text-embedding-3-large"):
|
||||||
"""Wrapper around Embedding.acreate w/ backoff"""
|
"""Wrapper around Embedding.acreate w/ backoff"""
|
||||||
|
|
||||||
return await self.aclient.embeddings.create(input=input, model=model)
|
return await self.aclient.embeddings.create(input=input, model=model)
|
||||||
|
|
||||||
async def async_get_embedding_with_backoff(self, text, model="text-embedding-ada-002"):
|
async def async_get_embedding_with_backoff(self, text, model="text-embedding-3-large"):
|
||||||
"""To get text embeddings, import/call this function
|
"""To get text embeddings, import/call this function
|
||||||
It specifies defaults + handles rate-limiting + is async"""
|
It specifies defaults + handles rate-limiting + is async"""
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
|
|
@ -46,7 +46,7 @@ class OpenAIAdapter(LLMInterface):
|
||||||
"""Wrapper around Embedding.create w/ backoff"""
|
"""Wrapper around Embedding.create w/ backoff"""
|
||||||
return openai.embeddings.create(**kwargs)
|
return openai.embeddings.create(**kwargs)
|
||||||
|
|
||||||
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-ada-002"):
|
def get_embedding_with_backoff(self, text: str, model: str = "text-embedding-3-large"):
|
||||||
"""To get text embeddings, import/call this function
|
"""To get text embeddings, import/call this function
|
||||||
It specifies defaults + handles rate-limiting
|
It specifies defaults + handles rate-limiting
|
||||||
:param text: str
|
:param text: str
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,12 @@ async def search_similarity(query:str ,graph):
|
||||||
query = await client.async_get_embedding_with_backoff(query)
|
query = await client.async_get_embedding_with_backoff(query)
|
||||||
# print(query)
|
# print(query)
|
||||||
for id in unique_layer_uuids:
|
for id in unique_layer_uuids:
|
||||||
result = client.search(id, query[0])
|
from cognitive_architecture.infrastructure.databases.vector.get_vector_database import get_vector_database
|
||||||
|
vector_client = get_vector_database()
|
||||||
|
|
||||||
|
print(query)
|
||||||
|
|
||||||
|
result = await vector_client.search(id, query,10)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
result_ = [ result_.id for result_ in result]
|
result_ = [ result_.id for result_ in result]
|
||||||
|
|
@ -23,4 +28,17 @@ async def search_similarity(query:str ,graph):
|
||||||
|
|
||||||
out.append([result_, score_])
|
out.append([result_, score_])
|
||||||
|
|
||||||
return out
|
relevant_context = []
|
||||||
|
|
||||||
|
for proposition_id in out[0][0]:
|
||||||
|
print(proposition_id)
|
||||||
|
for n,attr in graph.nodes(data=True):
|
||||||
|
if proposition_id in n:
|
||||||
|
for n_, attr_ in graph.nodes(data=True):
|
||||||
|
relevant_layer = attr['layer_uuid']
|
||||||
|
|
||||||
|
if attr_.get('layer_uuid') == relevant_layer:
|
||||||
|
print(attr_['description'])
|
||||||
|
relevant_context.append(attr_['description'])
|
||||||
|
|
||||||
|
return relevant_context
|
||||||
Loading…
Add table
Reference in a new issue