import os import joblib import numpy as np import pandas as pd import cv2 import tensorflow as tf from patchify import patchify # 1. Define Custom Layers @tf.keras.utils.register_keras_serializable() class ClassToken(tf.keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) def build(self, input_shape): self.hidden_dim = input_shape[-1] self.w = self.add_weight( name="cls_token", shape=(1, 1, self.hidden_dim), initializer="random_normal", trainable=True, ) def call(self, inputs): batch_size = tf.shape(inputs)[0] cls = tf.broadcast_to(self.w, [batch_size, 1, self.hidden_dim]) return cls @tf.keras.utils.register_keras_serializable() class ExtractCLSToken(tf.keras.layers.Layer): def __init__(self, **kwargs): super().__init__(**kwargs) def call(self, inputs): return inputs[:, 0, :] class DiamondInference: def __init__(self, model_path, encoder_dir, model_id=None): # Use provided model_id to load specific artifacts, fallback to generic if not provided self.model_id = model_id if model_id: hp_path = os.path.join(encoder_dir, f"hyperparameters_{model_id}.pkl") cat_path = os.path.join(encoder_dir, f"cat_encoders_{model_id}.pkl") num_path = os.path.join(encoder_dir, f"num_scaler_{model_id}.pkl") target_path = os.path.join(encoder_dir, f"target_encoder_{model_id}.pkl") norm_stats_path = os.path.join(encoder_dir, f"norm_stats_{model_id}.pkl") else: # Fallback to older generic names if no ID is passed hp_path = os.path.join(encoder_dir, "hyperparameters_imagenet_100ep.pkl") cat_path = os.path.join(encoder_dir, "cat_encoders_imagenet_100ep.pkl") num_path = os.path.join(encoder_dir, "num_scaler_imagenet_100ep.pkl") target_path = os.path.join(encoder_dir, "target_encoder_imagenet_100ep.pkl") norm_stats_path = os.path.join(encoder_dir, "norm_stats_imagenet_100ep.pkl") print(f"[INFO] Loading artifacts for model ID: {model_id or 'default'}") self.hp = joblib.load(hp_path) self.cat_encoders = joblib.load(cat_path) self.num_scaler = joblib.load(num_path) self.target_encoder = joblib.load(target_path) if os.path.exists(norm_stats_path): self.norm_stats = joblib.load(norm_stats_path) else: # Default fallback to ImageNet stats self.norm_stats = {"mean": np.array([0.485, 0.456, 0.406]), "std": np.array([0.229, 0.224, 0.225])} self.model = tf.keras.models.load_model( model_path, custom_objects={"ClassToken": ClassToken, "ExtractCLSToken": ExtractCLSToken}, compile=False ) print(f"[INFO] Model and artifacts loaded successfully from {model_path}.") def apply_tta_transform(self, img, transform_type): """Apply specific Test-Time Augmentation transformation""" if transform_type == "original": return img elif transform_type == "horizontal_flip": return cv2.flip(img, 1) elif transform_type == "rotation_5": h, w = img.shape[:2] M = cv2.getRotationMatrix2D((w//2, h//2), 5, 1.0) return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) elif transform_type == "rotation_minus_5": h, w = img.shape[:2] M = cv2.getRotationMatrix2D((w//2, h//2), -5, 1.0) return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT) elif transform_type == "brightness_up": return np.clip(img * 1.1, 0, 255).astype(np.uint8) return img def process_image(self, image_path, tta_transform=None): try: image = cv2.imread(image_path, cv2.IMREAD_COLOR) if image is None: return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32) image = cv2.resize(image, (self.hp["image_size"], self.hp["image_size"])) if tta_transform: image = self.apply_tta_transform(image, tta_transform) image = image / 255.0 image = (image - self.norm_stats["mean"]) / (self.norm_stats["std"] + 1e-7) patch_shape = (self.hp["patch_size"], self.hp["patch_size"], self.hp["num_channels"]) patches = patchify(image, patch_shape, self.hp["patch_size"]) patches = np.reshape(patches, self.hp["flat_patches_shape"]).astype(np.float32) return patches except Exception as e: print(f"[ERROR] Image processing failed: {e}") return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32) def predict(self, df_row, image_path, use_tta=True): # 1. Preprocess Tabular Data # Match training categorical features: StoneType, Color, Brown, BlueUv, GrdType, Result categorical_cols = ["StoneType", "Color", "Brown", "BlueUv", "GrdType", "Result"] numerical_cols = ["Carat"] tab_data_list = [] for col in categorical_cols: val = str(df_row.get(col, "__missing__")) # Safe transform for categorical values try: # First check if the column exists in encoders if col in self.cat_encoders: # Check if val is in encoder classes, otherwise fallback to __missing__ if val not in self.cat_encoders[col].classes_: val = "__missing__" if "__missing__" in self.cat_encoders[col].classes_ else self.cat_encoders[col].classes_[0] encoded_val = self.cat_encoders[col].transform([val])[0] else: print(f"[WARN] Encoder for column {col} not found. Using 0.") encoded_val = 0 except Exception as e: print(f"[ERROR] Encoding failed for {col} with value {val}: {e}. Using 0.") encoded_val = 0 tab_data_list.append(encoded_val) for col in numerical_cols: try: val = float(df_row.get(col, 0)) # Reshape for scaler (expected 2D array) scaled_val = self.num_scaler.transform([[val]])[0][0] except Exception as e: print(f"[ERROR] Scaling failed for {col}: {e}. Using 0.") scaled_val = 0 tab_data_list.append(scaled_val) tab_input = np.expand_dims(np.array(tab_data_list, dtype=np.float32), axis=0) # 2. Inference with TTA if use_tta: tta_transforms = ["original", "horizontal_flip", "rotation_5", "rotation_minus_5", "brightness_up"] all_preds = [] for transform in tta_transforms: img_patches = self.process_image(image_path, tta_transform=transform) img_input = np.expand_dims(img_patches, axis=0) preds = self.model.predict([img_input, tab_input], verbose=0)[0] all_preds.append(preds) final_pred_probs = np.mean(all_preds, axis=0) else: img_patches = self.process_image(image_path) img_input = np.expand_dims(img_patches, axis=0) final_pred_probs = self.model.predict([img_input, tab_input], verbose=0)[0] pred_idx = np.argmax(final_pred_probs) decoded_pred = self.target_encoder.inverse_transform([pred_idx])[0] return decoded_pred