Skip to content
Snippets Groups Projects
Commit 6cc3c593 authored by Evers, Valentijn's avatar Evers, Valentijn
Browse files

Merge branch 'L03DIGLIB-1563/add-linter'

parents d716dbb5 c0cac2d4
Branches main
No related tags found
No related merge requests found
Showing
with 219 additions and 279 deletions
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
line-length = 120
[format]
quote-style = "double"
indent-style = "tab"
docstring-code-format = true
\ No newline at end of file
import sys import sys
if sys.platform.startswith('win'):
import asyncio if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
from typing import List from typing import List
...@@ -24,7 +26,7 @@ from utils.response_generator import ResponseGenerator ...@@ -24,7 +26,7 @@ from utils.response_generator import ResponseGenerator
st.set_page_config( st.set_page_config(
page_title="Chat GPP - Generative Pre-trained Peter", page_title="Chat GPP - Generative Pre-trained Peter",
page_icon="https://www.wur.nl/favicon.ico", page_icon="https://www.wur.nl/favicon.ico",
initial_sidebar_state="collapsed" initial_sidebar_state="collapsed",
) )
st.title("Chat GPP\n### Generative Pre-trained Peter") st.title("Chat GPP\n### Generative Pre-trained Peter")
...@@ -37,9 +39,7 @@ def load_model(): ...@@ -37,9 +39,7 @@ def load_model():
@st.cache_resource @st.cache_resource
def load_encoder(): def load_encoder():
encoder = Encoder( encoder = Encoder(model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu")
model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu"
)
return encoder return encoder
...@@ -72,16 +72,12 @@ RagDB = load_rag_index() ...@@ -72,16 +72,12 @@ RagDB = load_rag_index()
# #
def display_references(references: List[Document]): def display_references(references: List[Document]):
for doc in references: for doc in references:
title = doc.metadata.get('title') title = doc.metadata.get("title")
#modified_date = doc.metadata.get('modified_date') # modified_date = doc.metadata.get('modified_date')
#date_only = modified_date.split('T')[0] # date_only = modified_date.split('T')[0]
#modified_by = doc.metadata.get('modified_by') # modified_by = doc.metadata.get('modified_by')
readable_score = str(round(doc.metadata.get('score'), 2)) readable_score = str(round(doc.metadata.get("score"), 2))
mention( mention(label=f"{title} ({readable_score})", icon="https://www.wur.nl/favicon.ico", url=doc.metadata.get("url"))
label=f"{title} ({readable_score})",
icon="https://www.wur.nl/favicon.ico",
url=doc.metadata.get('url')
)
# Display settings in sidebar # Display settings in sidebar
...@@ -124,6 +120,7 @@ for message in st.session_state.messages: ...@@ -124,6 +120,7 @@ for message in st.session_state.messages:
if "references" in message: if "references" in message:
display_references(message["references"]) display_references(message["references"])
async def main(): async def main():
# Accept user input # Accept user input
if user_prompt := st.chat_input("Ask me anything!"): if user_prompt := st.chat_input("Ask me anything!"):
...@@ -135,7 +132,9 @@ async def main(): ...@@ -135,7 +132,9 @@ async def main():
st.markdown(user_prompt) st.markdown(user_prompt)
# Retrieve context from Vector DB # Retrieve context from Vector DB
context_retriever = RagContextRetriever(RagDB, re_ranker, context_extender, k1, k2, k2_threshold, expand_context) context_retriever = RagContextRetriever(
RagDB, re_ranker, context_extender, k1, k2, k2_threshold, expand_context
)
context, retrieved_docs = context_retriever.retrieve_context(user_prompt, debug_show_rag_context) context, retrieved_docs = context_retriever.retrieve_context(user_prompt, debug_show_rag_context)
# context = None # context = None
# references = [] # references = []
...@@ -145,18 +144,18 @@ async def main(): ...@@ -145,18 +144,18 @@ async def main():
with st.chat_message("assistant", avatar=CHATBOT_AVATAR_URL): with st.chat_message("assistant", avatar=CHATBOT_AVATAR_URL):
# Generate response and stream it to the chat # Generate response and stream it to the chat
response_generator = ResponseGenerator(model, audio_generator) response_generator = ResponseGenerator(model, audio_generator)
response = await response_generator.generate_response(user_prompt, context, max_new_tokens, funny_response_chance) response = await response_generator.generate_response(
user_prompt, context, max_new_tokens, funny_response_chance
)
# List references # List references
references = response_generator.get_unique_references(retrieved_docs, response) references = response_generator.get_unique_references(retrieved_docs, response)
display_references(references) display_references(references)
# Add response to chat history # Add response to chat history
st.session_state.messages.append({ st.session_state.messages.append(
"role": "assistant", {"role": "assistant", "content": response, "avatar": CHATBOT_AVATAR_URL, "references": references}
"content": response, )
"avatar": CHATBOT_AVATAR_URL,
"references": references
}) asyncio.run(main())
asyncio.run(main())
\ No newline at end of file
# build-rag-index.py # build-rag-index.py
import os
from rag import encoder, simple_document_index, document_processor, fais_db from rag import encoder, simple_document_index, document_processor, fais_db
from config import RAG_INDEX_FOLDER, RAG_DOCUMENTS_FOLDER from config import RAG_INDEX_FOLDER, RAG_DOCUMENTS_FOLDER
...@@ -8,9 +7,7 @@ index_file_path = "storage/faiss-index" ...@@ -8,9 +7,7 @@ index_file_path = "storage/faiss-index"
# Load the encoder # Load the encoder
print("Loading encoder") print("Loading encoder")
encoder = encoder.Encoder( encoder = encoder.Encoder(model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu")
model_name="sentence-transformers/all-MiniLM-L12-v2", device="cpu"
)
# Create an instance of IndexHelper # Create an instance of IndexHelper
index_helper = simple_document_index.SimpleDocumentIndex(RAG_DOCUMENTS_FOLDER) index_helper = simple_document_index.SimpleDocumentIndex(RAG_DOCUMENTS_FOLDER)
......
# build-rag-index.py # build-rag-index.py
import os
import streamlit as st import streamlit as st
from rag import simple_document_index, fais_db from rag import simple_document_index, fais_db
from rag.rag_utils import display_documents_as_dataframe from rag.rag_utils import display_documents_as_dataframe
from config import RAG_DOCUMENTS_FOLDER from config import RAG_DOCUMENTS_FOLDER
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import CrossEncoder from sentence_transformers import CrossEncoder
st.set_page_config(layout="wide") st.set_page_config(layout="wide")
st.title('Compare RAG performance') st.title("Compare RAG performance")
cross_encoder = CrossEncoder("encoder_name", max_length=512, device="cpu")
cross_encoder = CrossEncoder(
"encoder_name", max_length=512, device="cpu"
)
def get_vector_db(model_name: str, index_filename: str): def get_vector_db(model_name: str, index_filename: str):
from rag import encoder, document_processor from rag import encoder, document_processor
# Load the encoder # Load the encoder
print("Loading encoder") print("Loading encoder")
encoder = encoder.Encoder( encoder = encoder.Encoder(model_name=model_name, device="cpu")
model_name=model_name, device="cpu"
)
# Create an instance of IndexHelper # Create an instance of IndexHelper
index_helper = simple_document_index.SimpleDocumentIndex(RAG_DOCUMENTS_FOLDER, filename=index_filename) index_helper = simple_document_index.SimpleDocumentIndex(RAG_DOCUMENTS_FOLDER, filename=index_filename)
...@@ -55,14 +50,15 @@ def lookup(query: str, model: str, k: int = 5): ...@@ -55,14 +50,15 @@ def lookup(query: str, model: str, k: int = 5):
return [results1, results2] return [results1, results2]
#def rerank(results):
# Rerank the results # def rerank(results):
# Rerank the results
k = st.number_input("k1 (rag)", 1, 10, 5) k = st.number_input("k1 (rag)", 1, 10, 5)
modelA = st.text_input('Enter model name A', "sentence-transformers/all-MiniLM-L12-v2") modelA = st.text_input("Enter model name A", "sentence-transformers/all-MiniLM-L12-v2")
modelB = st.text_input('Enter model name B', "sentence-transformers/paraphrase-multilingual-mpnet-base-v2") modelB = st.text_input("Enter model name B", "sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
query = st.text_input('Enter query') query = st.text_input("Enter query")
if query and modelA and modelB: if query and modelA and modelB:
[results_a_nl, results_a_en] = lookup(query, modelA, k) [results_a_nl, results_a_en] = lookup(query, modelA, k)
...@@ -85,7 +81,3 @@ if query and modelA and modelB: ...@@ -85,7 +81,3 @@ if query and modelA and modelB:
with col_en_2: with col_en_2:
st.write(f"Model: {modelB}") st.write(f"Model: {modelB}")
display_documents_as_dataframe(results_b_en) display_documents_as_dataframe(results_b_en)
...@@ -11,7 +11,7 @@ def load_yaml_config(file_path): ...@@ -11,7 +11,7 @@ def load_yaml_config(file_path):
# Check if file exists # Check if file exists
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise FileNotFoundError(f"Config file not found: {file_path}") raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
return config return config
...@@ -55,5 +55,5 @@ CHATBOT_AVATAR_URL = os.getenv("CHATBOT_AVATAR_URL", "https://www.wur.nl/favicon ...@@ -55,5 +55,5 @@ CHATBOT_AVATAR_URL = os.getenv("CHATBOT_AVATAR_URL", "https://www.wur.nl/favicon
# #
# Sound mapping for audio generator # Sound mapping for audio generator
# #
word_sound_mapping_str = os.getenv('WORD_SOUND_MAPPING', '{}') word_sound_mapping_str = os.getenv("WORD_SOUND_MAPPING", "{}")
WORD_SOUND_MAPPING = json.loads(word_sound_mapping_str) WORD_SOUND_MAPPING = json.loads(word_sound_mapping_str)
from . import huggingface_llm_model
from . import azure_openai_llm_model
from . import prompt_builder
from . import llm_model_factory
\ No newline at end of file
...@@ -6,22 +6,14 @@ from .llm_model_base import LlmModelBase ...@@ -6,22 +6,14 @@ from .llm_model_base import LlmModelBase
from .prompt_builder import PromptBuilder from .prompt_builder import PromptBuilder
class AzureOpenAiLlmModel (LlmModelBase): class AzureOpenAiLlmModel(LlmModelBase):
def __init__( def __init__(self, model_id: str):
self,
model_id: str
):
self.prompt_builder = PromptBuilder() self.prompt_builder = PromptBuilder()
self.model_id = model_id self.model_id = model_id
async def generate( async def generate(
self, self, question: str, context: str = None, max_new_tokens: int = 256, funny_prompt_chance: float = 0.0
question: str,
context: str = None,
max_new_tokens: int = 256,
funny_prompt_chance: float = 0.0
): ):
# Build the system prompt # Build the system prompt
chat_messages = self.prompt_builder.build_prompt(question, context, funny_prompt_chance) chat_messages = self.prompt_builder.build_prompt(question, context, funny_prompt_chance)
...@@ -29,11 +21,7 @@ class AzureOpenAiLlmModel (LlmModelBase): ...@@ -29,11 +21,7 @@ class AzureOpenAiLlmModel (LlmModelBase):
endpoint = os.environ["AZURE_OPEN_AI_ENDPOINT"] endpoint = os.environ["AZURE_OPEN_AI_ENDPOINT"]
api_key = os.environ["AZURE_OPEN_AI_API_KEY"] api_key = os.environ["AZURE_OPEN_AI_API_KEY"]
client = openai.AsyncAzureOpenAI( client = openai.AsyncAzureOpenAI(azure_endpoint=endpoint, api_key=api_key, api_version="2023-09-01-preview")
azure_endpoint=endpoint,
api_key=api_key,
api_version="2023-09-01-preview"
)
print("Calling Azure OpenAI API") print("Calling Azure OpenAI API")
...@@ -46,18 +34,18 @@ class AzureOpenAiLlmModel (LlmModelBase): ...@@ -46,18 +34,18 @@ class AzureOpenAiLlmModel (LlmModelBase):
top_p=1.0, top_p=1.0,
max_tokens=max_new_tokens, max_tokens=max_new_tokens,
messages=chat_messages, messages=chat_messages,
stream=True stream=True,
) )
# Define a simplified async generator to yield text chunks # Define a simplified async generator to yield text chunks
async def text_stream(): async def text_stream():
async for chunk in azure_open_ai_response: async for chunk in azure_open_ai_response:
if ( if (
hasattr(chunk, 'choices') and hasattr(chunk, "choices")
chunk.choices and and chunk.choices
hasattr(chunk.choices[0], 'delta') and and hasattr(chunk.choices[0], "delta")
hasattr(chunk.choices[0].delta, 'content') and and hasattr(chunk.choices[0].delta, "content")
chunk.choices[0].delta.content and chunk.choices[0].delta.content
): ):
yield chunk.choices[0].delta.content yield chunk.choices[0].delta.content
......
...@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig ...@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from config import MODEL_CACHE_DIR, HF_ACCESS_TOKEN, MODEL_CONFIG from config import MODEL_CACHE_DIR, HF_ACCESS_TOKEN, MODEL_CONFIG
from .llm_model_base import LlmModelBase from .llm_model_base import LlmModelBase
from .prompt_builder import PromptBuilder from .prompt_builder import PromptBuilder
import asyncio
class HuggingfaceLlmModel(LlmModelBase): class HuggingfaceLlmModel(LlmModelBase):
""" """
...@@ -28,14 +28,8 @@ class HuggingfaceLlmModel(LlmModelBase): ...@@ -28,14 +28,8 @@ class HuggingfaceLlmModel(LlmModelBase):
streamer: Optional[TextIteratorStreamer] = None streamer: Optional[TextIteratorStreamer] = None
def __init__( def __init__(self, model_id: str = "google/gemma-2-9b-it", device: str = "cuda"):
self, quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
model_id: str = "google/gemma-2-9b-it",
device: str = "cuda"
):
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
...@@ -51,7 +45,7 @@ class HuggingfaceLlmModel(LlmModelBase): ...@@ -51,7 +45,7 @@ class HuggingfaceLlmModel(LlmModelBase):
token=HF_ACCESS_TOKEN, token=HF_ACCESS_TOKEN,
padding=True, padding=True,
truncation=True, truncation=True,
max_length=self.model.config.max_position_embeddings max_length=self.model.config.max_position_embeddings,
) )
self.model.eval() self.model.eval()
...@@ -59,13 +53,9 @@ class HuggingfaceLlmModel(LlmModelBase): ...@@ -59,13 +53,9 @@ class HuggingfaceLlmModel(LlmModelBase):
self.prompt_builder = PromptBuilder() self.prompt_builder = PromptBuilder()
async def generate( async def generate(
self, self, question: str, context: str = None, max_new_tokens: int = 256, funny_prompt_chance: float = 0.7
question: str,
context: str = None,
max_new_tokens: int = 256,
funny_prompt_chance: float = 0.7
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
#-> TextIteratorStreamer: # -> TextIteratorStreamer:
""" """
Generates text based on a question and optional context. Generates text based on a question and optional context.
...@@ -110,7 +100,9 @@ class HuggingfaceLlmModel(LlmModelBase): ...@@ -110,7 +100,9 @@ class HuggingfaceLlmModel(LlmModelBase):
self.streamer.text_queue.put(None) # Signal the streamer to stop self.streamer.text_queue.put(None) # Signal the streamer to stop
except Exception as e: except Exception as e:
print(f"An unexpected error occurred: {e}") print(f"An unexpected error occurred: {e}")
self.stream_error_with_delay(f"Sorry, I ran into an unexpected error 🤯. I've output the error message to the console.") self.stream_error_with_delay(
"Sorry, I ran into an unexpected error 🤯. I've output the error message to the console."
)
self.streamer.text_queue.put(None) # Signal the streamer to stop self.streamer.text_queue.put(None) # Signal the streamer to stop
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
...@@ -124,16 +116,15 @@ class HuggingfaceLlmModel(LlmModelBase): ...@@ -124,16 +116,15 @@ class HuggingfaceLlmModel(LlmModelBase):
# Return the asynchronous wrapper # Return the asynchronous wrapper
return stream_generator() return stream_generator()
def stream_error_with_delay(self, text: str, delay_ms: int=10): def stream_error_with_delay(self, text: str, delay_ms: int = 10):
""" """
Streams an error message to the chat with a delay between characters, giving the appearance of typing. Streams an error message to the chat with a delay between characters, giving the appearance of typing.
Parameters: Parameters:
text (str): The error message to stream. text (str): The error message to stream.
delay_ms (int): The delay in milliseconds between each character. Defaults to 10 ms. delay_ms (int): The delay in milliseconds between each character. Defaults to 10 ms.
""" """
for char in text: for char in text:
self.streamer.text_queue.put(char) self.streamer.text_queue.put(char)
time.sleep(delay_ms / 1000.0) # Convert ms to seconds time.sleep(delay_ms / 1000.0) # Convert ms to seconds
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class LlmModelBase(ABC): class LlmModelBase(ABC):
@abstractmethod @abstractmethod
async def generate( async def generate(
self, self, question: str, context: str = None, max_new_tokens: int = 256, funny_prompt_chance: float = 0.0
question: str,
context: str = None,
max_new_tokens: int = 256,
funny_prompt_chance: float = 0.0
): ):
pass pass
\ No newline at end of file
...@@ -3,14 +3,15 @@ from model.huggingface_llm_model import HuggingfaceLlmModel ...@@ -3,14 +3,15 @@ from model.huggingface_llm_model import HuggingfaceLlmModel
from model.llm_model_base import LlmModelBase from model.llm_model_base import LlmModelBase
from config import MODEL_CONFIG from config import MODEL_CONFIG
class LlmModelFactory: class LlmModelFactory:
@staticmethod @staticmethod
def create_model() -> LlmModelBase: def create_model() -> LlmModelBase:
model_type = MODEL_CONFIG["modelType"] model_type = MODEL_CONFIG["modelType"]
if model_type == "huggingface": if model_type == "huggingface":
return HuggingfaceLlmModel(model_id=MODEL_CONFIG['modelId'], device="cuda") return HuggingfaceLlmModel(model_id=MODEL_CONFIG["modelId"], device="cuda")
elif model_type == "azure-openai": elif model_type == "azure-openai":
return AzureOpenAiLlmModel(model_id=MODEL_CONFIG['modelId']) return AzureOpenAiLlmModel(model_id=MODEL_CONFIG["modelId"])
else: else:
raise ValueError(f"Unknown model type: {model_type}") raise ValueError(f"Unknown model type: {model_type}")
\ No newline at end of file
...@@ -4,87 +4,82 @@ from config import MODEL_CONFIG ...@@ -4,87 +4,82 @@ from config import MODEL_CONFIG
class PromptBuilder: class PromptBuilder:
""" """
A class to construct system prompts based on: A class to construct system prompts based on:
- Some general system context (e.g. the personality of the system) - Some general system context (e.g. the personality of the system)
- The question asked by the user - The question asked by the user
- The context of the question (e.g. a document retrieved by RAG) - The context of the question (e.g. a document retrieved by RAG)
- A percentage chance to replace the prompt with a funny prompt - A percentage chance to replace the prompt with a funny prompt
Attributes: Attributes:
role (str): The role of the entity asking the question, default is 'user'. role (str): The role of the entity asking the question, default is 'user'.
funnyPrompts (list): A list of funny prompts configured in MODEL_CONFIG. funnyPrompts (list): A list of funny prompts configured in MODEL_CONFIG.
usedFunnyPrompts (list): A list to keep track of funny prompts that have been used. usedFunnyPrompts (list): A list to keep track of funny prompts that have been used.
Methods: Methods:
build_prompt(question: str, context: str, funny_prompt_chance: float): Constructs a system prompt. build_prompt(question: str, context: str, funny_prompt_chance: float): Constructs a system prompt.
_get_random_funny_prompt(): Selects a random funny prompt from the available list. _get_random_funny_prompt(): Selects a random funny prompt from the available list.
""" """
def __init__(self, role: str = "user"): def __init__(self, role: str = "user"):
self.role = role self.role = role
self.funnyPrompts = MODEL_CONFIG['funnyPrompts'] self.funnyPrompts = MODEL_CONFIG["funnyPrompts"]
self.usedFunnyPrompts = [] self.usedFunnyPrompts = []
def build_prompt( def build_prompt(self, question: str, context: str = None, funny_prompt_chance: float = 0.7) -> list:
self, """
question: str, Builds a complete system prompt based on the provided question and context, with a certain chance to completely
context: str = None, replace the prompt with a funny prompt.
funny_prompt_chance: float = 0.7
) -> list: Parameters:
""" question (str): The question asked by the user.
Builds a complete system prompt based on the provided question and context, with a certain chance to completely context (str, optional): The context of the question, defaults to None.
replace the prompt with a funny prompt. funny_prompt_chance (float, optional): The chance (0.0 to 1.0) to use a funny prompt instead of a normal one, defaults to 0.7.
Parameters: Returns:
question (str): The question asked by the user. list: A list containing a single chat object with the role and constructed prompt.
context (str, optional): The context of the question, defaults to None. """
funny_prompt_chance (float, optional): The chance (0.0 to 1.0) to use a funny prompt instead of a normal one, defaults to 0.7.
if context is None or context == "":
Returns: prompt = MODEL_CONFIG["templateNoContext"]
list: A list containing a single chat object with the role and constructed prompt. else:
""" prompt = MODEL_CONFIG["templateWithContext"]
prompt = prompt.replace("{context}", context)
if context is None or context == "":
prompt = MODEL_CONFIG['templateNoContext'] # replace {question} and {context} in the prompt
else: system_prompt = prompt.replace("{question}", question)
prompt = MODEL_CONFIG['templateWithContext']
prompt = prompt.replace("{context}", context) # Replace the entire prompt with a funny prompt sometimes
if (self.funnyPrompts or self.usedFunnyPrompts) and random.random() < funny_prompt_chance:
# replace {question} and {context} in the prompt system_prompt = self._get_random_funny_prompt()
system_prompt = prompt.replace("{question}", question) print("Using funny prompt instead of normal prompt.")
# Replace the entire prompt with a funny prompt sometimes # system_prompt = question
if (self.funnyPrompts or self.usedFunnyPrompts) and random.random() < funny_prompt_chance: # print("Final system_prompt: ", system_prompt)
system_prompt = self._get_random_funny_prompt()
print("Using funny prompt instead of normal prompt.") # Build the chat object
chat = [{"role": self.role, "content": system_prompt}]
# system_prompt = question
# print("Final system_prompt: ", system_prompt) return chat
# Build the chat object def _get_random_funny_prompt(self) -> str:
chat = [{"role": self.role, "content": system_prompt}] """
Selects a random funny prompt from the configured list, ensuring it hasn't been used yet.
return chat
Returns:
def _get_random_funny_prompt(self) -> str: str: The selected funny prompt.
""" """
Selects a random funny prompt from the configured list, ensuring it hasn't been used yet. if not self.funnyPrompts:
# If all prompts have been used, reset the list
Returns: self.funnyPrompts, self.usedFunnyPrompts = self.usedFunnyPrompts, []
str: The selected funny prompt.
""" # Get a random prompt based on weights
if not self.funnyPrompts: weights = [p["weight"] for p in self.funnyPrompts]
# If all prompts have been used, reset the list chosen_prompt = random.choices(self.funnyPrompts, weights=weights, k=1)[0]
self.funnyPrompts, self.usedFunnyPrompts = self.usedFunnyPrompts, []
# Remove the chosen prompt from the list
# Get a random prompt based on weights self.funnyPrompts = [p for p in self.funnyPrompts if p != chosen_prompt]
weights = [p["weight"] for p in self.funnyPrompts] # Add the chosen prompt to the used prompts list
chosen_prompt = random.choices(self.funnyPrompts, weights=weights, k=1)[0] self.usedFunnyPrompts.append(chosen_prompt)
# Remove the chosen prompt from the list return chosen_prompt
self.funnyPrompts = [p for p in self.funnyPrompts if p != chosen_prompt]
# Add the chosen prompt to the used prompts list
self.usedFunnyPrompts.append(chosen_prompt)
return chosen_prompt
from . import context_extender
from . import encoder
from . import fais_db
from . import document_processor
from . import rag_context_retriever
from . import reranker
from . import simple_document_index
from . import rag_utils
...@@ -67,7 +67,7 @@ class ContextExtender: ...@@ -67,7 +67,7 @@ class ContextExtender:
""" """
# Get the filename from the document's metadata # Get the filename from the document's metadata
filename = doc.metadata['source'] filename = doc.metadata["source"]
# Use the IndexHelper to get the full path of the file # Use the IndexHelper to get the full path of the file
file_paths = self.index_helper.get_file_paths() file_paths = self.index_helper.get_file_paths()
...@@ -81,11 +81,12 @@ class ContextExtender: ...@@ -81,11 +81,12 @@ class ContextExtender:
loader = UnstructuredMarkdownLoader(file_path) loader = UnstructuredMarkdownLoader(file_path)
else: else:
raise ValueError( raise ValueError(
f"Unsupported file format: only PDF and HTML files support is implemented. Tried to load: {file_path}") f"Unsupported file format: only PDF and HTML files support is implemented. Tried to load: {file_path}"
)
documents = loader.load() documents = loader.load()
# Concatenate the content of all documents # Concatenate the content of all documents
full_content = ' '.join([document.page_content for document in documents]) full_content = " ".join([document.page_content for document in documents])
# Create a new document with the loaded content and the original document's metadata # Create a new document with the loaded content and the original document's metadata
full_doc = Document(page_content=full_content, metadata=doc.metadata) full_doc = Document(page_content=full_content, metadata=doc.metadata)
......
...@@ -29,9 +29,7 @@ class DocumentProcessor: ...@@ -29,9 +29,7 @@ class DocumentProcessor:
self.document_index = document_index self.document_index = document_index
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( self.text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=AutoTokenizer.from_pretrained( tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L12-v2"),
"sentence-transformers/all-MiniLM-L12-v2"
),
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
chunk_overlap=int(self.chunk_size / 10), chunk_overlap=int(self.chunk_size / 10),
strip_whitespace=True, strip_whitespace=True,
...@@ -64,8 +62,7 @@ class DocumentProcessor: ...@@ -64,8 +62,7 @@ class DocumentProcessor:
""" """
print("RAG: Loading and splitting documents...") print("RAG: Loading and splitting documents...")
pages_pdf = [ pages_pdf = [
page for file_path in file_paths if file_path.endswith(".pdf") for page in page for file_path in file_paths if file_path.endswith(".pdf") for page in PyPDFLoader(file_path).load()
PyPDFLoader(file_path).load()
] ]
pages_html = [] pages_html = []
...@@ -105,7 +102,7 @@ class DocumentProcessor: ...@@ -105,7 +102,7 @@ class DocumentProcessor:
# Append additional metadata from our index.yaml to each document # Append additional metadata from our index.yaml to each document
for doc in docs: for doc in docs:
# Get the filename from the source path # Get the filename from the source path
filename = os.path.basename(doc.metadata['source']) filename = os.path.basename(doc.metadata["source"])
# Get the corresponding metadata from the index.yaml # Get the corresponding metadata from the index.yaml
metadata = self.document_index.get_metadata(filename) metadata = self.document_index.get_metadata(filename)
# Merge the metadata into the document # Merge the metadata into the document
......
...@@ -4,29 +4,25 @@ from config import MODEL_CACHE_DIR ...@@ -4,29 +4,25 @@ from config import MODEL_CACHE_DIR
class Encoder: class Encoder:
""" """
A class to handle the encoding of text using models from Hugging Face's transformers library. A class to handle the encoding of text using models from Hugging Face's transformers library.
Used for generating embeddings for text for RAG search. Used for generating embeddings for text for RAG search.
Attributes: Attributes:
embedding_function (HuggingFaceEmbeddings): An instance of HuggingFaceEmbeddings used for generating embeddings. embedding_function (HuggingFaceEmbeddings): An instance of HuggingFaceEmbeddings used for generating embeddings.
Args: Args:
model_name (str): The name of the sentence transformer model to be used for embeddings. model_name (str): The name of the sentence transformer model to be used for embeddings.
Defaults to "sentence-transformers/all-MiniLM-L12-v2". Defaults to "sentence-transformers/all-MiniLM-L12-v2".
device (str): The device to run the model on. Defaults to "gpu". device (str): The device to run the model on. Defaults to "gpu".
""" """
def __init__( def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device="gpu"):
self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device="gpu" self.embedding_function = HuggingFaceEmbeddings(
): model_name=model_name,
self.embedding_function = HuggingFaceEmbeddings( cache_folder=MODEL_CACHE_DIR,
model_name=model_name, model_kwargs={"device": device},
cache_folder=MODEL_CACHE_DIR, )
model_kwargs={"device": device},
) def get(self) -> HuggingFaceEmbeddings:
return self.embedding_function
def get(self) -> HuggingFaceEmbeddings:
return self.embedding_function
...@@ -55,9 +55,7 @@ class FaissDb: ...@@ -55,9 +55,7 @@ class FaissDb:
Parameters: Parameters:
docs (List[Document]): A list of documents to be added to the FAISS database. docs (List[Document]): A list of documents to be added to the FAISS database.
""" """
self.db = FAISS.from_documents( self.db = FAISS.from_documents(docs, self.embedding_function, distance_strategy=DistanceStrategy.COSINE)
docs, self.embedding_function, distance_strategy=DistanceStrategy.COSINE
)
print("Loaded documents into new FAISS DB - Done!") print("Loaded documents into new FAISS DB - Done!")
def similarity_search(self, question: str, k: int = 3) -> List[Document]: def similarity_search(self, question: str, k: int = 3) -> List[Document]:
......
...@@ -28,14 +28,14 @@ class RagContextRetriever: ...@@ -28,14 +28,14 @@ class RagContextRetriever:
""" """
def __init__( def __init__(
self, self,
rag_db: FaissDb, rag_db: FaissDb,
reranker: ReRanker, reranker: ReRanker,
context_extender: ContextExtender, context_extender: ContextExtender,
k1: int, k1: int,
k2: int, k2: int,
k2_threshold: float, k2_threshold: float,
expand_context: bool expand_context: bool,
): ):
""" """
Initializes the RagContextRetriever with the necessary components and parameters. Initializes the RagContextRetriever with the necessary components and parameters.
...@@ -90,7 +90,7 @@ class RagContextRetriever: ...@@ -90,7 +90,7 @@ class RagContextRetriever:
display_documents_as_dataframe(retrieved_docs) display_documents_as_dataframe(retrieved_docs)
# Apply filter to remove documents with scores below the threshold # Apply filter to remove documents with scores below the threshold
retrieved_docs = [doc for doc in retrieved_docs if doc.metadata.get('score') > self.k2_threshold] retrieved_docs = [doc for doc in retrieved_docs if doc.metadata.get("score") > self.k2_threshold]
if debug_mode: if debug_mode:
st.write("##### Document chunks after filtering by threshold") st.write("##### Document chunks after filtering by threshold")
...@@ -105,11 +105,13 @@ class RagContextRetriever: ...@@ -105,11 +105,13 @@ class RagContextRetriever:
if debug_mode: if debug_mode:
st.write("##### Documents after expanding") st.write("##### Documents after expanding")
st.write(f"_Filtered for threshold {self.k2_threshold}, grouped by title, keeping only the highest " st.write(
f"scoring document, expanding the context to the full source document._") f"_Filtered for threshold {self.k2_threshold}, grouped by title, keeping only the highest "
f"scoring document, expanding the context to the full source document._"
)
display_documents_as_dataframe(retrieved_docs) display_documents_as_dataframe(retrieved_docs)
# Convert the retrieved documents to a single string # Convert the retrieved documents to a single string
context = ("".join(doc.page_content + "\n" for doc in retrieved_docs)) context = "".join(doc.page_content + "\n" for doc in retrieved_docs)
return context, retrieved_docs return context, retrieved_docs
...@@ -12,11 +12,11 @@ def filter_highest_scoring_entries(retrieved_docs: List[Document]) -> List[Docum ...@@ -12,11 +12,11 @@ def filter_highest_scoring_entries(retrieved_docs: List[Document]) -> List[Docum
highest_scoring_docs = {} highest_scoring_docs = {}
for doc in retrieved_docs: for doc in retrieved_docs:
title = doc.metadata.get('title') title = doc.metadata.get("title")
score = doc.metadata.get('score') score = doc.metadata.get("score")
# If the title is not in the dictionary or the current score is higher than the stored one, update it # If the title is not in the dictionary or the current score is higher than the stored one, update it
if title not in highest_scoring_docs or score > highest_scoring_docs[title].metadata.get('score'): if title not in highest_scoring_docs or score > highest_scoring_docs[title].metadata.get("score"):
highest_scoring_docs[title] = doc highest_scoring_docs[title] = doc
# Return the values of the dictionary as a list # Return the values of the dictionary as a list
...@@ -35,13 +35,12 @@ def display_documents_as_dataframe(documents, show_score=True) -> None: ...@@ -35,13 +35,12 @@ def display_documents_as_dataframe(documents, show_score=True) -> None:
doc_data = [] doc_data = []
for doc in documents: for doc in documents:
doc_dict = { doc_dict = {
"Title": doc.metadata.get('title'), "Title": doc.metadata.get("title"),
"URL": doc.metadata.get('url'), "URL": doc.metadata.get("url"),
"Content": doc.page_content, "Content": doc.page_content,
} }
if show_score: if show_score:
score = doc.metadata.get('score', 0) score = doc.metadata.get("score", 0)
doc_dict["Score"] = round(score, 2) if score is not None else 0 doc_dict["Score"] = round(score, 2) if score is not None else 0
doc_data.append(doc_dict) doc_data.append(doc_dict)
st.dataframe(doc_data) st.dataframe(doc_data)
...@@ -19,14 +19,9 @@ class ReRanker: ...@@ -19,14 +19,9 @@ class ReRanker:
""" """
def __init__( def __init__(
self, self, encoder_name: str = "cross-encoder/ms-marco-TinyBERT-L-2-v2", max_length: int = 512, device: str = "cpu"
encoder_name: str = "cross-encoder/ms-marco-TinyBERT-L-2-v2",
max_length: int = 512,
device: str = "cpu"
): ):
self.cross_encoder = CrossEncoder( self.cross_encoder = CrossEncoder(encoder_name, max_length=max_length, device=device)
encoder_name, max_length=max_length, device=device
)
def rerank(self, question: str, docs: list, k: int = 3) -> List[Document]: def rerank(self, question: str, docs: list, k: int = 3) -> List[Document]:
""" """
...@@ -41,9 +36,7 @@ class ReRanker: ...@@ -41,9 +36,7 @@ class ReRanker:
List[Document]: A list of the top k documents ranked by their relevance to the question. List[Document]: A list of the top k documents ranked by their relevance to the question.
""" """
scores = self.cross_encoder.predict( scores = self.cross_encoder.predict([(question, doc.page_content) for doc in docs])
[(question, doc.page_content) for doc in docs]
)
# Pair each document with its score # Pair each document with its score
scored_docs = list(zip(docs, scores)) scored_docs = list(zip(docs, scores))
...@@ -59,7 +52,7 @@ class ReRanker: ...@@ -59,7 +52,7 @@ class ReRanker:
for doc_score in scored_docs: for doc_score in scored_docs:
doc = doc_score[0] doc = doc_score[0]
score = doc_score[1] score = doc_score[1]
doc.metadata['score'] = score doc.metadata["score"] = score
documents_only.append(doc) documents_only.append(doc)
return documents_only return documents_only
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment