upload_model / inference_1.py
JackyChunKit's picture
Update inference_1.py
6ec27fe verified
from vllm import LLM, SamplingParams
import argparse
import json
def setup_model(model_path,tensor_parallel_size):
"""
Initialize the fine-tuned Qwen-2.5-7B model from a local path.
Args:
model_path: Path to the directory containing the trained model
"""
print(f"Loading fine-tuned Qwen model from: {model_path}")
# Initialize the model with VLLM using local path
# trust_remote_code=True is required for custom Qwen model code
llm = LLM(
model=model_path,
trust_remote_code=True,
# Optional parameters for performance tuning
tensor_parallel_size=args.tensor_parallel_size, # Use multiple GPUs
# dtype="bfloat16", # Use bfloat16 for more efficient inference
# gpu_memory_utilization=0.85 # Control memory usage
)
print("Model loaded successfully!")
return llm
def generate_response(llm, prompt, temperature=0.7, max_tokens=512, top_p=0.9):
"""Generate a response for a given prompt."""
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
def chat_completion(llm, messages, temperature=0.7, max_tokens=512):
"""Generate a chat completion from messages."""
sampling_params = SamplingParams(
temperature=temperature,
top_p=0.9,
max_tokens=max_tokens
)
# Convert messages to a prompt using the model's chat template
tokenizer = llm.get_tokenizer()
if hasattr(tokenizer, "apply_chat_template"):
# For newer transformers versions
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback for models without chat template
prompt = format_messages_manually(messages)
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
def format_messages_manually(messages):
"""Format messages manually if chat template is not available."""
formatted_prompt = ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
formatted_prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
formatted_prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
formatted_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
formatted_prompt += "<|im_start|>assistant\n"
return formatted_prompt
def batch_inference(llm, prompts, temperature=0.7, max_tokens=512):
"""Run batch inference on multiple prompts."""
sampling_params = SamplingParams(
temperature=temperature,
top_p=0.9,
max_tokens=max_tokens
)
outputs = llm.generate(prompts, sampling_params)
return [output.outputs[0].text for output in outputs]
def main():
parser = argparse.ArgumentParser(description="Inference with fine-tuned Qwen-2.5-7B model")
parser.add_argument("--model_path", required=True, help="Path to the fine-tuned model directory")
parser.add_argument("--mode", choices=["single", "chat", "batch"], default="single", help="Inference mode")
parser.add_argument("--prompt", help="Prompt for single inference mode")
parser.add_argument("--prompt_file", help="File containing prompts for batch mode (one per line)")
parser.add_argument("--output_file", help="Path to save JSON results (default: auto-generated)")
parser.add_argument("--tensor_parallel_size", help="GPU number")
args = parser.parse_args()
# Initialize the model
llm = setup_model(args.model_path, args.tensor_parallel_size)
if args.mode == "single":
if not args.prompt:
args.prompt = input("Enter your prompt: ")
print("\nGenerating response...")
response = generate_response(llm, args.prompt)
print(f"\nResponse:\n{response}")
elif args.mode == "chat":
messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
print("\nChat mode. Type 'exit' or 'quit' to end the conversation.\n")
while True:
user_input = input("\nYou: ")
if user_input.lower() in ["exit", "quit"]:
print("Goodbye!")
break
messages.append({"role": "user", "content": user_input})
response = chat_completion(llm, messages)
print(f"\nAssistant: {response}")
messages.append({"role": "assistant", "content": response})
elif args.mode == "batch":
if not args.prompt_file:
print("Error: --prompt_file required for batch mode")
return
with open(args.prompt_file, 'r', encoding='utf-8') as f:
prompts = json.load(f)
print(f"Running batch inference on {len(prompts)} prompts...")
responses = batch_inference(llm, prompts)
print(prompts[0])
print(responses[0])
with open(args.output_file, "w") as final:
json.dump(responses, final)
if __name__ == "__main__":
main()