| """ |
| Multi-Domain Classifier - Inference Example |
| Repository: https://huggingface.co/ovinduG/multi-domain-classifier-phi3 |
| """ |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
| import torch |
| import json |
|
|
| class MultiDomainClassifier: |
| def __init__(self, model_id="ovinduG/multi-domain-classifier-phi3"): |
| print("Loading model...") |
| |
| |
| self.base_model = AutoModelForCausalLM.from_pretrained( |
| "microsoft/Phi-3-mini-4k-instruct", |
| torch_dtype=torch.bfloat16, |
| device_map="auto" |
| ) |
| |
| |
| self.model = PeftModel.from_pretrained(self.base_model, model_id) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| self.model.eval() |
| |
| print("✅ Model loaded!") |
| |
| def predict(self, query: str) -> dict: |
| """Classify a query into domains""" |
| |
| prompt = f"""Classify this query: {query} |
| |
| Output JSON format: |
| { |
| "primary_domain": "domain_name", |
| "primary_confidence": 0.95, |
| "is_multi_domain": true/false, |
| "secondary_domains": [{"domain": "name", "confidence": 0.85}] |
| }""" |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=200, |
| temperature=0.1, |
| do_sample=False, |
| use_cache=False |
| ) |
| |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| try: |
| json_str = response.split("Output JSON format:")[-1].strip() |
| result = json.loads(json_str) |
| return result |
| except: |
| return {"error": "Failed to parse response", "raw": response} |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| classifier = MultiDomainClassifier() |
| |
| |
| queries = [ |
| "Write a Python function to calculate factorial", |
| "Build ML model to analyze sales data and create API endpoints", |
| "What is quantum entanglement?", |
| "Create a REST API for healthcare diabetes prediction" |
| ] |
| |
| print("\n" + "="*80) |
| print("CLASSIFICATION EXAMPLES") |
| print("="*80) |
| |
| for query in queries: |
| print(f"\nQuery: {query}") |
| result = classifier.predict(query) |
| print(f"Result: {json.dumps(result, indent=2)}") |
| print("-"*80) |
|
|