Initial Commit
This commit is contained in:
parent
7ab699955e
commit
8f25798a91
13 changed files with 262 additions and 552 deletions
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
.env
|
||||
venv/
|
||||
lightrag_hku.egg-info/
|
||||
output/
|
||||
ragtest/
|
||||
__pycache__/
|
||||
330
README.md
330
README.md
|
|
@ -1,122 +1,21 @@
|
|||
<center><h2>🚀 LightRAG: Simple and Fast Retrieval-Augmented Generation</h2></center>
|
||||
|
||||
|
||||

|
||||
|
||||
<div align='center'>
|
||||
<p>
|
||||
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
|
||||
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
|
||||
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
|
||||
</p>
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/python->=3.9.11-blue">
|
||||
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
||||
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
||||
</p>
|
||||
|
||||
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
||||

|
||||
</div>
|
||||
|
||||
## 🎉 News
|
||||
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
||||
<center><h2>🚀 PalmierRAG: State-of-the-Art RAG for Any Codebase</h2></center>
|
||||
|
||||
## Install
|
||||
|
||||
* Install from source (Recommend)
|
||||
* Install from source
|
||||
|
||||
```bash
|
||||
cd LightRAG
|
||||
cd palmier-lightrag
|
||||
pip install -e .
|
||||
```
|
||||
* Install from PyPI
|
||||
```bash
|
||||
pip install lightrag-hku
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
* All the code can be found in the `examples`.
|
||||
* Set OpenAI API key in environment if using OpenAI models: `export OPENAI_API_KEY="sk-...".`
|
||||
* Download the demo text "A Christmas Carol by Charles Dickens":
|
||||
```bash
|
||||
curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
|
||||
```
|
||||
Use the below Python snippet to initialize LightRAG and perform queries:
|
||||
|
||||
```python
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||
|
||||
WORKING_DIR = "./dickens"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
|
||||
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
||||
)
|
||||
|
||||
with open("./book.txt") as f:
|
||||
rag.insert(f.read())
|
||||
|
||||
# Perform naive search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
|
||||
|
||||
# Perform local search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
|
||||
|
||||
# Perform global search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
|
||||
|
||||
# Perform hybrid search
|
||||
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary> Using Open AI-like APIs </summary>
|
||||
|
||||
LightRAG also support Open AI-like chat/embeddings APIs:
|
||||
```python
|
||||
async def llm_model_func(
|
||||
prompt, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
return await openai_complete_if_cache(
|
||||
"solar-mini",
|
||||
prompt,
|
||||
system_prompt=system_prompt,
|
||||
history_messages=history_messages,
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar",
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def embedding_func(texts: list[str]) -> np.ndarray:
|
||||
return await openai_embedding(
|
||||
texts,
|
||||
model="solar-embedding-1-large-query",
|
||||
api_key=os.getenv("UPSTAGE_API_KEY"),
|
||||
base_url="https://api.upstage.ai/v1/solar"
|
||||
)
|
||||
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=llm_model_func,
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=4096,
|
||||
max_token_size=8192,
|
||||
func=embedding_func
|
||||
)
|
||||
)
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary> Using Hugging Face Models </summary>
|
||||
|
||||
* Set OpenAI API key in environment if using OpenAI models: `export OPENAI_API_KEY="sk-...".` OR add .env to `lightrag/`
|
||||
* Create `ragtest/input` directory containing `.txt` files to be indexed.
|
||||
* Run `python index.py` to index the documents.
|
||||
* Modify `query.py` with desired query and run `python query.py` to query the documents.
|
||||
### Using Hugging Face Models
|
||||
If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
||||
```python
|
||||
from lightrag.llm import hf_model_complete, hf_embedding
|
||||
|
|
@ -125,7 +24,7 @@ from transformers import AutoModel, AutoTokenizer
|
|||
# Initialize LightRAG with Hugging Face model
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
|
||||
llm_model_func=hf_model_complete, # Use Hugging Face complete model for text generation
|
||||
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
|
||||
# Use Hugging Face embedding function
|
||||
embedding_func=EmbeddingFunc(
|
||||
|
|
@ -139,39 +38,11 @@ rag = LightRAG(
|
|||
),
|
||||
)
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary> Using Ollama Models (There are some bugs. I'll fix them ASAP.) </summary>
|
||||
If you want to use Ollama models, you only need to set LightRAG as follows:
|
||||
|
||||
```python
|
||||
from lightrag.llm import ollama_model_complete, ollama_embedding
|
||||
|
||||
# Initialize LightRAG with Ollama model
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=ollama_model_complete, # Use Ollama model for text generation
|
||||
llm_model_name='your_model_name', # Your model name
|
||||
# Use Ollama embedding function
|
||||
embedding_func=EmbeddingFunc(
|
||||
embedding_dim=768,
|
||||
max_token_size=8192,
|
||||
func=lambda texts: ollama_embedding(
|
||||
texts,
|
||||
embed_model="nomic-embed-text"
|
||||
)
|
||||
),
|
||||
)
|
||||
```
|
||||
</details>
|
||||
|
||||
### Batch Insert
|
||||
```python
|
||||
# Batch Insert: Insert multiple texts at once
|
||||
rag.insert(["TEXT1", "TEXT2",...])
|
||||
```
|
||||
|
||||
### Incremental Insert
|
||||
|
||||
```python
|
||||
|
|
@ -187,10 +58,6 @@ The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https
|
|||
|
||||
### Generate Query
|
||||
LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
|
||||
|
||||
<details>
|
||||
<summary> Prompt </summary>
|
||||
|
||||
```python
|
||||
Given the following description of a dataset:
|
||||
|
||||
|
|
@ -214,14 +81,9 @@ Output the results in the following structure:
|
|||
- User 5: [user description]
|
||||
...
|
||||
```
|
||||
</details>
|
||||
|
||||
### Batch Eval
|
||||
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
|
||||
|
||||
<details>
|
||||
<summary> Prompt </summary>
|
||||
|
||||
```python
|
||||
---Role---
|
||||
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
||||
|
|
@ -264,162 +126,6 @@ Output your evaluation in the following JSON format:
|
|||
}}
|
||||
}}
|
||||
```
|
||||
</details>
|
||||
|
||||
### Overall Performance Table
|
||||
| | **Agriculture** | | **CS** | | **Legal** | | **Mix** | |
|
||||
|----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|
|
||||
| | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** |
|
||||
| **Comprehensiveness** | 32.69% | **67.31%** | 35.44% | **64.56%** | 19.05% | **80.95%** | 36.36% | **63.64%** |
|
||||
| **Diversity** | 24.09% | **75.91%** | 35.24% | **64.76%** | 10.98% | **89.02%** | 30.76% | **69.24%** |
|
||||
| **Empowerment** | 31.35% | **68.65%** | 35.48% | **64.52%** | 17.59% | **82.41%** | 40.95% | **59.05%** |
|
||||
| **Overall** | 33.30% | **66.70%** | 34.76% | **65.24%** | 17.46% | **82.54%** | 37.59% | **62.40%** |
|
||||
| | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** |
|
||||
| **Comprehensiveness** | 32.05% | **67.95%** | 39.30% | **60.70%** | 18.57% | **81.43%** | 38.89% | **61.11%** |
|
||||
| **Diversity** | 29.44% | **70.56%** | 38.71% | **61.29%** | 15.14% | **84.86%** | 28.50% | **71.50%** |
|
||||
| **Empowerment** | 32.51% | **67.49%** | 37.52% | **62.48%** | 17.80% | **82.20%** | 43.96% | **56.04%** |
|
||||
| **Overall** | 33.29% | **66.71%** | 39.03% | **60.97%** | 17.80% | **82.20%** | 39.61% | **60.39%** |
|
||||
| | HyDE | **LightRAG** | HyDE | **LightRAG** | HyDE | **LightRAG** | HyDE | **LightRAG** |
|
||||
| **Comprehensiveness** | 24.39% | **75.61%** | 36.49% | **63.51%** | 27.68% | **72.32%** | 42.17% | **57.83%** |
|
||||
| **Diversity** | 24.96% | **75.34%** | 37.41% | **62.59%** | 18.79% | **81.21%** | 30.88% | **69.12%** |
|
||||
| **Empowerment** | 24.89% | **75.11%** | 34.99% | **65.01%** | 26.99% | **73.01%** | **45.61%** | **54.39%** |
|
||||
| **Overall** | 23.17% | **76.83%** | 35.67% | **64.33%** | 27.68% | **72.32%** | 42.72% | **57.28%** |
|
||||
| | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** |
|
||||
| **Comprehensiveness** | 45.56% | **54.44%** | 45.98% | **54.02%** | 47.13% | **52.87%** | **51.86%** | 48.14% |
|
||||
| **Diversity** | 19.65% | **80.35%** | 39.64% | **60.36%** | 25.55% | **74.45%** | 35.87% | **64.13%** |
|
||||
| **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
|
||||
| **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
|
||||
|
||||
## Reproduce
|
||||
All the code can be found in the `./reproduce` directory.
|
||||
|
||||
### Step-0 Extract Unique Contexts
|
||||
First, we need to extract unique contexts in the datasets.
|
||||
|
||||
<details>
|
||||
<summary> Code </summary>
|
||||
|
||||
```python
|
||||
def extract_unique_contexts(input_directory, output_directory):
|
||||
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
|
||||
print(f"Found {len(jsonl_files)} JSONL files.")
|
||||
|
||||
for file_path in jsonl_files:
|
||||
filename = os.path.basename(file_path)
|
||||
name, ext = os.path.splitext(filename)
|
||||
output_filename = f"{name}_unique_contexts.json"
|
||||
output_path = os.path.join(output_directory, output_filename)
|
||||
|
||||
unique_contexts_dict = {}
|
||||
|
||||
print(f"Processing file: {filename}")
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as infile:
|
||||
for line_number, line in enumerate(infile, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
context = json_obj.get('context')
|
||||
if context and context not in unique_contexts_dict:
|
||||
unique_contexts_dict[context] = None
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {filename}")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"An error occurred while processing file {filename}: {e}")
|
||||
continue
|
||||
|
||||
unique_contexts_list = list(unique_contexts_dict.keys())
|
||||
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as outfile:
|
||||
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
||||
print(f"Unique `context` entries have been saved to: {output_filename}")
|
||||
except Exception as e:
|
||||
print(f"An error occurred while saving to the file {output_filename}: {e}")
|
||||
|
||||
print("All files have been processed.")
|
||||
|
||||
```
|
||||
</details>
|
||||
|
||||
### Step-1 Insert Contexts
|
||||
For the extracted contexts, we insert them into the LightRAG system.
|
||||
|
||||
<details>
|
||||
<summary> Code </summary>
|
||||
|
||||
```python
|
||||
def insert_text(rag, file_path):
|
||||
with open(file_path, mode='r') as f:
|
||||
unique_contexts = json.load(f)
|
||||
|
||||
retries = 0
|
||||
max_retries = 3
|
||||
while retries < max_retries:
|
||||
try:
|
||||
rag.insert(unique_contexts)
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
|
||||
time.sleep(10)
|
||||
if retries == max_retries:
|
||||
print("Insertion failed after exceeding the maximum number of retries")
|
||||
```
|
||||
</details>
|
||||
|
||||
### Step-2 Generate Queries
|
||||
|
||||
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
|
||||
|
||||
<details>
|
||||
<summary> Code </summary>
|
||||
|
||||
```python
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
def get_summary(context, tot_tokens=2000):
|
||||
tokens = tokenizer.tokenize(context)
|
||||
half_tokens = tot_tokens // 2
|
||||
|
||||
start_tokens = tokens[1000:1000 + half_tokens]
|
||||
end_tokens = tokens[-(1000 + half_tokens):1000]
|
||||
|
||||
summary_tokens = start_tokens + end_tokens
|
||||
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
||||
|
||||
return summary
|
||||
```
|
||||
</details>
|
||||
|
||||
### Step-3 Query
|
||||
For the queries generated in Step-2, we will extract them and query LightRAG.
|
||||
|
||||
<details>
|
||||
<summary> Code </summary>
|
||||
|
||||
```python
|
||||
def extract_queries(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
data = f.read()
|
||||
|
||||
data = data.replace('**', '')
|
||||
|
||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
||||
|
||||
return queries
|
||||
```
|
||||
</details>
|
||||
|
||||
## Code Structure
|
||||
|
||||
|
|
@ -428,10 +134,8 @@ def extract_queries(file_path):
|
|||
├── examples
|
||||
│ ├── batch_eval.py
|
||||
│ ├── generate_query.py
|
||||
│ ├── lightrag_hf_demo.py
|
||||
│ ├── lightrag_ollama_demo.py
|
||||
│ ├── lightrag_openai_compatible_demo.py
|
||||
│ └── lightrag_openai_demo.py
|
||||
│ ├── lightrag_openai_demo.py
|
||||
│ └── lightrag_hf_demo.py
|
||||
├── lightrag
|
||||
│ ├── __init__.py
|
||||
│ ├── base.py
|
||||
|
|
@ -452,16 +156,6 @@ def extract_queries(file_path):
|
|||
└── setup.py
|
||||
```
|
||||
|
||||
## Star History
|
||||
|
||||
<a href="https://star-history.com/#HKUDS/LightRAG&Date">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=HKUDS/LightRAG&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=HKUDS/LightRAG&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=HKUDS/LightRAG&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
|
||||
## Citation
|
||||
|
||||
```python
|
||||
|
|
@ -473,4 +167,4 @@ eprint={2410.05779},
|
|||
archivePrefix={arXiv},
|
||||
primaryClass={cs.IR}
|
||||
}
|
||||
```
|
||||
```
|
||||
21
index.py
Normal file
21
index.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from lightrag import LightRAG
|
||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||
import os
|
||||
WORKING_DIR = "./ragtest"
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
|
||||
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
||||
)
|
||||
# Load all .txt files from the input folder
|
||||
input_folder = os.path.join(WORKING_DIR, "input")
|
||||
texts_to_insert = []
|
||||
for filename in os.listdir(input_folder):
|
||||
if filename.endswith(".txt"):
|
||||
file_path = os.path.join(input_folder, filename)
|
||||
with open(file_path, "r") as f:
|
||||
texts_to_insert.append(f.read())
|
||||
# Batch insert all loaded texts
|
||||
rag.insert(texts_to_insert)
|
||||
|
|
@ -5,6 +5,7 @@ from datetime import datetime
|
|||
from functools import partial
|
||||
from typing import Type, cast, Any
|
||||
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
|
||||
from .operate import (
|
||||
|
|
@ -49,6 +50,8 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|||
|
||||
@dataclass
|
||||
class LightRAG:
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
working_dir: str = field(
|
||||
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
|
||||
)
|
||||
|
|
|
|||
19
query.py
Normal file
19
query.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
from lightrag import LightRAG, QueryParam
|
||||
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
|
||||
import os
|
||||
WORKING_DIR = "./ragtest"
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
rag = LightRAG(
|
||||
working_dir=WORKING_DIR,
|
||||
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
|
||||
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
|
||||
)
|
||||
# Perform naive search
|
||||
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="naive")))
|
||||
# Perform local search
|
||||
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="local")))
|
||||
# Perform global search
|
||||
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="global")))
|
||||
# Perform hybrid search
|
||||
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="hybrid")))
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import glob
|
||||
import argparse
|
||||
|
||||
def extract_unique_contexts(input_directory, output_directory):
|
||||
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
|
||||
print(f"Found {len(jsonl_files)} JSONL files.")
|
||||
|
||||
for file_path in jsonl_files:
|
||||
filename = os.path.basename(file_path)
|
||||
name, ext = os.path.splitext(filename)
|
||||
output_filename = f"{name}_unique_contexts.json"
|
||||
output_path = os.path.join(output_directory, output_filename)
|
||||
|
||||
unique_contexts_dict = {}
|
||||
|
||||
print(f"Processing file: {filename}")
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as infile:
|
||||
for line_number, line in enumerate(infile, start=1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
context = json_obj.get('context')
|
||||
if context and context not in unique_contexts_dict:
|
||||
unique_contexts_dict[context] = None
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {filename}")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"An error occurred while processing file {filename}: {e}")
|
||||
continue
|
||||
|
||||
unique_contexts_list = list(unique_contexts_dict.keys())
|
||||
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as outfile:
|
||||
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
||||
print(f"Unique `context` entries have been saved to: {output_filename}")
|
||||
except Exception as e:
|
||||
print(f"An error occurred while saving to the file {output_filename}: {e}")
|
||||
|
||||
print("All files have been processed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
extract_unique_contexts(args.input_dir, args.output_dir)
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
|
||||
from lightrag import LightRAG
|
||||
|
||||
def insert_text(rag, file_path):
|
||||
with open(file_path, mode='r') as f:
|
||||
unique_contexts = json.load(f)
|
||||
|
||||
retries = 0
|
||||
max_retries = 3
|
||||
while retries < max_retries:
|
||||
try:
|
||||
rag.insert(unique_contexts)
|
||||
break
|
||||
except Exception as e:
|
||||
retries += 1
|
||||
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
|
||||
time.sleep(10)
|
||||
if retries == max_retries:
|
||||
print("Insertion failed after exceeding the maximum number of retries")
|
||||
|
||||
cls = "agriculture"
|
||||
WORKING_DIR = "../{cls}"
|
||||
|
||||
if not os.path.exists(WORKING_DIR):
|
||||
os.mkdir(WORKING_DIR)
|
||||
|
||||
rag = LightRAG(working_dir=WORKING_DIR)
|
||||
|
||||
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
from openai import OpenAI
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
def openai_complete_if_cache(
|
||||
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
||||
) -> str:
|
||||
openai_client = OpenAI()
|
||||
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=model, messages=messages, **kwargs
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
def get_summary(context, tot_tokens=2000):
|
||||
tokens = tokenizer.tokenize(context)
|
||||
half_tokens = tot_tokens // 2
|
||||
|
||||
start_tokens = tokens[1000:1000 + half_tokens]
|
||||
end_tokens = tokens[-(1000 + half_tokens):1000]
|
||||
|
||||
summary_tokens = start_tokens + end_tokens
|
||||
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
clses = ['agriculture']
|
||||
for cls in clses:
|
||||
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
|
||||
unique_contexts = json.load(f)
|
||||
|
||||
summaries = [get_summary(context) for context in unique_contexts]
|
||||
|
||||
total_description = "\n\n".join(summaries)
|
||||
|
||||
prompt = f"""
|
||||
Given the following description of a dataset:
|
||||
|
||||
{total_description}
|
||||
|
||||
Please identify 5 potential users who would engage with this dataset. For each user, list 5 tasks they would perform with this dataset. Then, for each (user, task) combination, generate 5 questions that require a high-level understanding of the entire dataset.
|
||||
|
||||
Output the results in the following structure:
|
||||
- User 1: [user description]
|
||||
- Task 1: [task description]
|
||||
- Question 1:
|
||||
- Question 2:
|
||||
- Question 3:
|
||||
- Question 4:
|
||||
- Question 5:
|
||||
- Task 2: [task description]
|
||||
...
|
||||
- Task 5: [task description]
|
||||
- User 2: [user description]
|
||||
...
|
||||
- User 5: [user description]
|
||||
...
|
||||
"""
|
||||
|
||||
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
|
||||
|
||||
file_path = f"../datasets/questions/{cls}_questions.txt"
|
||||
with open(file_path, "w") as file:
|
||||
file.write(result)
|
||||
|
||||
print(f"{cls}_questions written to {file_path}")
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
import re
|
||||
import json
|
||||
import asyncio
|
||||
from lightrag import LightRAG, QueryParam
|
||||
from tqdm import tqdm
|
||||
|
||||
def extract_queries(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
data = f.read()
|
||||
|
||||
data = data.replace('**', '')
|
||||
|
||||
queries = re.findall(r'- Question \d+: (.+)', data)
|
||||
|
||||
return queries
|
||||
|
||||
async def process_query(query_text, rag_instance, query_param):
|
||||
try:
|
||||
result, context = await rag_instance.aquery(query_text, param=query_param)
|
||||
return {"query": query_text, "result": result, "context": context}, None
|
||||
except Exception as e:
|
||||
return None, {"query": query_text, "error": str(e)}
|
||||
|
||||
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
|
||||
loop = always_get_an_event_loop()
|
||||
|
||||
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
|
||||
result_file.write("[\n")
|
||||
first_entry = True
|
||||
|
||||
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
||||
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
|
||||
|
||||
if result:
|
||||
if not first_entry:
|
||||
result_file.write(",\n")
|
||||
json.dump(result, result_file, ensure_ascii=False, indent=4)
|
||||
first_entry = False
|
||||
elif error:
|
||||
json.dump(error, err_file, ensure_ascii=False, indent=4)
|
||||
err_file.write("\n")
|
||||
|
||||
result_file.write("\n]")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cls = "agriculture"
|
||||
mode = "hybrid"
|
||||
WORKING_DIR = "../{cls}"
|
||||
|
||||
rag = LightRAG(working_dir=WORKING_DIR)
|
||||
query_param = QueryParam(mode=mode)
|
||||
|
||||
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
||||
run_queries_and_save_to_json(queries, rag, query_param, "result.json", "errors.json")
|
||||
|
|
@ -9,4 +9,7 @@ tenacity
|
|||
transformers
|
||||
torch
|
||||
ollama
|
||||
accelerate
|
||||
accelerate
|
||||
transformers
|
||||
torch
|
||||
python-dotenv
|
||||
45
scripts/repo_chunking.py
Normal file
45
scripts/repo_chunking.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import os
|
||||
import requests
|
||||
from github import Github
|
||||
from dotenv import load_dotenv
|
||||
def chunk_repo(repo_url):
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
# Extract owner and repo name from the URL
|
||||
_, _, _, owner, repo_name = repo_url.rstrip('/').split('/')
|
||||
print(f"Owner: {owner}, Repo: {repo_name}")
|
||||
# Initialize GitHub API client using the token from .env
|
||||
g = Github(os.getenv('GITHUB_TOKEN'))
|
||||
# Get the repository
|
||||
repo = g.get_repo(f"{owner}/{repo_name}")
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = 'scripts/output'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# List of common code file extensions
|
||||
code_extensions = ['.py', '.js', '.ts', '.java', '.c', '.cpp', '.cs', '.go', '.rb', '.php', '.swift', '.kt', '.rs', '.html', '.css', '.scss', '.sql']
|
||||
# Traverse through all files in the repository
|
||||
contents = repo.get_contents("")
|
||||
while contents:
|
||||
file_content = contents.pop(0)
|
||||
if file_content.type == "dir":
|
||||
contents.extend(repo.get_contents(file_content.path))
|
||||
else:
|
||||
file_extension = os.path.splitext(file_content.name)[1]
|
||||
if file_extension in code_extensions:
|
||||
# Get the raw content of the file
|
||||
raw_content = requests.get(file_content.download_url).text
|
||||
|
||||
# Create a unique filename for the output
|
||||
output_filename = f"{output_dir}/{file_content.path.replace('/', '_')}.txt"
|
||||
|
||||
# Write metadata and file contents to the output file
|
||||
with open(output_filename, 'w', encoding='utf-8') as f:
|
||||
f.write(f"File Path: {file_content.path}\n")
|
||||
f.write("\n--- File Contents ---\n\n")
|
||||
f.write(raw_content)
|
||||
print(f"Processed: {file_content.path}")
|
||||
print("Repository chunking completed.")
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
repo_url = "https://github.com/palmier-io/palmier-vscode-extension"
|
||||
chunk_repo(repo_url)
|
||||
73
scripts/repo_stats.py
Normal file
73
scripts/repo_stats.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
''' Collects stats about a repo to prepare for cost estimation '''
|
||||
import os
|
||||
import requests
|
||||
from github import Github
|
||||
import tiktoken
|
||||
from dotenv import load_dotenv
|
||||
import ast
|
||||
def count_functions(file_content, file_extension):
|
||||
# TODO: Implement for other languages
|
||||
if file_extension == '.py':
|
||||
try:
|
||||
tree = ast.parse(file_content)
|
||||
return len([node for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)])
|
||||
except SyntaxError:
|
||||
print(f"Syntax error in Python file")
|
||||
return 0
|
||||
else:
|
||||
return 0
|
||||
def count_tokens(file_content):
|
||||
enc = tiktoken.encoding_for_model("gpt-4o-mini-2024-07-18")
|
||||
tokens = enc.encode(file_content)
|
||||
return len(tokens)
|
||||
def analyze_repo(repo_url):
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
# Extract owner and repo name from the URL
|
||||
_, _, _, owner, repo_name = repo_url.rstrip('/').split('/')
|
||||
print(f"Owner: {owner}, Repo: {repo_name}")
|
||||
# Initialize GitHub API client using the token from .env
|
||||
g = Github(os.getenv('GITHUB_TOKEN'))
|
||||
# Get the repository
|
||||
repo = g.get_repo(f"{owner}/{repo_name}")
|
||||
# Initialize counters
|
||||
total_tokens = 0
|
||||
file_count = 0
|
||||
total_functions = 0
|
||||
total_lines = 0 # New counter for lines of code
|
||||
# Traverse through all files in the repository
|
||||
contents = repo.get_contents("")
|
||||
while contents:
|
||||
file_content = contents.pop(0)
|
||||
if file_content.type == "dir":
|
||||
contents.extend(repo.get_contents(file_content.path))
|
||||
else:
|
||||
# List of common code file extensions
|
||||
code_extensions = ['.py', '.js', '.ts', '.java', '.c', '.cpp', '.cs', '.go', '.rb', '.php', '.swift', '.kt', '.rs', '.html', '.css', '.scss', '.sql']
|
||||
|
||||
file_extension = os.path.splitext(file_content.name)[1]
|
||||
if file_extension in code_extensions:
|
||||
file_count += 1
|
||||
# Get the raw content of the file
|
||||
raw_content = requests.get(file_content.download_url).text
|
||||
|
||||
# Count tokens
|
||||
total_tokens += count_tokens(raw_content)
|
||||
# Count lines of code
|
||||
total_lines += len(raw_content.splitlines()) # New line to count lines of code
|
||||
# Count functions for Python, JavaScript, and TypeScript files
|
||||
if file_extension in ['.py', '.js', '.ts']:
|
||||
total_functions += count_functions(raw_content, file_extension)
|
||||
# Prepare the output dictionary
|
||||
result = {
|
||||
"number_of_files": file_count,
|
||||
"total_tokens": total_tokens,
|
||||
"total_functions": total_functions,
|
||||
"total_lines_of_code": total_lines # New stat in the result dictionary
|
||||
}
|
||||
return result
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
repo_url = "https://github.com/palmier-io/palmier-vscode-extension"
|
||||
stats = analyze_repo(repo_url)
|
||||
print(stats)
|
||||
79
scripts/view_graph.py
Normal file
79
scripts/view_graph.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import networkx as nx
|
||||
import os
|
||||
import plotly.graph_objects as go
|
||||
WORKING_DIR = "./ragtest"
|
||||
# Find all GraphML files in the directory and sort them alphabetically
|
||||
graphml_files = [f for f in os.listdir(WORKING_DIR) if f.endswith('.graphml')]
|
||||
graphml_files.sort()
|
||||
def create_hover_text(attributes):
|
||||
return '<br>'.join([f'{k}: {v}' for k, v in attributes.items() if k != 'id'])
|
||||
# Loop through each GraphML file
|
||||
for graphml_file in graphml_files:
|
||||
# Construct the full path to the GraphML file
|
||||
graphml_path = os.path.join(WORKING_DIR, graphml_file)
|
||||
|
||||
# Load the graph
|
||||
G = nx.read_graphml(graphml_path)
|
||||
|
||||
# Create a layout for the graph
|
||||
pos = nx.spring_layout(G)
|
||||
|
||||
# Create node trace
|
||||
node_x = []
|
||||
node_y = []
|
||||
node_text = []
|
||||
for node in G.nodes():
|
||||
x, y = pos[node]
|
||||
node_x.append(x)
|
||||
node_y.append(y)
|
||||
node_text.append(create_hover_text(G.nodes[node]))
|
||||
node_trace = go.Scatter(
|
||||
x=node_x, y=node_y,
|
||||
mode='markers+text',
|
||||
hoverinfo='text',
|
||||
text=[node for node in G.nodes()], # Use node IDs as labels
|
||||
hovertext=node_text,
|
||||
marker=dict(size=20, color='lightblue'),
|
||||
textposition='top center'
|
||||
)
|
||||
# Create edge trace
|
||||
edge_x = []
|
||||
edge_y = []
|
||||
edge_text = []
|
||||
edge_hover_x = []
|
||||
edge_hover_y = []
|
||||
for edge in G.edges():
|
||||
x0, y0 = pos[edge[0]]
|
||||
x1, y1 = pos[edge[1]]
|
||||
edge_x.extend([x0, x1, None])
|
||||
edge_y.extend([y0, y1, None])
|
||||
edge_hover_x.append((x0 + x1) / 2)
|
||||
edge_hover_y.append((y0 + y1) / 2)
|
||||
edge_text.append(create_hover_text(G.edges[edge]))
|
||||
edge_trace = go.Scatter(
|
||||
x=edge_x, y=edge_y,
|
||||
line=dict(width=0.5, color='#888'),
|
||||
mode='lines'
|
||||
)
|
||||
edge_hover_trace = go.Scatter(
|
||||
x=edge_hover_x, y=edge_hover_y,
|
||||
mode='markers',
|
||||
marker=dict(size=0.5, color='rgba(0,0,0,0)'),
|
||||
hoverinfo='text',
|
||||
hovertext=edge_text,
|
||||
hoverlabel=dict(bgcolor='white'),
|
||||
)
|
||||
# Create the figure
|
||||
fig = go.Figure(data=[edge_trace, edge_hover_trace, node_trace],
|
||||
layout=go.Layout(
|
||||
title=graphml_file,
|
||||
showlegend=False,
|
||||
hovermode='closest',
|
||||
margin=dict(b=0,l=0,r=0,t=40),
|
||||
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
||||
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
|
||||
))
|
||||
# Save the figure as an interactive HTML file
|
||||
output_file = f'{os.path.splitext(graphml_file)[0]}_interactive.html'
|
||||
fig.write_html(os.path.join(WORKING_DIR, output_file))
|
||||
print(f"Interactive graph saved as {output_file}")
|
||||
Loading…
Add table
Reference in a new issue