From 05956bedf1b6f36ee8a6eb51aa65ec9a36c9f455 Mon Sep 17 00:00:00 2001 From: Valentijn Evers <valentijn.evers@wur.nl> Date: Tue, 3 Sep 2024 12:04:18 +0200 Subject: [PATCH] L03DIGLIB-1347 - Optimized tokenizer settings - Added max tokens length & enabled truncation - Increased max timeout to 30s for text streamer --- src/model/llm_model.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/model/llm_model.py b/src/model/llm_model.py index e810820..10d8ea2 100644 --- a/src/model/llm_model.py +++ b/src/model/llm_model.py @@ -31,9 +31,6 @@ class LlmModel: model_id: str = "google/gemma-2-9b-it", device: str = "cuda" ): - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, cache_dir=MODEL_CACHE_DIR, token=HF_ACCESS_TOKEN - ) quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 ) @@ -45,6 +42,16 @@ class LlmModel: cache_dir=MODEL_CACHE_DIR, token=HF_ACCESS_TOKEN, ) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, + cache_dir=MODEL_CACHE_DIR, + token=HF_ACCESS_TOKEN, + padding=True, + truncation=True, + max_length=self.model.config.max_position_embeddings + ) + self.model.eval() self.device = device self.prompt_builder = PromptBuilder() @@ -74,16 +81,20 @@ class LlmModel: # Apply the chat template to the prompt formatted_prompt = self.tokenizer.apply_chat_template( - chat, tokenize=False, add_generation_prompt=True + chat, + tokenize=False, + add_generation_prompt=True, ) # Tokenize the system prompt inputs = self.tokenizer.encode( - formatted_prompt, add_special_tokens=False, return_tensors="pt" + formatted_prompt, + add_special_tokens=False, + return_tensors="pt", ).to(self.device) # Create a stream, which will be used to fetch the generated text in a non-blocking way. - self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=10) + self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=30) # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. generation_kwargs = dict(inputs=inputs, max_new_tokens=max_new_tokens, streamer=self.streamer) -- GitLab