| | from anthropic import Anthropic |
| | from google import genai |
| | from google.genai import types |
| | from openai import OpenAI |
| | import time |
| |
|
| |
|
| | def get_client(messages, cfg): |
| | if 'gpt' in cfg.model.family_name or cfg.model.family_name == 'o': |
| | client = OpenAI(api_key=cfg.model.api_key) |
| | elif 'claude' in cfg.model.family_name: |
| | client = Anthropic(api_key=cfg.model.api_key) |
| | elif 'deepseek' in cfg.model.family_name: |
| | client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url) |
| | elif 'gemini' in cfg.model.family_name: |
| | client = genai.Client(api_key=cfg.model.api_key) |
| | elif cfg.model.family_name == 'qwen': |
| | client = OpenAI(api_key=cfg.model.api_key, base_url=cfg.model.base_url) |
| | else: |
| | raise ValueError(f'Model {cfg.model.family_name} not recognized') |
| | return client |
| |
|
| |
|
| | def generate_response(messages, cfg): |
| | client = get_client(messages, cfg) |
| | model_name = cfg.model.name |
| | if 'o1' in model_name or 'o3' in model_name or 'o4' in model_name: |
| | |
| | if 'o1' in model_name and len(messages)>0 and messages[0]['role'] == 'system': |
| | system_prompt = messages[0]['content'] |
| | messages = messages[1:] |
| | messages[0]['content'] = system_prompt + messages[0]['content'] |
| | |
| | num_tokens = 16384 |
| | temperature = 1.0 |
| | start_time = time.time() |
| | response = client.chat.completions.create( |
| | model=model_name, |
| | messages=messages, |
| | max_completion_tokens=num_tokens, |
| | temperature=temperature) |
| | end_time = time.time() |
| | print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.') |
| | return response |
| | |
| | if 'claude' in model_name: |
| | num_tokens = 8192 |
| | temperature = 0.7 |
| |
|
| | if len(messages)>0 and messages[0]['role'] == 'system': |
| | system_prompt = messages[0]['content'] |
| | messages = messages[1:] |
| |
|
| | start_time = time.time() |
| | if cfg.model.thinking: |
| | num_thinking_tokens = 12288 |
| | response = client.messages.create( |
| | model=model_name, |
| | max_tokens=num_tokens+num_thinking_tokens, |
| | thinking= {"type": "enabled", "budget_tokens": num_thinking_tokens}, |
| | system=system_prompt, |
| | messages=messages, |
| | |
| | temperature=1.0, |
| | ) |
| | else: |
| | response = client.messages.create( |
| | model=model_name, |
| | max_tokens=num_tokens, |
| | system=system_prompt, |
| | messages=messages, |
| | temperature=temperature, |
| | ) |
| | end_time = time.time() |
| | print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.') |
| | return response |
| |
|
| | if 'gemini' in model_name: |
| | start_time = time.time() |
| | if len(messages)>0 and messages[0]['role'] == 'system': |
| | |
| | system_prompt = messages[0]['content'] |
| | messages = messages[1:] |
| | messages[0]['content'] = system_prompt + messages[0]['content'] |
| |
|
| | for message in messages: |
| | if message['role'] == 'assistant': |
| | message['role'] = 'model' |
| | |
| | chat = client.chats.create( |
| | model=model_name, |
| | history=[ |
| | types.Content(role=message['role'], parts=[types.Part(text=message['content'])]) |
| | for message in messages[:-1] |
| | ], |
| | ) |
| | response = chat.send_message(message=messages[-1]['content']) |
| | end_time = time.time() |
| | print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.') |
| | return response |
| |
|
| | num_tokens = 4096 |
| | temperature = 0.7 |
| |
|
| | start_time = time.time() |
| | response = client.chat.completions.create( |
| | model=model_name, |
| | messages=messages, |
| | max_tokens=num_tokens, |
| | temperature=temperature, |
| | stream=('qwq' in model_name), |
| | ) |
| |
|
| | if 'qwq' in model_name: |
| | answer_content = "" |
| | for chunk in response: |
| | if chunk.choices: |
| | delta = chunk.choices[0].delta |
| | if hasattr(delta, 'reasoning_content') and delta.reasoning_content != None: |
| | |
| | pass |
| | else: |
| | answer_content += delta.content |
| | response = answer_content |
| |
|
| | end_time = time.time() |
| | print(f'It takes {model_name} {end_time - start_time:.2f}s to generate the response.') |
| | return response |
| |
|
| |
|