Initial Commit

This commit is contained in:
mricopeng 2024-10-16 20:59:08 -07:00
parent 7ab699955e
commit 8f25798a91
13 changed files with 262 additions and 552 deletions

6
.gitignore vendored Normal file
View file

@ -0,0 +1,6 @@
.env
venv/
lightrag_hku.egg-info/
output/
ragtest/
__pycache__/

330
README.md
View file

@ -1,122 +1,21 @@
<center><h2>🚀 LightRAG: Simple and Fast Retrieval-Augmented Generation</h2></center>
![请添加图片描述](https://i-blog.csdnimg.cn/direct/567139f1a36e4564abc63ce5c12b6271.jpeg)
<div align='center'>
<p>
<a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
<a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
<img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
</p>
<p>
<img src="https://img.shields.io/badge/python->=3.9.11-blue">
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
</p>
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
![请添加图片描述](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
</div>
## 🎉 News
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
<center><h2>🚀 PalmierRAG: State-of-the-Art RAG for Any Codebase</h2></center>
## Install
* Install from source (Recommend)
* Install from source
```bash
cd LightRAG
cd palmier-lightrag
pip install -e .
```
* Install from PyPI
```bash
pip install lightrag-hku
```
## Quick Start
* All the code can be found in the `examples`.
* Set OpenAI API key in environment if using OpenAI models: `export OPENAI_API_KEY="sk-...".`
* Download the demo text "A Christmas Carol by Charles Dickens":
```bash
curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
```
Use the below Python snippet to initialize LightRAG and perform queries:
```python
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
WORKING_DIR = "./dickens"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
with open("./book.txt") as f:
rag.insert(f.read())
# Perform naive search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
```
<details>
<summary> Using Open AI-like APIs </summary>
LightRAG also support Open AI-like chat/embeddings APIs:
```python
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"solar-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar",
**kwargs
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embedding(
texts,
model="solar-embedding-1-large-query",
api_key=os.getenv("UPSTAGE_API_KEY"),
base_url="https://api.upstage.ai/v1/solar"
)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=4096,
max_token_size=8192,
func=embedding_func
)
)
```
</details>
<details>
<summary> Using Hugging Face Models </summary>
* Set OpenAI API key in environment if using OpenAI models: `export OPENAI_API_KEY="sk-...".` OR add .env to `lightrag/`
* Create `ragtest/input` directory containing `.txt` files to be indexed.
* Run `python index.py` to index the documents.
* Modify `query.py` with desired query and run `python query.py` to query the documents.
### Using Hugging Face Models
If you want to use Hugging Face models, you only need to set LightRAG as follows:
```python
from lightrag.llm import hf_model_complete, hf_embedding
@ -125,7 +24,7 @@ from transformers import AutoModel, AutoTokenizer
# Initialize LightRAG with Hugging Face model
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=hf_model_complete, # Use Hugging Face model for text generation
llm_model_func=hf_model_complete, # Use Hugging Face complete model for text generation
llm_model_name='meta-llama/Llama-3.1-8B-Instruct', # Model name from Hugging Face
# Use Hugging Face embedding function
embedding_func=EmbeddingFunc(
@ -139,39 +38,11 @@ rag = LightRAG(
),
)
```
</details>
<details>
<summary> Using Ollama Models (There are some bugs. I'll fix them ASAP.) </summary>
If you want to use Ollama models, you only need to set LightRAG as follows:
```python
from lightrag.llm import ollama_model_complete, ollama_embedding
# Initialize LightRAG with Ollama model
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=ollama_model_complete, # Use Ollama model for text generation
llm_model_name='your_model_name', # Your model name
# Use Ollama embedding function
embedding_func=EmbeddingFunc(
embedding_dim=768,
max_token_size=8192,
func=lambda texts: ollama_embedding(
texts,
embed_model="nomic-embed-text"
)
),
)
```
</details>
### Batch Insert
```python
# Batch Insert: Insert multiple texts at once
rag.insert(["TEXT1", "TEXT2",...])
```
### Incremental Insert
```python
@ -187,10 +58,6 @@ The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https
### Generate Query
LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
<details>
<summary> Prompt </summary>
```python
Given the following description of a dataset:
@ -214,14 +81,9 @@ Output the results in the following structure:
- User 5: [user description]
...
```
</details>
### Batch Eval
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
<details>
<summary> Prompt </summary>
```python
---Role---
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
@ -264,162 +126,6 @@ Output your evaluation in the following JSON format:
}}
}}
```
</details>
### Overall Performance Table
| | **Agriculture** | | **CS** | | **Legal** | | **Mix** | |
|----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|
| | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** |
| **Comprehensiveness** | 32.69% | **67.31%** | 35.44% | **64.56%** | 19.05% | **80.95%** | 36.36% | **63.64%** |
| **Diversity** | 24.09% | **75.91%** | 35.24% | **64.76%** | 10.98% | **89.02%** | 30.76% | **69.24%** |
| **Empowerment** | 31.35% | **68.65%** | 35.48% | **64.52%** | 17.59% | **82.41%** | 40.95% | **59.05%** |
| **Overall** | 33.30% | **66.70%** | 34.76% | **65.24%** | 17.46% | **82.54%** | 37.59% | **62.40%** |
| | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** |
| **Comprehensiveness** | 32.05% | **67.95%** | 39.30% | **60.70%** | 18.57% | **81.43%** | 38.89% | **61.11%** |
| **Diversity** | 29.44% | **70.56%** | 38.71% | **61.29%** | 15.14% | **84.86%** | 28.50% | **71.50%** |
| **Empowerment** | 32.51% | **67.49%** | 37.52% | **62.48%** | 17.80% | **82.20%** | 43.96% | **56.04%** |
| **Overall** | 33.29% | **66.71%** | 39.03% | **60.97%** | 17.80% | **82.20%** | 39.61% | **60.39%** |
| | HyDE | **LightRAG** | HyDE | **LightRAG** | HyDE | **LightRAG** | HyDE | **LightRAG** |
| **Comprehensiveness** | 24.39% | **75.61%** | 36.49% | **63.51%** | 27.68% | **72.32%** | 42.17% | **57.83%** |
| **Diversity** | 24.96% | **75.34%** | 37.41% | **62.59%** | 18.79% | **81.21%** | 30.88% | **69.12%** |
| **Empowerment** | 24.89% | **75.11%** | 34.99% | **65.01%** | 26.99% | **73.01%** | **45.61%** | **54.39%** |
| **Overall** | 23.17% | **76.83%** | 35.67% | **64.33%** | 27.68% | **72.32%** | 42.72% | **57.28%** |
| | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** |
| **Comprehensiveness** | 45.56% | **54.44%** | 45.98% | **54.02%** | 47.13% | **52.87%** | **51.86%** | 48.14% |
| **Diversity** | 19.65% | **80.35%** | 39.64% | **60.36%** | 25.55% | **74.45%** | 35.87% | **64.13%** |
| **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
| **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
## Reproduce
All the code can be found in the `./reproduce` directory.
### Step-0 Extract Unique Contexts
First, we need to extract unique contexts in the datasets.
<details>
<summary> Code </summary>
```python
def extract_unique_contexts(input_directory, output_directory):
os.makedirs(output_directory, exist_ok=True)
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
print(f"Found {len(jsonl_files)} JSONL files.")
for file_path in jsonl_files:
filename = os.path.basename(file_path)
name, ext = os.path.splitext(filename)
output_filename = f"{name}_unique_contexts.json"
output_path = os.path.join(output_directory, output_filename)
unique_contexts_dict = {}
print(f"Processing file: {filename}")
try:
with open(file_path, 'r', encoding='utf-8') as infile:
for line_number, line in enumerate(infile, start=1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
context = json_obj.get('context')
if context and context not in unique_contexts_dict:
unique_contexts_dict[context] = None
except json.JSONDecodeError as e:
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
except FileNotFoundError:
print(f"File not found: {filename}")
continue
except Exception as e:
print(f"An error occurred while processing file {filename}: {e}")
continue
unique_contexts_list = list(unique_contexts_dict.keys())
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
try:
with open(output_path, 'w', encoding='utf-8') as outfile:
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
print(f"Unique `context` entries have been saved to: {output_filename}")
except Exception as e:
print(f"An error occurred while saving to the file {output_filename}: {e}")
print("All files have been processed.")
```
</details>
### Step-1 Insert Contexts
For the extracted contexts, we insert them into the LightRAG system.
<details>
<summary> Code </summary>
```python
def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
unique_contexts = json.load(f)
retries = 0
max_retries = 3
while retries < max_retries:
try:
rag.insert(unique_contexts)
break
except Exception as e:
retries += 1
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
time.sleep(10)
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
```
</details>
### Step-2 Generate Queries
We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
<details>
<summary> Code </summary>
```python
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
def get_summary(context, tot_tokens=2000):
tokens = tokenizer.tokenize(context)
half_tokens = tot_tokens // 2
start_tokens = tokens[1000:1000 + half_tokens]
end_tokens = tokens[-(1000 + half_tokens):1000]
summary_tokens = start_tokens + end_tokens
summary = tokenizer.convert_tokens_to_string(summary_tokens)
return summary
```
</details>
### Step-3 Query
For the queries generated in Step-2, we will extract them and query LightRAG.
<details>
<summary> Code </summary>
```python
def extract_queries(file_path):
with open(file_path, 'r') as f:
data = f.read()
data = data.replace('**', '')
queries = re.findall(r'- Question \d+: (.+)', data)
return queries
```
</details>
## Code Structure
@ -428,10 +134,8 @@ def extract_queries(file_path):
├── examples
│ ├── batch_eval.py
│ ├── generate_query.py
│ ├── lightrag_hf_demo.py
│ ├── lightrag_ollama_demo.py
│ ├── lightrag_openai_compatible_demo.py
│ └── lightrag_openai_demo.py
│ ├── lightrag_openai_demo.py
│ └── lightrag_hf_demo.py
├── lightrag
│ ├── __init__.py
│ ├── base.py
@ -452,16 +156,6 @@ def extract_queries(file_path):
└── setup.py
```
## Star History
<a href="https://star-history.com/#HKUDS/LightRAG&Date">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=HKUDS/LightRAG&type=Date&theme=dark" />
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=HKUDS/LightRAG&type=Date" />
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=HKUDS/LightRAG&type=Date" />
</picture>
</a>
## Citation
```python
@ -473,4 +167,4 @@ eprint={2410.05779},
archivePrefix={arXiv},
primaryClass={cs.IR}
}
```
```

21
index.py Normal file
View file

@ -0,0 +1,21 @@
from lightrag import LightRAG
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
import os
WORKING_DIR = "./ragtest"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
# Load all .txt files from the input folder
input_folder = os.path.join(WORKING_DIR, "input")
texts_to_insert = []
for filename in os.listdir(input_folder):
if filename.endswith(".txt"):
file_path = os.path.join(input_folder, filename)
with open(file_path, "r") as f:
texts_to_insert.append(f.read())
# Batch insert all loaded texts
rag.insert(texts_to_insert)

View file

@ -5,6 +5,7 @@ from datetime import datetime
from functools import partial
from typing import Type, cast, Any
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
from .operate import (
@ -49,6 +50,8 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@dataclass
class LightRAG:
# Load environment variables from .env file
load_dotenv()
working_dir: str = field(
default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
)

19
query.py Normal file
View file

@ -0,0 +1,19 @@
from lightrag import LightRAG, QueryParam
from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
import os
WORKING_DIR = "./ragtest"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
# llm_model_func=gpt_4o_complete # Optionally, use a stronger model
)
# Perform naive search
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="naive")))
# Perform local search
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="local")))
# Perform global search
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="global")))
# Perform hybrid search
print(rag.query("what happens when i click the palmier tab button?", param=QueryParam(mode="hybrid")))

View file

@ -1,63 +0,0 @@
import os
import json
import glob
import argparse
def extract_unique_contexts(input_directory, output_directory):
os.makedirs(output_directory, exist_ok=True)
jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
print(f"Found {len(jsonl_files)} JSONL files.")
for file_path in jsonl_files:
filename = os.path.basename(file_path)
name, ext = os.path.splitext(filename)
output_filename = f"{name}_unique_contexts.json"
output_path = os.path.join(output_directory, output_filename)
unique_contexts_dict = {}
print(f"Processing file: {filename}")
try:
with open(file_path, 'r', encoding='utf-8') as infile:
for line_number, line in enumerate(infile, start=1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
context = json_obj.get('context')
if context and context not in unique_contexts_dict:
unique_contexts_dict[context] = None
except json.JSONDecodeError as e:
print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
except FileNotFoundError:
print(f"File not found: {filename}")
continue
except Exception as e:
print(f"An error occurred while processing file {filename}: {e}")
continue
unique_contexts_list = list(unique_contexts_dict.keys())
print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
try:
with open(output_path, 'w', encoding='utf-8') as outfile:
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
print(f"Unique `context` entries have been saved to: {output_filename}")
except Exception as e:
print(f"An error occurred while saving to the file {output_filename}: {e}")
print("All files have been processed.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
args = parser.parse_args()
extract_unique_contexts(args.input_dir, args.output_dir)

View file

@ -1,32 +0,0 @@
import os
import json
import time
from lightrag import LightRAG
def insert_text(rag, file_path):
with open(file_path, mode='r') as f:
unique_contexts = json.load(f)
retries = 0
max_retries = 3
while retries < max_retries:
try:
rag.insert(unique_contexts)
break
except Exception as e:
retries += 1
print(f"Insertion failed, retrying ({retries}/{max_retries}), error: {e}")
time.sleep(10)
if retries == max_retries:
print("Insertion failed after exceeding the maximum number of retries")
cls = "agriculture"
WORKING_DIR = "../{cls}"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
rag = LightRAG(working_dir=WORKING_DIR)
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")

View file

@ -1,76 +0,0 @@
import os
import json
from openai import OpenAI
from transformers import GPT2Tokenizer
def openai_complete_if_cache(
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
) -> str:
openai_client = OpenAI()
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = openai_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
return response.choices[0].message.content
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
def get_summary(context, tot_tokens=2000):
tokens = tokenizer.tokenize(context)
half_tokens = tot_tokens // 2
start_tokens = tokens[1000:1000 + half_tokens]
end_tokens = tokens[-(1000 + half_tokens):1000]
summary_tokens = start_tokens + end_tokens
summary = tokenizer.convert_tokens_to_string(summary_tokens)
return summary
clses = ['agriculture']
for cls in clses:
with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
unique_contexts = json.load(f)
summaries = [get_summary(context) for context in unique_contexts]
total_description = "\n\n".join(summaries)
prompt = f"""
Given the following description of a dataset:
{total_description}
Please identify 5 potential users who would engage with this dataset. For each user, list 5 tasks they would perform with this dataset. Then, for each (user, task) combination, generate 5 questions that require a high-level understanding of the entire dataset.
Output the results in the following structure:
- User 1: [user description]
- Task 1: [task description]
- Question 1:
- Question 2:
- Question 3:
- Question 4:
- Question 5:
- Task 2: [task description]
...
- Task 5: [task description]
- User 2: [user description]
...
- User 5: [user description]
...
"""
result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
file_path = f"../datasets/questions/{cls}_questions.txt"
with open(file_path, "w") as file:
file.write(result)
print(f"{cls}_questions written to {file_path}")

View file

@ -1,62 +0,0 @@
import re
import json
import asyncio
from lightrag import LightRAG, QueryParam
from tqdm import tqdm
def extract_queries(file_path):
with open(file_path, 'r') as f:
data = f.read()
data = data.replace('**', '')
queries = re.findall(r'- Question \d+: (.+)', data)
return queries
async def process_query(query_text, rag_instance, query_param):
try:
result, context = await rag_instance.aquery(query_text, param=query_param)
return {"query": query_text, "result": result, "context": context}, None
except Exception as e:
return None, {"query": query_text, "error": str(e)}
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
loop = always_get_an_event_loop()
with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
result_file.write("[\n")
first_entry = True
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
if result:
if not first_entry:
result_file.write(",\n")
json.dump(result, result_file, ensure_ascii=False, indent=4)
first_entry = False
elif error:
json.dump(error, err_file, ensure_ascii=False, indent=4)
err_file.write("\n")
result_file.write("\n]")
if __name__ == "__main__":
cls = "agriculture"
mode = "hybrid"
WORKING_DIR = "../{cls}"
rag = LightRAG(working_dir=WORKING_DIR)
query_param = QueryParam(mode=mode)
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
run_queries_and_save_to_json(queries, rag, query_param, "result.json", "errors.json")

View file

@ -9,4 +9,7 @@ tenacity
transformers
torch
ollama
accelerate
accelerate
transformers
torch
python-dotenv

45
scripts/repo_chunking.py Normal file
View file

@ -0,0 +1,45 @@
import os
import requests
from github import Github
from dotenv import load_dotenv
def chunk_repo(repo_url):
# Load environment variables from .env file
load_dotenv()
# Extract owner and repo name from the URL
_, _, _, owner, repo_name = repo_url.rstrip('/').split('/')
print(f"Owner: {owner}, Repo: {repo_name}")
# Initialize GitHub API client using the token from .env
g = Github(os.getenv('GITHUB_TOKEN'))
# Get the repository
repo = g.get_repo(f"{owner}/{repo_name}")
# Create output directory if it doesn't exist
output_dir = 'scripts/output'
os.makedirs(output_dir, exist_ok=True)
# List of common code file extensions
code_extensions = ['.py', '.js', '.ts', '.java', '.c', '.cpp', '.cs', '.go', '.rb', '.php', '.swift', '.kt', '.rs', '.html', '.css', '.scss', '.sql']
# Traverse through all files in the repository
contents = repo.get_contents("")
while contents:
file_content = contents.pop(0)
if file_content.type == "dir":
contents.extend(repo.get_contents(file_content.path))
else:
file_extension = os.path.splitext(file_content.name)[1]
if file_extension in code_extensions:
# Get the raw content of the file
raw_content = requests.get(file_content.download_url).text
# Create a unique filename for the output
output_filename = f"{output_dir}/{file_content.path.replace('/', '_')}.txt"
# Write metadata and file contents to the output file
with open(output_filename, 'w', encoding='utf-8') as f:
f.write(f"File Path: {file_content.path}\n")
f.write("\n--- File Contents ---\n\n")
f.write(raw_content)
print(f"Processed: {file_content.path}")
print("Repository chunking completed.")
# Example usage
if __name__ == "__main__":
repo_url = "https://github.com/palmier-io/palmier-vscode-extension"
chunk_repo(repo_url)

73
scripts/repo_stats.py Normal file
View file

@ -0,0 +1,73 @@
''' Collects stats about a repo to prepare for cost estimation '''
import os
import requests
from github import Github
import tiktoken
from dotenv import load_dotenv
import ast
def count_functions(file_content, file_extension):
# TODO: Implement for other languages
if file_extension == '.py':
try:
tree = ast.parse(file_content)
return len([node for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)])
except SyntaxError:
print(f"Syntax error in Python file")
return 0
else:
return 0
def count_tokens(file_content):
enc = tiktoken.encoding_for_model("gpt-4o-mini-2024-07-18")
tokens = enc.encode(file_content)
return len(tokens)
def analyze_repo(repo_url):
# Load environment variables from .env file
load_dotenv()
# Extract owner and repo name from the URL
_, _, _, owner, repo_name = repo_url.rstrip('/').split('/')
print(f"Owner: {owner}, Repo: {repo_name}")
# Initialize GitHub API client using the token from .env
g = Github(os.getenv('GITHUB_TOKEN'))
# Get the repository
repo = g.get_repo(f"{owner}/{repo_name}")
# Initialize counters
total_tokens = 0
file_count = 0
total_functions = 0
total_lines = 0 # New counter for lines of code
# Traverse through all files in the repository
contents = repo.get_contents("")
while contents:
file_content = contents.pop(0)
if file_content.type == "dir":
contents.extend(repo.get_contents(file_content.path))
else:
# List of common code file extensions
code_extensions = ['.py', '.js', '.ts', '.java', '.c', '.cpp', '.cs', '.go', '.rb', '.php', '.swift', '.kt', '.rs', '.html', '.css', '.scss', '.sql']
file_extension = os.path.splitext(file_content.name)[1]
if file_extension in code_extensions:
file_count += 1
# Get the raw content of the file
raw_content = requests.get(file_content.download_url).text
# Count tokens
total_tokens += count_tokens(raw_content)
# Count lines of code
total_lines += len(raw_content.splitlines()) # New line to count lines of code
# Count functions for Python, JavaScript, and TypeScript files
if file_extension in ['.py', '.js', '.ts']:
total_functions += count_functions(raw_content, file_extension)
# Prepare the output dictionary
result = {
"number_of_files": file_count,
"total_tokens": total_tokens,
"total_functions": total_functions,
"total_lines_of_code": total_lines # New stat in the result dictionary
}
return result
# Example usage
if __name__ == "__main__":
repo_url = "https://github.com/palmier-io/palmier-vscode-extension"
stats = analyze_repo(repo_url)
print(stats)

79
scripts/view_graph.py Normal file
View file

@ -0,0 +1,79 @@
import networkx as nx
import os
import plotly.graph_objects as go
WORKING_DIR = "./ragtest"
# Find all GraphML files in the directory and sort them alphabetically
graphml_files = [f for f in os.listdir(WORKING_DIR) if f.endswith('.graphml')]
graphml_files.sort()
def create_hover_text(attributes):
return '<br>'.join([f'{k}: {v}' for k, v in attributes.items() if k != 'id'])
# Loop through each GraphML file
for graphml_file in graphml_files:
# Construct the full path to the GraphML file
graphml_path = os.path.join(WORKING_DIR, graphml_file)
# Load the graph
G = nx.read_graphml(graphml_path)
# Create a layout for the graph
pos = nx.spring_layout(G)
# Create node trace
node_x = []
node_y = []
node_text = []
for node in G.nodes():
x, y = pos[node]
node_x.append(x)
node_y.append(y)
node_text.append(create_hover_text(G.nodes[node]))
node_trace = go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
hoverinfo='text',
text=[node for node in G.nodes()], # Use node IDs as labels
hovertext=node_text,
marker=dict(size=20, color='lightblue'),
textposition='top center'
)
# Create edge trace
edge_x = []
edge_y = []
edge_text = []
edge_hover_x = []
edge_hover_y = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_hover_x.append((x0 + x1) / 2)
edge_hover_y.append((y0 + y1) / 2)
edge_text.append(create_hover_text(G.edges[edge]))
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
mode='lines'
)
edge_hover_trace = go.Scatter(
x=edge_hover_x, y=edge_hover_y,
mode='markers',
marker=dict(size=0.5, color='rgba(0,0,0,0)'),
hoverinfo='text',
hovertext=edge_text,
hoverlabel=dict(bgcolor='white'),
)
# Create the figure
fig = go.Figure(data=[edge_trace, edge_hover_trace, node_trace],
layout=go.Layout(
title=graphml_file,
showlegend=False,
hovermode='closest',
margin=dict(b=0,l=0,r=0,t=40),
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)
))
# Save the figure as an interactive HTML file
output_file = f'{os.path.splitext(graphml_file)[0]}_interactive.html'
fig.write_html(os.path.join(WORKING_DIR, output_file))
print(f"Interactive graph saved as {output_file}")