diff --git a/requirements.txt b/requirements.txt index a51befbbe04266eeb98beab94d4ed89f138f640a..a9a381abb1ee7435b7415c859db29c51b3990d9a 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000000000000000000000000000000000000..242a8eb98df153edc33b4d15bdba4439890df04a --- /dev/null +++ b/ruff.toml @@ -0,0 +1,6 @@ +line-length = 120 + +[format] +quote-style = "double" +indent-style = "tab" +docstring-code-format = true \ No newline at end of file diff --git a/src/app.py b/src/app.py index b331e0ce5ffc9632eca015e57a1e694ed44cd2c4..8f60818abba773c60aca4686ba181bacc5369542 100644 --- a/src/app.py +++ b/src/app.py @@ -1,7 +1,9 @@ import sys -if sys.platform.startswith('win'): - import asyncio - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + +if sys.platform.startswith("win"): + import asyncio + + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) from typing import List @@ -24,7 +26,7 @@ from utils.response_generator import ResponseGenerator st.set_page_config( page_title="Chat GPP - Generative Pre-trained Peter", page_icon="https://www.wur.nl/favicon.ico", - initial_sidebar_state="collapsed" + initial_sidebar_state="collapsed", ) st.title("Chat GPP\n### Generative Pre-trained Peter") @@ -37,9 +39,7 @@ def load_model(): @st.cache_resource def load_encoder(): - encoder = Encoder( - model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu" - ) + encoder = Encoder(model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu") return encoder @@ -72,16 +72,12 @@ RagDB = load_rag_index() # def display_references(references: List[Document]): for doc in references: - title = doc.metadata.get('title') - #modified_date = doc.metadata.get('modified_date') - #date_only = modified_date.split('T')[0] - #modified_by = doc.metadata.get('modified_by') - readable_score = str(round(doc.metadata.get('score'), 2)) - mention( - label=f"{title} ({readable_score})", - icon="https://www.wur.nl/favicon.ico", - url=doc.metadata.get('url') - ) + title = doc.metadata.get("title") + # modified_date = doc.metadata.get('modified_date') + # date_only = modified_date.split('T')[0] + # modified_by = doc.metadata.get('modified_by') + readable_score = str(round(doc.metadata.get("score"), 2)) + mention(label=f"{title} ({readable_score})", icon="https://www.wur.nl/favicon.ico", url=doc.metadata.get("url")) # Display settings in sidebar @@ -124,6 +120,7 @@ for message in st.session_state.messages: if "references" in message: display_references(message["references"]) + async def main(): # Accept user input if user_prompt := st.chat_input("Ask me anything!"): @@ -135,7 +132,9 @@ async def main(): st.markdown(user_prompt) # Retrieve context from Vector DB - context_retriever = RagContextRetriever(RagDB, re_ranker, context_extender, k1, k2, k2_threshold, expand_context) + context_retriever = RagContextRetriever( + RagDB, re_ranker, context_extender, k1, k2, k2_threshold, expand_context + ) context, retrieved_docs = context_retriever.retrieve_context(user_prompt, debug_show_rag_context) # context = None # references = [] @@ -145,18 +144,18 @@ async def main(): with st.chat_message("assistant", avatar=CHATBOT_AVATAR_URL): # Generate response and stream it to the chat response_generator = ResponseGenerator(model, audio_generator) - response = await response_generator.generate_response(user_prompt, context, max_new_tokens, funny_response_chance) + response = await response_generator.generate_response( + user_prompt, context, max_new_tokens, funny_response_chance + ) # List references references = response_generator.get_unique_references(retrieved_docs, response) display_references(references) # Add response to chat history - st.session_state.messages.append({ - "role": "assistant", - "content": response, - "avatar": CHATBOT_AVATAR_URL, - "references": references - }) - -asyncio.run(main()) \ No newline at end of file + st.session_state.messages.append( + {"role": "assistant", "content": response, "avatar": CHATBOT_AVATAR_URL, "references": references} + ) + + +asyncio.run(main()) diff --git a/src/build-rag-index.py b/src/build-rag-index.py index ac77417db3ab7a1294f14caf8b1e6bf5a3cc906d..9f2352c36856cd916c95ec898effbb71bf78036f 100644 --- a/src/build-rag-index.py +++ b/src/build-rag-index.py @@ -1,5 +1,4 @@ # build-rag-index.py -import os from rag import encoder, simple_document_index, document_processor, fais_db from config import RAG_INDEX_FOLDER, RAG_DOCUMENTS_FOLDER @@ -8,9 +7,7 @@ index_file_path = "storage/faiss-index" # Load the encoder print("Loading encoder") -encoder = encoder.Encoder( - model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu" -) +encoder = encoder.Encoder(model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu") # Create an instance of IndexHelper index_helper = simple_document_index.SimpleDocumentIndex(RAG_DOCUMENTS_FOLDER) diff --git a/src/compare-rag.py b/src/compare-rag.py index db07cece2e4ab51a7905574b6835c2fb409eabb3..108c2b220a6ff5136cecf8c61c22fe2a5c6a54da 100644 --- a/src/compare-rag.py +++ b/src/compare-rag.py @@ -1,29 +1,24 @@ # build-rag-index.py -import os import streamlit as st from rag import simple_document_index, fais_db from rag.rag_utils import display_documents_as_dataframe from config import RAG_DOCUMENTS_FOLDER -from transformers import AutoTokenizer, AutoModelForSequenceClassification from sentence_transformers import CrossEncoder st.set_page_config(layout="wide") -st.title('Compare RAG performance') +st.title("Compare RAG performance") + +cross_encoder = CrossEncoder("encoder_name", max_length=512, device="cpu") -cross_encoder = CrossEncoder( - "encoder_name", max_length=512, device="cpu" - ) def get_vector_db(model_name: str, index_filename: str): from rag import encoder, document_processor # Load the encoder print("Loading encoder") - encoder = encoder.Encoder( - model_name=model_name, device="cpu" - ) + encoder = encoder.Encoder(model_name=model_name, device="cpu") # Create an instance of IndexHelper index_helper = simple_document_index.SimpleDocumentIndex(RAG_DOCUMENTS_FOLDER, filename=index_filename) @@ -55,14 +50,15 @@ def lookup(query: str, model: str, k: int = 5): return [results1, results2] -#def rerank(results): - # Rerank the results + +# def rerank(results): +# Rerank the results k = st.number_input("k1 (rag)", 1, 10, 5) -modelA = st.text_input('Enter model name A', "sentence-transformers/all-MiniLM-L12-v2") -modelB = st.text_input('Enter model name B', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2") -query = st.text_input('Enter query') +modelA = st.text_input("Enter model name A", "sentence-transformers/all-MiniLM-L12-v2") +modelB = st.text_input("Enter model name B", "sentence-transformers/paraphrase-multilingual-mpnet-base-v2") +query = st.text_input("Enter query") if query and modelA and modelB: [results_a_nl, results_a_en] = lookup(query, modelA, k) @@ -85,7 +81,3 @@ if query and modelA and modelB: with col_en_2: st.write(f"Model: {modelB}") display_documents_as_dataframe(results_b_en) - - - - diff --git a/src/config.py b/src/config.py index 658c85081d8b0042effd8666b412da9f30da3ac9..b57fc4b595f446cf993d5e059f649f9d892b40ad 100644 --- a/src/config.py +++ b/src/config.py @@ -11,7 +11,7 @@ def load_yaml_config(file_path): # Check if file exists if not os.path.exists(file_path): raise FileNotFoundError(f"Config file not found: {file_path}") - with open(file_path, 'r', encoding='utf-8') as file: + with open(file_path, "r", encoding="utf-8") as file: config = yaml.safe_load(file) return config @@ -55,5 +55,5 @@ CHATBOT_AVATAR_URL = os.getenv("CHATBOT_AVATAR_URL", "https://www.wur.nl/favicon # # Sound mapping for audio generator # -word_sound_mapping_str = os.getenv('WORD_SOUND_MAPPING', '{}') +word_sound_mapping_str = os.getenv("WORD_SOUND_MAPPING", "{}") WORD_SOUND_MAPPING = json.loads(word_sound_mapping_str) diff --git a/src/model/__init__.py b/src/model/__init__.py index 99a652c9d0cd80b4bed291fa8c89c996354c0097..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/model/__init__.py +++ b/src/model/__init__.py @@ -1,4 +0,0 @@ -from . import huggingface_llm_model -from . import azure_openai_llm_model -from . import prompt_builder -from . import llm_model_factory \ No newline at end of file diff --git a/src/model/azure_openai_llm_model.py b/src/model/azure_openai_llm_model.py index ff4b358fbb9a2d280b0b1e4d16b18060337a597d..7624ee8162f9edad920b9ffa2140cca3d93d4c8c 100644 --- a/src/model/azure_openai_llm_model.py +++ b/src/model/azure_openai_llm_model.py @@ -6,22 +6,14 @@ from .llm_model_base import LlmModelBase from .prompt_builder import PromptBuilder -class AzureOpenAiLlmModel (LlmModelBase): - def __init__( - self, - model_id: str - ): +class AzureOpenAiLlmModel(LlmModelBase): + def __init__(self, model_id: str): self.prompt_builder = PromptBuilder() self.model_id = model_id async def generate( - self, - question: str, - context: str = None, - max_new_tokens: int = 256, - funny_prompt_chance: float = 0.0 + self, question: str, context: str = None, max_new_tokens: int = 256, funny_prompt_chance: float = 0.0 ): - # Build the system prompt chat_messages = self.prompt_builder.build_prompt(question, context, funny_prompt_chance) @@ -29,11 +21,7 @@ class AzureOpenAiLlmModel (LlmModelBase): endpoint = os.environ["AZURE_OPEN_AI_ENDPOINT"] api_key = os.environ["AZURE_OPEN_AI_API_KEY"] - client = openai.AsyncAzureOpenAI( - azure_endpoint=endpoint, - api_key=api_key, - api_version="2023-09-01-preview" - ) + client = openai.AsyncAzureOpenAI(azure_endpoint=endpoint, api_key=api_key, api_version="2023-09-01-preview") print("Calling Azure OpenAI API") @@ -46,18 +34,18 @@ class AzureOpenAiLlmModel (LlmModelBase): top_p=1.0, max_tokens=max_new_tokens, messages=chat_messages, - stream=True + stream=True, ) # Define a simplified async generator to yield text chunks async def text_stream(): async for chunk in azure_open_ai_response: if ( - hasattr(chunk, 'choices') and - chunk.choices and - hasattr(chunk.choices[0], 'delta') and - hasattr(chunk.choices[0].delta, 'content') and - chunk.choices[0].delta.content + hasattr(chunk, "choices") + and chunk.choices + and hasattr(chunk.choices[0], "delta") + and hasattr(chunk.choices[0].delta, "content") + and chunk.choices[0].delta.content ): yield chunk.choices[0].delta.content diff --git a/src/model/huggingface_llm_model.py b/src/model/huggingface_llm_model.py index 700d2a3dc3c8ca4452f3ee3805a1f033eb608d8c..ef1fb9cd9a89856223b0bb77142384136050cb98 100644 --- a/src/model/huggingface_llm_model.py +++ b/src/model/huggingface_llm_model.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from config import MODEL_CACHE_DIR, HF_ACCESS_TOKEN, MODEL_CONFIG from .llm_model_base import LlmModelBase from .prompt_builder import PromptBuilder -import asyncio + class HuggingfaceLlmModel(LlmModelBase): """ @@ -28,14 +28,8 @@ class HuggingfaceLlmModel(LlmModelBase): streamer: Optional[TextIteratorStreamer] = None - def __init__( - self, - model_id: str = "google/gemma-2-9b-it", - device: str = "cuda" - ): - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 - ) + def __init__(self, model_id: str = "google/gemma-2-9b-it", device: str = "cuda"): + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) self.model = AutoModelForCausalLM.from_pretrained( model_id, @@ -51,7 +45,7 @@ class HuggingfaceLlmModel(LlmModelBase): token=HF_ACCESS_TOKEN, padding=True, truncation=True, - max_length=self.model.config.max_position_embeddings + max_length=self.model.config.max_position_embeddings, ) self.model.eval() @@ -59,13 +53,9 @@ class HuggingfaceLlmModel(LlmModelBase): self.prompt_builder = PromptBuilder() async def generate( - self, - question: str, - context: str = None, - max_new_tokens: int = 256, - funny_prompt_chance: float = 0.7 + self, question: str, context: str = None, max_new_tokens: int = 256, funny_prompt_chance: float = 0.7 ) -> AsyncGenerator[str, None]: - #-> TextIteratorStreamer: + # -> TextIteratorStreamer: """ Generates text based on a question and optional context. @@ -110,7 +100,9 @@ class HuggingfaceLlmModel(LlmModelBase): self.streamer.text_queue.put(None) # Signal the streamer to stop except Exception as e: print(f"An unexpected error occurred: {e}") - self.stream_error_with_delay(f"Sorry, I ran into an unexpected error 🤯. I've output the error message to the console.") + self.stream_error_with_delay( + "Sorry, I ran into an unexpected error 🤯. I've output the error message to the console." + ) self.streamer.text_queue.put(None) # Signal the streamer to stop # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. @@ -124,16 +116,15 @@ class HuggingfaceLlmModel(LlmModelBase): # Return the asynchronous wrapper return stream_generator() - - def stream_error_with_delay(self, text: str, delay_ms: int=10): + + def stream_error_with_delay(self, text: str, delay_ms: int = 10): """ - Streams an error message to the chat with a delay between characters, giving the appearance of typing. + Streams an error message to the chat with a delay between characters, giving the appearance of typing. - Parameters: - text (str): The error message to stream. - delay_ms (int): The delay in milliseconds between each character. Defaults to 10 ms. - """ + Parameters: + text (str): The error message to stream. + delay_ms (int): The delay in milliseconds between each character. Defaults to 10 ms. + """ for char in text: self.streamer.text_queue.put(char) time.sleep(delay_ms / 1000.0) # Convert ms to seconds - diff --git a/src/model/llm_model_base.py b/src/model/llm_model_base.py index 929b80959c93023b8fdd4905e4afc4c8d44f7e42..51b2793b8336c2badf9ea70983ab2c922db634b2 100644 --- a/src/model/llm_model_base.py +++ b/src/model/llm_model_base.py @@ -1,12 +1,9 @@ from abc import ABC, abstractmethod + class LlmModelBase(ABC): - @abstractmethod - async def generate( - self, - question: str, - context: str = None, - max_new_tokens: int = 256, - funny_prompt_chance: float = 0.0 + @abstractmethod + async def generate( + self, question: str, context: str = None, max_new_tokens: int = 256, funny_prompt_chance: float = 0.0 ): - pass \ No newline at end of file + pass diff --git a/src/model/llm_model_factory.py b/src/model/llm_model_factory.py index 88039db189ccbee237c613c3b6e7ce54a8c4ea8f..cfb22121ddf512fb12b178479774424f61e1687e 100644 --- a/src/model/llm_model_factory.py +++ b/src/model/llm_model_factory.py @@ -3,14 +3,15 @@ from model.huggingface_llm_model import HuggingfaceLlmModel from model.llm_model_base import LlmModelBase from config import MODEL_CONFIG + class LlmModelFactory: - @staticmethod - def create_model() -> LlmModelBase: - model_type = MODEL_CONFIG["modelType"] + @staticmethod + def create_model() -> LlmModelBase: + model_type = MODEL_CONFIG["modelType"] - if model_type == "huggingface": - return HuggingfaceLlmModel(model_id=MODEL_CONFIG['modelId'], device="cuda") - elif model_type == "azure-openai": - return AzureOpenAiLlmModel(model_id=MODEL_CONFIG['modelId']) - else: - raise ValueError(f"Unknown model type: {model_type}") \ No newline at end of file + if model_type == "huggingface": + return HuggingfaceLlmModel(model_id=MODEL_CONFIG["modelId"], device="cuda") + elif model_type == "azure-openai": + return AzureOpenAiLlmModel(model_id=MODEL_CONFIG["modelId"]) + else: + raise ValueError(f"Unknown model type: {model_type}") diff --git a/src/model/prompt_builder.py b/src/model/prompt_builder.py index 1858577d0976b495ad24659aa88d718589fe9d27..f5b4210626593cb38d5adc057b6cc9b4c7d5b878 100644 --- a/src/model/prompt_builder.py +++ b/src/model/prompt_builder.py @@ -4,87 +4,82 @@ from config import MODEL_CONFIG class PromptBuilder: - """ - A class to construct system prompts based on: - - Some general system context (e.g. the personality of the system) - - The question asked by the user - - The context of the question (e.g. a document retrieved by RAG) - - A percentage chance to replace the prompt with a funny prompt - - Attributes: - role (str): The role of the entity asking the question, default is 'user'. - funnyPrompts (list): A list of funny prompts configured in MODEL_CONFIG. - usedFunnyPrompts (list): A list to keep track of funny prompts that have been used. - - Methods: - build_prompt(question: str, context: str, funny_prompt_chance: float): Constructs a system prompt. - _get_random_funny_prompt(): Selects a random funny prompt from the available list. - """ - - def __init__(self, role: str = "user"): - self.role = role - self.funnyPrompts = MODEL_CONFIG['funnyPrompts'] - self.usedFunnyPrompts = [] - - def build_prompt( - self, - question: str, - context: str = None, - funny_prompt_chance: float = 0.7 - ) -> list: - """ - Builds a complete system prompt based on the provided question and context, with a certain chance to completely - replace the prompt with a funny prompt. - - Parameters: - question (str): The question asked by the user. - context (str, optional): The context of the question, defaults to None. - funny_prompt_chance (float, optional): The chance (0.0 to 1.0) to use a funny prompt instead of a normal one, defaults to 0.7. - - Returns: - list: A list containing a single chat object with the role and constructed prompt. - """ - - if context is None or context == "": - prompt = MODEL_CONFIG['templateNoContext'] - else: - prompt = MODEL_CONFIG['templateWithContext'] - prompt = prompt.replace("{context}", context) - - # replace {question} and {context} in the prompt - system_prompt = prompt.replace("{question}", question) - - # Replace the entire prompt with a funny prompt sometimes - if (self.funnyPrompts or self.usedFunnyPrompts) and random.random() < funny_prompt_chance: - system_prompt = self._get_random_funny_prompt() - print("Using funny prompt instead of normal prompt.") - - # system_prompt = question - # print("Final system_prompt: ", system_prompt) - - # Build the chat object - chat = [{"role": self.role, "content": system_prompt}] - - return chat - - def _get_random_funny_prompt(self) -> str: - """ - Selects a random funny prompt from the configured list, ensuring it hasn't been used yet. - - Returns: - str: The selected funny prompt. - """ - if not self.funnyPrompts: - # If all prompts have been used, reset the list - self.funnyPrompts, self.usedFunnyPrompts = self.usedFunnyPrompts, [] - - # Get a random prompt based on weights - weights = [p["weight"] for p in self.funnyPrompts] - chosen_prompt = random.choices(self.funnyPrompts, weights=weights, k=1)[0] - - # Remove the chosen prompt from the list - self.funnyPrompts = [p for p in self.funnyPrompts if p != chosen_prompt] - # Add the chosen prompt to the used prompts list - self.usedFunnyPrompts.append(chosen_prompt) - - return chosen_prompt + """ + A class to construct system prompts based on: + - Some general system context (e.g. the personality of the system) + - The question asked by the user + - The context of the question (e.g. a document retrieved by RAG) + - A percentage chance to replace the prompt with a funny prompt + + Attributes: + role (str): The role of the entity asking the question, default is 'user'. + funnyPrompts (list): A list of funny prompts configured in MODEL_CONFIG. + usedFunnyPrompts (list): A list to keep track of funny prompts that have been used. + + Methods: + build_prompt(question: str, context: str, funny_prompt_chance: float): Constructs a system prompt. + _get_random_funny_prompt(): Selects a random funny prompt from the available list. + """ + + def __init__(self, role: str = "user"): + self.role = role + self.funnyPrompts = MODEL_CONFIG["funnyPrompts"] + self.usedFunnyPrompts = [] + + def build_prompt(self, question: str, context: str = None, funny_prompt_chance: float = 0.7) -> list: + """ + Builds a complete system prompt based on the provided question and context, with a certain chance to completely + replace the prompt with a funny prompt. + + Parameters: + question (str): The question asked by the user. + context (str, optional): The context of the question, defaults to None. + funny_prompt_chance (float, optional): The chance (0.0 to 1.0) to use a funny prompt instead of a normal one, defaults to 0.7. + + Returns: + list: A list containing a single chat object with the role and constructed prompt. + """ + + if context is None or context == "": + prompt = MODEL_CONFIG["templateNoContext"] + else: + prompt = MODEL_CONFIG["templateWithContext"] + prompt = prompt.replace("{context}", context) + + # replace {question} and {context} in the prompt + system_prompt = prompt.replace("{question}", question) + + # Replace the entire prompt with a funny prompt sometimes + if (self.funnyPrompts or self.usedFunnyPrompts) and random.random() < funny_prompt_chance: + system_prompt = self._get_random_funny_prompt() + print("Using funny prompt instead of normal prompt.") + + # system_prompt = question + # print("Final system_prompt: ", system_prompt) + + # Build the chat object + chat = [{"role": self.role, "content": system_prompt}] + + return chat + + def _get_random_funny_prompt(self) -> str: + """ + Selects a random funny prompt from the configured list, ensuring it hasn't been used yet. + + Returns: + str: The selected funny prompt. + """ + if not self.funnyPrompts: + # If all prompts have been used, reset the list + self.funnyPrompts, self.usedFunnyPrompts = self.usedFunnyPrompts, [] + + # Get a random prompt based on weights + weights = [p["weight"] for p in self.funnyPrompts] + chosen_prompt = random.choices(self.funnyPrompts, weights=weights, k=1)[0] + + # Remove the chosen prompt from the list + self.funnyPrompts = [p for p in self.funnyPrompts if p != chosen_prompt] + # Add the chosen prompt to the used prompts list + self.usedFunnyPrompts.append(chosen_prompt) + + return chosen_prompt diff --git a/src/rag/__init__.py b/src/rag/__init__.py index 7d599247a38afbde256562ad21d9706cc94e32fe..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/rag/__init__.py +++ b/src/rag/__init__.py @@ -1,8 +0,0 @@ -from . import context_extender -from . import encoder -from . import fais_db -from . import document_processor -from . import rag_context_retriever -from . import reranker -from . import simple_document_index -from . import rag_utils diff --git a/src/rag/context_extender.py b/src/rag/context_extender.py index 30682e0f2dd17d652127f2be1b8495409f85982a..75de464cfd9a54bd783172b0c6413414d6168e54 100644 --- a/src/rag/context_extender.py +++ b/src/rag/context_extender.py @@ -67,7 +67,7 @@ class ContextExtender: """ # Get the filename from the document's metadata - filename = doc.metadata['source'] + filename = doc.metadata["source"] # Use the IndexHelper to get the full path of the file file_paths = self.index_helper.get_file_paths() @@ -81,11 +81,12 @@ class ContextExtender: loader = UnstructuredMarkdownLoader(file_path) else: raise ValueError( - f"Unsupported file format: only PDF and HTML files support is implemented. Tried to load: {file_path}") + f"Unsupported file format: only PDF and HTML files support is implemented. Tried to load: {file_path}" + ) documents = loader.load() # Concatenate the content of all documents - full_content = ' '.join([document.page_content for document in documents]) + full_content = " ".join([document.page_content for document in documents]) # Create a new document with the loaded content and the original document's metadata full_doc = Document(page_content=full_content, metadata=doc.metadata) diff --git a/src/rag/document_processor.py b/src/rag/document_processor.py index 13a561a54ae45a9c050797687166d34ed0c9bf21..c7898bf906b47b0ddcd8b2b7bb7cf0268c317ac8 100644 --- a/src/rag/document_processor.py +++ b/src/rag/document_processor.py @@ -29,9 +29,7 @@ class DocumentProcessor: self.document_index = document_index self.chunk_size = chunk_size self.text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( - tokenizer=AutoTokenizer.from_pretrained( - "sentence-transformers/all-MiniLM-L12-v2" - ), + tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L12-v2"), chunk_size=self.chunk_size, chunk_overlap=int(self.chunk_size / 10), strip_whitespace=True, @@ -64,8 +62,7 @@ class DocumentProcessor: """ print("RAG: Loading and splitting documents...") pages_pdf = [ - page for file_path in file_paths if file_path.endswith(".pdf") for page in - PyPDFLoader(file_path).load() + page for file_path in file_paths if file_path.endswith(".pdf") for page in PyPDFLoader(file_path).load() ] pages_html = [] @@ -105,7 +102,7 @@ class DocumentProcessor: # Append additional metadata from our index.yaml to each document for doc in docs: # Get the filename from the source path - filename = os.path.basename(doc.metadata['source']) + filename = os.path.basename(doc.metadata["source"]) # Get the corresponding metadata from the index.yaml metadata = self.document_index.get_metadata(filename) # Merge the metadata into the document diff --git a/src/rag/encoder.py b/src/rag/encoder.py index 6eaf00c5c8ab634a0259959932c9fd743662f0b9..f91ed86c8485433f4ccecc48eca65018104ea8ee 100644 --- a/src/rag/encoder.py +++ b/src/rag/encoder.py @@ -4,29 +4,25 @@ from config import MODEL_CACHE_DIR class Encoder: - """ - A class to handle the encoding of text using models from Hugging Face's transformers library. - Used for generating embeddings for text for RAG search. - - Attributes: - embedding_function (HuggingFaceEmbeddings): An instance of HuggingFaceEmbeddings used for generating embeddings. - - Args: - model_name (str): The name of the sentence transformer model to be used for embeddings. - Defaults to "sentence-transformers/all-MiniLM-L12-v2". - device (str): The device to run the model on. Defaults to "gpu". - """ - - def __init__( - self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device="gpu" - ): - self.embedding_function = HuggingFaceEmbeddings( - model_name=model_name, - cache_folder=MODEL_CACHE_DIR, - model_kwargs={"device": device}, - ) - - def get(self) -> HuggingFaceEmbeddings: - return self.embedding_function - - + """ + A class to handle the encoding of text using models from Hugging Face's transformers library. + Used for generating embeddings for text for RAG search. + + Attributes: + embedding_function (HuggingFaceEmbeddings): An instance of HuggingFaceEmbeddings used for generating embeddings. + + Args: + model_name (str): The name of the sentence transformer model to be used for embeddings. + Defaults to "sentence-transformers/all-MiniLM-L12-v2". + device (str): The device to run the model on. Defaults to "gpu". + """ + + def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device="gpu"): + self.embedding_function = HuggingFaceEmbeddings( + model_name=model_name, + cache_folder=MODEL_CACHE_DIR, + model_kwargs={"device": device}, + ) + + def get(self) -> HuggingFaceEmbeddings: + return self.embedding_function diff --git a/src/rag/fais_db.py b/src/rag/fais_db.py index 57f848f402cd017b8ecd70a746dca16b94c87213..f013cd71b52160dc5aba92a81bd33e9494ad415b 100644 --- a/src/rag/fais_db.py +++ b/src/rag/fais_db.py @@ -55,9 +55,7 @@ class FaissDb: Parameters: docs (List[Document]): A list of documents to be added to the FAISS database. """ - self.db = FAISS.from_documents( - docs, self.embedding_function, distance_strategy=DistanceStrategy.COSINE - ) + self.db = FAISS.from_documents(docs, self.embedding_function, distance_strategy=DistanceStrategy.COSINE) print("Loaded documents into new FAISS DB - Done!") def similarity_search(self, question: str, k: int = 3) -> List[Document]: diff --git a/src/rag/rag_context_retriever.py b/src/rag/rag_context_retriever.py index 64b3066d2e0261c427920998935be137233c7e2c..8d81ea97d737634cee84b256538bc3415dafb6ff 100644 --- a/src/rag/rag_context_retriever.py +++ b/src/rag/rag_context_retriever.py @@ -28,14 +28,14 @@ class RagContextRetriever: """ def __init__( - self, - rag_db: FaissDb, - reranker: ReRanker, - context_extender: ContextExtender, - k1: int, - k2: int, - k2_threshold: float, - expand_context: bool + self, + rag_db: FaissDb, + reranker: ReRanker, + context_extender: ContextExtender, + k1: int, + k2: int, + k2_threshold: float, + expand_context: bool, ): """ Initializes the RagContextRetriever with the necessary components and parameters. @@ -90,7 +90,7 @@ class RagContextRetriever: display_documents_as_dataframe(retrieved_docs) # Apply filter to remove documents with scores below the threshold - retrieved_docs = [doc for doc in retrieved_docs if doc.metadata.get('score') > self.k2_threshold] + retrieved_docs = [doc for doc in retrieved_docs if doc.metadata.get("score") > self.k2_threshold] if debug_mode: st.write("##### Document chunks after filtering by threshold") @@ -105,11 +105,13 @@ class RagContextRetriever: if debug_mode: st.write("##### Documents after expanding") - st.write(f"_Filtered for threshold {self.k2_threshold}, grouped by title, keeping only the highest " - f"scoring document, expanding the context to the full source document._") + st.write( + f"_Filtered for threshold {self.k2_threshold}, grouped by title, keeping only the highest " + f"scoring document, expanding the context to the full source document._" + ) display_documents_as_dataframe(retrieved_docs) # Convert the retrieved documents to a single string - context = ("".join(doc.page_content + "\n" for doc in retrieved_docs)) + context = "".join(doc.page_content + "\n" for doc in retrieved_docs) return context, retrieved_docs diff --git a/src/rag/rag_utils.py b/src/rag/rag_utils.py index 17f314bb5474392d7f3b877ca42c9ea33b71eb1f..9fef2e00d4ea33ce670dc624a63cfa272f720783 100644 --- a/src/rag/rag_utils.py +++ b/src/rag/rag_utils.py @@ -12,11 +12,11 @@ def filter_highest_scoring_entries(retrieved_docs: List[Document]) -> List[Docum highest_scoring_docs = {} for doc in retrieved_docs: - title = doc.metadata.get('title') - score = doc.metadata.get('score') + title = doc.metadata.get("title") + score = doc.metadata.get("score") # If the title is not in the dictionary or the current score is higher than the stored one, update it - if title not in highest_scoring_docs or score > highest_scoring_docs[title].metadata.get('score'): + if title not in highest_scoring_docs or score > highest_scoring_docs[title].metadata.get("score"): highest_scoring_docs[title] = doc # Return the values of the dictionary as a list @@ -35,13 +35,12 @@ def display_documents_as_dataframe(documents, show_score=True) -> None: doc_data = [] for doc in documents: doc_dict = { - "Title": doc.metadata.get('title'), - "URL": doc.metadata.get('url'), + "Title": doc.metadata.get("title"), + "URL": doc.metadata.get("url"), "Content": doc.page_content, } if show_score: - score = doc.metadata.get('score', 0) + score = doc.metadata.get("score", 0) doc_dict["Score"] = round(score, 2) if score is not None else 0 doc_data.append(doc_dict) st.dataframe(doc_data) - diff --git a/src/rag/reranker.py b/src/rag/reranker.py index 5ca53feb84bb3c770c72be2548bbdaae2c5a5b8e..0945f5b24765be1f02692d45334911e9ce9795f8 100644 --- a/src/rag/reranker.py +++ b/src/rag/reranker.py @@ -19,14 +19,9 @@ class ReRanker: """ def __init__( - self, - encoder_name: str = "cross-encoder/ms-marco-TinyBERT-L-2-v2", - max_length: int = 512, - device: str = "cpu" + self, encoder_name: str = "cross-encoder/ms-marco-TinyBERT-L-2-v2", max_length: int = 512, device: str = "cpu" ): - self.cross_encoder = CrossEncoder( - encoder_name, max_length=max_length, device=device - ) + self.cross_encoder = CrossEncoder(encoder_name, max_length=max_length, device=device) def rerank(self, question: str, docs: list, k: int = 3) -> List[Document]: """ @@ -41,9 +36,7 @@ class ReRanker: List[Document]: A list of the top k documents ranked by their relevance to the question. """ - scores = self.cross_encoder.predict( - [(question, doc.page_content) for doc in docs] - ) + scores = self.cross_encoder.predict([(question, doc.page_content) for doc in docs]) # Pair each document with its score scored_docs = list(zip(docs, scores)) @@ -59,7 +52,7 @@ class ReRanker: for doc_score in scored_docs: doc = doc_score[0] score = doc_score[1] - doc.metadata['score'] = score + doc.metadata["score"] = score documents_only.append(doc) return documents_only diff --git a/src/rag/simple_document_index.py b/src/rag/simple_document_index.py index b821e84c3ff03c435081088b862e4f8c665b7598..c5a45cab20710f65a1388252e26b416c5e014bbf 100644 --- a/src/rag/simple_document_index.py +++ b/src/rag/simple_document_index.py @@ -5,53 +5,53 @@ import yaml class SimpleDocumentIndex: - """ - Loads document index.yaml from a specified directory (./documents) and allows easy - retrieval of: - - all file paths for the documents listed in the index - - metadata for a specific document - - - Attributes: - directory_path (str): The path to the directory containing the index.yaml file. - index_file_path (str): The full path to the index.yaml file. - documents (dict): A dictionary loaded from the index.yaml file, where keys are document - filenames and values are metadata associated with each document. - """ - - def __init__(self, directory_path: str, filename: str = 'index.yaml'): - self.directory_path = directory_path - self.index_file_path = os.path.join(directory_path, filename) - self.documents = None - - def _load_yaml(self): - """Loads the document index from the index.yaml file into the `documents` attribute.""" - with open(self.index_file_path, 'r', encoding="utf8") as f: - self.documents = yaml.safe_load(f) - - def get_file_paths(self) -> List[str]: - """ - Retrieves the file paths for all documents listed in the index. - - Returns: - list: A list of file paths for the documents. - """ - if self.documents is None: - self._load_yaml() - return [os.path.join(self.directory_path, filename) for filename in self.documents] - - def get_metadata(self, filename: str) -> dict: - """ - Retrieves the metadata for a specific document, identified by its filename. - - If the documents are not already loaded, it loads them first. - - Args: - filename (str): The filename of the document to retrieve metadata for. - - Returns: - dict: The metadata associated with the document, or None if the document is not found. - """ - if self.documents is None: - self._load_yaml() - return self.documents.get(filename, None) + """ + Loads document index.yaml from a specified directory (./documents) and allows easy + retrieval of: + - all file paths for the documents listed in the index + - metadata for a specific document + + + Attributes: + directory_path (str): The path to the directory containing the index.yaml file. + index_file_path (str): The full path to the index.yaml file. + documents (dict): A dictionary loaded from the index.yaml file, where keys are document + filenames and values are metadata associated with each document. + """ + + def __init__(self, directory_path: str, filename: str = "index.yaml"): + self.directory_path = directory_path + self.index_file_path = os.path.join(directory_path, filename) + self.documents = None + + def _load_yaml(self): + """Loads the document index from the index.yaml file into the `documents` attribute.""" + with open(self.index_file_path, "r", encoding="utf8") as f: + self.documents = yaml.safe_load(f) + + def get_file_paths(self) -> List[str]: + """ + Retrieves the file paths for all documents listed in the index. + + Returns: + list: A list of file paths for the documents. + """ + if self.documents is None: + self._load_yaml() + return [os.path.join(self.directory_path, filename) for filename in self.documents] + + def get_metadata(self, filename: str) -> dict: + """ + Retrieves the metadata for a specific document, identified by its filename. + + If the documents are not already loaded, it loads them first. + + Args: + filename (str): The filename of the document to retrieve metadata for. + + Returns: + dict: The metadata associated with the document, or None if the document is not found. + """ + if self.documents is None: + self._load_yaml() + return self.documents.get(filename, None) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 0717303806dd38a97c31705146cdcd108a697e8a..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,2 +0,0 @@ -from . import audio_generator -from . import response_generator diff --git a/src/utils/audio_generator.py b/src/utils/audio_generator.py index 5a39a88e755e391cbf495a452729bcc97927a444..4f64f1558569dd5e8151f46c4cf8bdba793c04e8 100644 --- a/src/utils/audio_generator.py +++ b/src/utils/audio_generator.py @@ -6,29 +6,30 @@ from config import WORD_SOUND_MAPPING class AudioGenerator: - """ - Checks if the (partial) output of a LLM contains any words that are configured in the WORD_SOUND_MAPPING. - If so, plays the configured sound for that word using a Streamlit audio element. - - Not for production. Used as a fun feature for the demo. - """ - def __init__(self): - self.word_sound_mapping = WORD_SOUND_MAPPING - - def play_sound(self, partial_output: str) -> bool: - # Check if there are any word_sound_mappings configured - if not self.word_sound_mapping: - return False - - # Split the partial_output into words using regex - words = re.findall(r'\b\w+\b', partial_output.lower()) - - # Check for each word in the mapping - for word, sound_url in self.word_sound_mapping.items(): - if word.lower() in words: - # Make sure the large audio player is hidden - st.markdown('<style>.stAudio{display:none;}</style>', unsafe_allow_html=True) - # Add audio player to the streamlit app - st.audio(sound_url, autoplay=True) - return True - return False + """ + Checks if the (partial) output of a LLM contains any words that are configured in the WORD_SOUND_MAPPING. + If so, plays the configured sound for that word using a Streamlit audio element. + + Not for production. Used as a fun feature for the demo. + """ + + def __init__(self): + self.word_sound_mapping = WORD_SOUND_MAPPING + + def play_sound(self, partial_output: str) -> bool: + # Check if there are any word_sound_mappings configured + if not self.word_sound_mapping: + return False + + # Split the partial_output into words using regex + words = re.findall(r"\b\w+\b", partial_output.lower()) + + # Check for each word in the mapping + for word, sound_url in self.word_sound_mapping.items(): + if word.lower() in words: + # Make sure the large audio player is hidden + st.markdown("<style>.stAudio{display:none;}</style>", unsafe_allow_html=True) + # Add audio player to the streamlit app + st.audio(sound_url, autoplay=True) + return True + return False diff --git a/src/utils/response_generator.py b/src/utils/response_generator.py index a72483322afada8bd8ae912dac394d8d3acaab4c..ed4cf57698a1e89877a2b8d2bdf27ec211acd814 100644 --- a/src/utils/response_generator.py +++ b/src/utils/response_generator.py @@ -15,11 +15,7 @@ class ResponseGenerator: # Stream a LLM response to the chat async def generate_response( - self, - user_prompt: str, - context: str, - max_new_tokens: int, - funny_prompt_chance: float + self, user_prompt: str, context: str, max_new_tokens: int, funny_prompt_chance: float ) -> str: response_container = st.empty() response = "" @@ -29,10 +25,7 @@ class ResponseGenerator: # Generate response and stream it to the chat async for partial_output in await self.model.generate( - user_prompt, - context=context, - max_new_tokens=max_new_tokens, - funny_prompt_chance=funny_prompt_chance + user_prompt, context=context, max_new_tokens=max_new_tokens, funny_prompt_chance=funny_prompt_chance ): response += partial_output response = response.replace("<eos>", "").replace("<end_of_turn>", "") @@ -48,20 +41,24 @@ class ResponseGenerator: @staticmethod def get_unique_references(retrieved_docs: List[Document], response: str) -> List[Document]: # Skip if no documents were retrieved or if the response is a default message - if not retrieved_docs or response.startswith(MODEL_CONFIG['noContextFoundFlag']) or response.startswith(MODEL_CONFIG['outOfMemoryError']): + if ( + not retrieved_docs + or response.startswith(MODEL_CONFIG["noContextFoundFlag"]) + or response.startswith(MODEL_CONFIG["outOfMemoryError"]) + ): return [] # Filter to keep unique documents, showing each reference only once unique_docs = {} for doc in retrieved_docs: - title = doc.metadata.get('title') - url = doc.metadata.get('url') - score = doc.metadata.get('score') + title = doc.metadata.get("title") + url = doc.metadata.get("url") + score = doc.metadata.get("score") if title and url: # If the title already exists in the dictionary, compare the scores if title in unique_docs: # If the new document's score is higher, replace the old one - if unique_docs[title].metadata.get('score') < score: + if unique_docs[title].metadata.get("score") < score: unique_docs[title] = doc else: # If the title is not in the dictionary, add the new document