Skip to content
Snippets Groups Projects
Commit 05956bed authored by Evers, Valentijn's avatar Evers, Valentijn
Browse files

L03DIGLIB-1347 - Optimized tokenizer settings

- Added max tokens length & enabled truncation
- Increased max timeout to 30s for text streamer
parent 803ea7a6
No related branches found
No related tags found
No related merge requests found
...@@ -31,9 +31,6 @@ class LlmModel: ...@@ -31,9 +31,6 @@ class LlmModel:
model_id: str = "google/gemma-2-9b-it", model_id: str = "google/gemma-2-9b-it",
device: str = "cuda" device: str = "cuda"
): ):
self.tokenizer = AutoTokenizer.from_pretrained(
model_id, cache_dir=MODEL_CACHE_DIR, token=HF_ACCESS_TOKEN
)
quantization_config = BitsAndBytesConfig( quantization_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16
) )
...@@ -45,6 +42,16 @@ class LlmModel: ...@@ -45,6 +42,16 @@ class LlmModel:
cache_dir=MODEL_CACHE_DIR, cache_dir=MODEL_CACHE_DIR,
token=HF_ACCESS_TOKEN, token=HF_ACCESS_TOKEN,
) )
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
cache_dir=MODEL_CACHE_DIR,
token=HF_ACCESS_TOKEN,
padding=True,
truncation=True,
max_length=self.model.config.max_position_embeddings
)
self.model.eval() self.model.eval()
self.device = device self.device = device
self.prompt_builder = PromptBuilder() self.prompt_builder = PromptBuilder()
...@@ -74,16 +81,20 @@ class LlmModel: ...@@ -74,16 +81,20 @@ class LlmModel:
# Apply the chat template to the prompt # Apply the chat template to the prompt
formatted_prompt = self.tokenizer.apply_chat_template( formatted_prompt = self.tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True chat,
tokenize=False,
add_generation_prompt=True,
) )
# Tokenize the system prompt # Tokenize the system prompt
inputs = self.tokenizer.encode( inputs = self.tokenizer.encode(
formatted_prompt, add_special_tokens=False, return_tensors="pt" formatted_prompt,
add_special_tokens=False,
return_tensors="pt",
).to(self.device) ).to(self.device)
# Create a stream, which will be used to fetch the generated text in a non-blocking way. # Create a stream, which will be used to fetch the generated text in a non-blocking way.
self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=10) self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=30)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way. # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs=inputs, max_new_tokens=max_new_tokens, streamer=self.streamer) generation_kwargs = dict(inputs=inputs, max_new_tokens=max_new_tokens, streamer=self.streamer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment