| """ |
| ML Model Loader and Utilities |
| Handles loading and using the conflict prediction model and package embeddings. |
| Loads from local files if available, otherwise downloads from Hugging Face Hub. |
| """ |
|
|
| import json |
| import pickle |
| from pathlib import Path |
| from typing import Dict, List, Tuple, Optional |
| import numpy as np |
| from packaging.requirements import Requirement |
|
|
| |
| try: |
| from huggingface_hub import hf_hub_download |
| HF_HUB_AVAILABLE = True |
| except ImportError: |
| HF_HUB_AVAILABLE = False |
| print("Warning: huggingface_hub not available. Models must be loaded locally.") |
|
|
|
|
| class ConflictPredictor: |
| """Load and use the conflict prediction model.""" |
| |
| def __init__(self, model_path: Optional[Path] = None, repo_id: str = "ysakhale/dependency-conflict-models"): |
| """Initialize the conflict predictor. |
| |
| Args: |
| model_path: Local path to model file (optional) |
| repo_id: Hugging Face repository ID to download from if local file not found |
| """ |
| self.repo_id = repo_id |
| self.model = None |
| self.model_path = model_path |
| |
| |
| if model_path is None: |
| model_path = Path(__file__).parent / "models" / "conflict_predictor.pkl" |
| |
| self.model_path = model_path |
| |
| |
| if model_path.exists(): |
| try: |
| with open(model_path, 'rb') as f: |
| self.model = pickle.load(f) |
| print(f"Loaded conflict prediction model from {model_path}") |
| return |
| except Exception as e: |
| print(f"Could not load conflict prediction model from local: {e}") |
| |
| |
| if HF_HUB_AVAILABLE: |
| try: |
| print(f"Model not found locally. Downloading from Hugging Face Hub: {repo_id}") |
| downloaded_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="conflict_predictor.pkl", |
| repo_type="model" |
| ) |
| with open(downloaded_path, 'rb') as f: |
| self.model = pickle.load(f) |
| print(f"Loaded conflict prediction model from Hugging Face Hub") |
| |
| try: |
| model_path.parent.mkdir(parents=True, exist_ok=True) |
| import shutil |
| shutil.copy(downloaded_path, model_path) |
| print(f"Cached model locally at {model_path}") |
| except: |
| pass |
| return |
| except Exception as e: |
| print(f"Could not download model from Hugging Face Hub: {e}") |
| |
| print(f"Warning: Conflict prediction model not available") |
| |
| def extract_features(self, requirements_text: str) -> np.ndarray: |
| """Extract features from requirements text (same as training).""" |
| features = [] |
| |
| packages = {} |
| lines = requirements_text.strip().split('\n') |
| num_packages = 0 |
| has_pins = 0 |
| version_specificity = [] |
| |
| for line in lines: |
| line = line.strip() |
| if not line or line.startswith('#'): |
| continue |
| |
| try: |
| req = Requirement(line) |
| pkg_name = req.name.lower() |
| specifier = str(req.specifier) if req.specifier else '' |
| |
| if pkg_name in packages: |
| features.append(1) |
| else: |
| packages[pkg_name] = specifier |
| num_packages += 1 |
| |
| if specifier: |
| has_pins += 1 |
| if '==' in specifier: |
| version_specificity.append(3) |
| elif '>=' in specifier or '<=' in specifier: |
| version_specificity.append(2) |
| else: |
| version_specificity.append(1) |
| else: |
| version_specificity.append(0) |
| except: |
| pass |
| |
| feature_vec = [] |
| feature_vec.append(min(num_packages / 20.0, 1.0)) |
| feature_vec.append(has_pins / max(num_packages, 1)) |
| feature_vec.append(np.mean(version_specificity) / 3.0 if version_specificity else 0) |
| feature_vec.append(1 if len(packages) < num_packages else 0) |
| |
| common_packages = [ |
| 'torch', 'pytorch-lightning', 'tensorflow', 'keras', 'fastapi', 'pydantic', |
| 'numpy', 'pandas', 'scipy', 'scikit-learn', 'matplotlib', 'seaborn', |
| 'requests', 'httpx', 'sqlalchemy', 'alembic', 'uvicorn', 'starlette', |
| 'langchain', 'openai', 'chromadb', 'redis', 'celery', 'gunicorn', |
| 'pillow', 'opencv-python', 'beautifulsoup4', 'scrapy', 'plotly', 'jax' |
| ] |
| |
| for pkg in common_packages: |
| feature_vec.append(1 if pkg in packages else 0) |
| |
| has_torch = 'torch' in packages |
| has_pl = 'pytorch-lightning' in packages |
| has_tf = 'tensorflow' in packages |
| has_keras = 'keras' in packages |
| has_fastapi = 'fastapi' in packages |
| has_pydantic = 'pydantic' in packages |
| |
| feature_vec.append(1 if (has_torch and has_pl) else 0) |
| feature_vec.append(1 if (has_tf and has_keras) else 0) |
| feature_vec.append(1 if (has_fastapi and has_pydantic) else 0) |
| |
| return np.array(feature_vec) |
| |
| def predict(self, requirements_text: str) -> Tuple[bool, float]: |
| """ |
| Predict if requirements have conflicts. |
| |
| Returns: |
| (has_conflict, confidence_score) |
| """ |
| if self.model is None: |
| return False, 0.0 |
| |
| try: |
| features = self.extract_features(requirements_text) |
| features = features.reshape(1, -1) |
| |
| prediction = self.model.predict(features)[0] |
| probability = self.model.predict_proba(features)[0] |
| |
| has_conflict = bool(prediction) |
| confidence = float(probability[1] if has_conflict else probability[0]) |
| |
| return has_conflict, confidence |
| except Exception as e: |
| print(f"Error in conflict prediction: {e}") |
| return False, 0.0 |
|
|
|
|
| class PackageEmbeddings: |
| """Load and use package embeddings for similarity matching.""" |
| |
| def __init__(self, embeddings_path: Optional[Path] = None, repo_id: str = "ysakhale/dependency-conflict-models"): |
| """Initialize package embeddings. |
| |
| Args: |
| embeddings_path: Local path to embeddings file (optional) |
| repo_id: Hugging Face repository ID to download from if local file not found |
| """ |
| self.repo_id = repo_id |
| self.embeddings = {} |
| self.embeddings_path = embeddings_path |
| self.model = None |
| |
| if embeddings_path is None: |
| embeddings_path = Path(__file__).parent / "models" / "package_embeddings.json" |
| |
| self.embeddings_path = embeddings_path |
| |
| |
| if embeddings_path.exists(): |
| try: |
| with open(embeddings_path, 'r') as f: |
| self.embeddings = json.load(f) |
| print(f"Loaded {len(self.embeddings)} package embeddings from {embeddings_path}") |
| return |
| except Exception as e: |
| print(f"Could not load embeddings from local: {e}") |
| |
| |
| if HF_HUB_AVAILABLE: |
| try: |
| print(f"Embeddings not found locally. Downloading from Hugging Face Hub: {repo_id}") |
| downloaded_path = hf_hub_download( |
| repo_id=repo_id, |
| filename="package_embeddings.json", |
| repo_type="model" |
| ) |
| with open(downloaded_path, 'r') as f: |
| self.embeddings = json.load(f) |
| print(f"Loaded {len(self.embeddings)} package embeddings from Hugging Face Hub") |
| |
| try: |
| embeddings_path.parent.mkdir(parents=True, exist_ok=True) |
| import shutil |
| shutil.copy(downloaded_path, embeddings_path) |
| print(f"Cached embeddings locally at {embeddings_path}") |
| except: |
| pass |
| return |
| except Exception as e: |
| print(f"Could not download embeddings from Hugging Face Hub: {e}") |
| |
| print(f"Warning: Package embeddings not available") |
| |
| def _load_model(self): |
| """Lazy load the sentence transformer model.""" |
| if self.model is None: |
| try: |
| from sentence_transformers import SentenceTransformer |
| self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
| except ImportError: |
| print("⚠️ sentence-transformers not available, embedding similarity disabled") |
| return None |
| return self.model |
| |
| def get_embedding(self, package_name: str) -> Optional[np.ndarray]: |
| """Get embedding for a package (from cache or compute on-the-fly).""" |
| package_lower = package_name.lower() |
| |
| |
| if package_lower in self.embeddings: |
| return np.array(self.embeddings[package_lower]) |
| |
| |
| model = self._load_model() |
| if model is not None: |
| embedding = model.encode([package_name])[0] |
| |
| self.embeddings[package_lower] = embedding.tolist() |
| return embedding |
| |
| return None |
| |
| def find_similar(self, package_name: str, top_k: int = 5, threshold: float = 0.6) -> List[Tuple[str, float]]: |
| """ |
| Find similar packages using cosine similarity. |
| |
| Returns: |
| List of (package_name, similarity_score) tuples |
| """ |
| query_emb = self.get_embedding(package_name) |
| if query_emb is None: |
| return [] |
| |
| similarities = [] |
| |
| for pkg, emb in self.embeddings.items(): |
| if pkg == package_name.lower(): |
| continue |
| |
| emb_array = np.array(emb) |
| |
| similarity = np.dot(query_emb, emb_array) / ( |
| np.linalg.norm(query_emb) * np.linalg.norm(emb_array) |
| ) |
| |
| if similarity >= threshold: |
| similarities.append((pkg, float(similarity))) |
| |
| |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| return similarities[:top_k] |
| |
| def get_best_match(self, package_name: str, threshold: float = 0.7) -> Optional[str]: |
| """Get the best matching package name.""" |
| similar = self.find_similar(package_name, top_k=1, threshold=threshold) |
| if similar: |
| return similar[0][0] |
| return None |
|
|
|
|