| import os
|
| import joblib
|
| import numpy as np
|
| import pandas as pd
|
| import cv2
|
| import tensorflow as tf
|
| from patchify import patchify
|
|
|
|
|
| @tf.keras.utils.register_keras_serializable()
|
| class ClassToken(tf.keras.layers.Layer):
|
| def __init__(self, **kwargs):
|
| super().__init__(**kwargs)
|
| def build(self, input_shape):
|
| self.hidden_dim = input_shape[-1]
|
| self.w = self.add_weight(
|
| name="cls_token",
|
| shape=(1, 1, self.hidden_dim),
|
| initializer="random_normal",
|
| trainable=True,
|
| )
|
| def call(self, inputs):
|
| batch_size = tf.shape(inputs)[0]
|
| cls = tf.broadcast_to(self.w, [batch_size, 1, self.hidden_dim])
|
| return cls
|
|
|
| @tf.keras.utils.register_keras_serializable()
|
| class ExtractCLSToken(tf.keras.layers.Layer):
|
| def __init__(self, **kwargs):
|
| super().__init__(**kwargs)
|
| def call(self, inputs):
|
| return inputs[:, 0, :]
|
|
|
| class DiamondInference:
|
| def __init__(self, model_path, encoder_dir, model_id=None):
|
|
|
| self.model_id = model_id
|
|
|
| if model_id:
|
| hp_path = os.path.join(encoder_dir, f"hyperparameters_{model_id}.pkl")
|
| cat_path = os.path.join(encoder_dir, f"cat_encoders_{model_id}.pkl")
|
| num_path = os.path.join(encoder_dir, f"num_scaler_{model_id}.pkl")
|
| target_path = os.path.join(encoder_dir, f"target_encoder_{model_id}.pkl")
|
| norm_stats_path = os.path.join(encoder_dir, f"norm_stats_{model_id}.pkl")
|
| else:
|
|
|
| hp_path = os.path.join(encoder_dir, "hyperparameters_imagenet_100ep.pkl")
|
| cat_path = os.path.join(encoder_dir, "cat_encoders_imagenet_100ep.pkl")
|
| num_path = os.path.join(encoder_dir, "num_scaler_imagenet_100ep.pkl")
|
| target_path = os.path.join(encoder_dir, "target_encoder_imagenet_100ep.pkl")
|
| norm_stats_path = os.path.join(encoder_dir, "norm_stats_imagenet_100ep.pkl")
|
|
|
| print(f"[INFO] Loading artifacts for model ID: {model_id or 'default'}")
|
| self.hp = joblib.load(hp_path)
|
| self.cat_encoders = joblib.load(cat_path)
|
| self.num_scaler = joblib.load(num_path)
|
| self.target_encoder = joblib.load(target_path)
|
|
|
| if os.path.exists(norm_stats_path):
|
| self.norm_stats = joblib.load(norm_stats_path)
|
| else:
|
|
|
| self.norm_stats = {"mean": np.array([0.485, 0.456, 0.406]), "std": np.array([0.229, 0.224, 0.225])}
|
|
|
| self.model = tf.keras.models.load_model(
|
| model_path,
|
| custom_objects={"ClassToken": ClassToken, "ExtractCLSToken": ExtractCLSToken},
|
| compile=False
|
| )
|
| print(f"[INFO] Model and artifacts loaded successfully from {model_path}.")
|
|
|
| def apply_tta_transform(self, img, transform_type):
|
| """Apply specific Test-Time Augmentation transformation"""
|
| if transform_type == "original":
|
| return img
|
| elif transform_type == "horizontal_flip":
|
| return cv2.flip(img, 1)
|
| elif transform_type == "rotation_5":
|
| h, w = img.shape[:2]
|
| M = cv2.getRotationMatrix2D((w//2, h//2), 5, 1.0)
|
| return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
|
| elif transform_type == "rotation_minus_5":
|
| h, w = img.shape[:2]
|
| M = cv2.getRotationMatrix2D((w//2, h//2), -5, 1.0)
|
| return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
|
| elif transform_type == "brightness_up":
|
| return np.clip(img * 1.1, 0, 255).astype(np.uint8)
|
| return img
|
|
|
| def process_image(self, image_path, tta_transform=None):
|
| try:
|
| image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
| if image is None:
|
| return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)
|
|
|
| image = cv2.resize(image, (self.hp["image_size"], self.hp["image_size"]))
|
|
|
| if tta_transform:
|
| image = self.apply_tta_transform(image, tta_transform)
|
|
|
| image = image / 255.0
|
| image = (image - self.norm_stats["mean"]) / (self.norm_stats["std"] + 1e-7)
|
|
|
| patch_shape = (self.hp["patch_size"], self.hp["patch_size"], self.hp["num_channels"])
|
| patches = patchify(image, patch_shape, self.hp["patch_size"])
|
| patches = np.reshape(patches, self.hp["flat_patches_shape"]).astype(np.float32)
|
| return patches
|
| except Exception as e:
|
| print(f"[ERROR] Image processing failed: {e}")
|
| return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)
|
|
|
| def predict(self, df_row, image_path, use_tta=True):
|
|
|
|
|
| categorical_cols = ["StoneType", "Color", "Brown", "BlueUv", "GrdType", "Result"]
|
| numerical_cols = ["Carat"]
|
|
|
| tab_data_list = []
|
| for col in categorical_cols:
|
| val = str(df_row.get(col, "__missing__"))
|
|
|
| try:
|
|
|
| if col in self.cat_encoders:
|
|
|
| if val not in self.cat_encoders[col].classes_:
|
| val = "__missing__" if "__missing__" in self.cat_encoders[col].classes_ else self.cat_encoders[col].classes_[0]
|
|
|
| encoded_val = self.cat_encoders[col].transform([val])[0]
|
| else:
|
| print(f"[WARN] Encoder for column {col} not found. Using 0.")
|
| encoded_val = 0
|
| except Exception as e:
|
| print(f"[ERROR] Encoding failed for {col} with value {val}: {e}. Using 0.")
|
| encoded_val = 0
|
| tab_data_list.append(encoded_val)
|
|
|
| for col in numerical_cols:
|
| try:
|
| val = float(df_row.get(col, 0))
|
|
|
| scaled_val = self.num_scaler.transform([[val]])[0][0]
|
| except Exception as e:
|
| print(f"[ERROR] Scaling failed for {col}: {e}. Using 0.")
|
| scaled_val = 0
|
| tab_data_list.append(scaled_val)
|
|
|
| tab_input = np.expand_dims(np.array(tab_data_list, dtype=np.float32), axis=0)
|
|
|
|
|
| if use_tta:
|
| tta_transforms = ["original", "horizontal_flip", "rotation_5", "rotation_minus_5", "brightness_up"]
|
| all_preds = []
|
|
|
| for transform in tta_transforms:
|
| img_patches = self.process_image(image_path, tta_transform=transform)
|
| img_input = np.expand_dims(img_patches, axis=0)
|
| preds = self.model.predict([img_input, tab_input], verbose=0)[0]
|
| all_preds.append(preds)
|
|
|
| final_pred_probs = np.mean(all_preds, axis=0)
|
| else:
|
| img_patches = self.process_image(image_path)
|
| img_input = np.expand_dims(img_patches, axis=0)
|
| final_pred_probs = self.model.predict([img_input, tab_input], verbose=0)[0]
|
|
|
| pred_idx = np.argmax(final_pred_probs)
|
| decoded_pred = self.target_encoder.inverse_transform([pred_idx])[0]
|
|
|
| return decoded_pred
|
|
|