From 05956bedf1b6f36ee8a6eb51aa65ec9a36c9f455 Mon Sep 17 00:00:00 2001
From: Valentijn Evers <valentijn.evers@wur.nl>
Date: Tue, 3 Sep 2024 12:04:18 +0200
Subject: [PATCH] L03DIGLIB-1347 - Optimized tokenizer settings - Added max
 tokens length & enabled truncation - Increased max timeout to 30s for text
 streamer

---
 src/model/llm_model.py | 23 +++++++++++++++++------
 1 file changed, 17 insertions(+), 6 deletions(-)

diff --git a/src/model/llm_model.py b/src/model/llm_model.py
index e810820..10d8ea2 100644
--- a/src/model/llm_model.py
+++ b/src/model/llm_model.py
@@ -31,9 +31,6 @@ class LlmModel:
 			model_id: str = "google/gemma-2-9b-it",
 			device: str = "cuda"
 	):
-		self.tokenizer = AutoTokenizer.from_pretrained(
-			model_id, cache_dir=MODEL_CACHE_DIR, token=HF_ACCESS_TOKEN
-		)
 		quantization_config = BitsAndBytesConfig(
 			load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
 		)
@@ -45,6 +42,16 @@ class LlmModel:
 			cache_dir=MODEL_CACHE_DIR,
 			token=HF_ACCESS_TOKEN,
 		)
+
+		self.tokenizer = AutoTokenizer.from_pretrained(
+			model_id,
+			cache_dir=MODEL_CACHE_DIR,
+			token=HF_ACCESS_TOKEN,
+			padding=True,
+			truncation=True,
+			max_length=self.model.config.max_position_embeddings
+		)
+
 		self.model.eval()
 		self.device = device
 		self.prompt_builder = PromptBuilder()
@@ -74,16 +81,20 @@ class LlmModel:
 
 		# Apply the chat template to the prompt
 		formatted_prompt = self.tokenizer.apply_chat_template(
-			chat, tokenize=False, add_generation_prompt=True
+			chat,
+			tokenize=False,
+			add_generation_prompt=True,
 		)
 
 		# Tokenize the system prompt
 		inputs = self.tokenizer.encode(
-			formatted_prompt, add_special_tokens=False, return_tensors="pt"
+			formatted_prompt,
+			add_special_tokens=False,
+			return_tensors="pt",
 		).to(self.device)
 
 		# Create a stream, which will be used to fetch the generated text in a non-blocking way.
-		self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=10)
+		self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=30)
 
 		# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
 		generation_kwargs = dict(inputs=inputs, max_new_tokens=max_new_tokens, streamer=self.streamer)
-- 
GitLab