| from pydantic import BaseModel |
| from transformers import (PreTrainedTokenizerFast, StoppingCriteria) |
|
|
|
|
| def fallback(value, fallback_value): |
| if value is None: |
| return fallback_value |
| return value |
|
|
|
|
| class Body(BaseModel): |
| prompt: str |
| posts_count: int |
| max_length: int | None = None |
| temperature: float | None = None |
| top_p: float | None = None |
| top_k: float | None = None |
| repetition_penalty: float | None = None |
| no_repeat_ngram_size: float | None = None |
| do_sample: bool | None = None |
|
|
|
|
| class MaxPostsStoppingCriteria(StoppingCriteria): |
| def __init__(self, tokenizer: PreTrainedTokenizerFast, posts_count: int): |
| self.end_of_post_token_id = tokenizer.encode("<|end_of_post|>", add_special_tokens=False) |
| self.posts_count = posts_count |
| self.counter = 0 |
|
|
| def __call__(self, input_ids, scores, **kwargs): |
| |
| for sequence in input_ids: |
| if sequence[-len(self.end_of_post_token_id):].tolist() == self.end_of_post_token_id: |
| self.counter += 1 |
| if self.counter >= self.posts_count: |
| return True |
| return False |
|
|