| from torch_grammar import GrammarSampler |
| from transformers.generation.logits_process import LogitsProcessor |
|
|
| from modules import shared |
|
|
| sampler = None |
| grammar = None |
| grammar_string = '' |
|
|
|
|
| class GrammarLogitsProcessor(LogitsProcessor): |
| def __init__(self, string): |
|
|
| global sampler, grammar, grammar_string |
|
|
| if string != grammar_string: |
| grammar_string = string |
| if string.strip() != '': |
| string = string.strip() + '\n' |
| sampler = GrammarSampler(string, 'root', shared.tokenizer) |
| else: |
| sampler = None |
|
|
| if sampler is not None: |
| grammar = sampler.logits_processor() |
| else: |
| grammar = None |
|
|
| def __call__(self, input_ids, scores): |
| if grammar is not None: |
| scores = grammar(input_ids, scores) |
|
|
| return scores |
|
|