Diamond-Color-Prediction / inference_utils.py
WebashalarForML's picture
Upload 7 files
8eab558 verified
import os
import joblib
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
from patchify import patchify
# 1. Define Custom Layers
@tf.keras.utils.register_keras_serializable()
class ClassToken(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def build(self, input_shape):
self.hidden_dim = input_shape[-1]
self.w = self.add_weight(
name="cls_token",
shape=(1, 1, self.hidden_dim),
initializer="random_normal",
trainable=True,
)
def call(self, inputs):
batch_size = tf.shape(inputs)[0]
cls = tf.broadcast_to(self.w, [batch_size, 1, self.hidden_dim])
return cls
@tf.keras.utils.register_keras_serializable()
class ExtractCLSToken(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def call(self, inputs):
return inputs[:, 0, :]
class DiamondInference:
def __init__(self, model_path, encoder_dir, model_id=None):
# Use provided model_id to load specific artifacts, fallback to generic if not provided
self.model_id = model_id
if model_id:
hp_path = os.path.join(encoder_dir, f"hyperparameters_{model_id}.pkl")
cat_path = os.path.join(encoder_dir, f"cat_encoders_{model_id}.pkl")
num_path = os.path.join(encoder_dir, f"num_scaler_{model_id}.pkl")
target_path = os.path.join(encoder_dir, f"target_encoder_{model_id}.pkl")
norm_stats_path = os.path.join(encoder_dir, f"norm_stats_{model_id}.pkl")
else:
# Fallback to older generic names if no ID is passed
hp_path = os.path.join(encoder_dir, "hyperparameters_imagenet_100ep.pkl")
cat_path = os.path.join(encoder_dir, "cat_encoders_imagenet_100ep.pkl")
num_path = os.path.join(encoder_dir, "num_scaler_imagenet_100ep.pkl")
target_path = os.path.join(encoder_dir, "target_encoder_imagenet_100ep.pkl")
norm_stats_path = os.path.join(encoder_dir, "norm_stats_imagenet_100ep.pkl")
print(f"[INFO] Loading artifacts for model ID: {model_id or 'default'}")
self.hp = joblib.load(hp_path)
self.cat_encoders = joblib.load(cat_path)
self.num_scaler = joblib.load(num_path)
self.target_encoder = joblib.load(target_path)
if os.path.exists(norm_stats_path):
self.norm_stats = joblib.load(norm_stats_path)
else:
# Default fallback to ImageNet stats
self.norm_stats = {"mean": np.array([0.485, 0.456, 0.406]), "std": np.array([0.229, 0.224, 0.225])}
self.model = tf.keras.models.load_model(
model_path,
custom_objects={"ClassToken": ClassToken, "ExtractCLSToken": ExtractCLSToken},
compile=False
)
print(f"[INFO] Model and artifacts loaded successfully from {model_path}.")
def apply_tta_transform(self, img, transform_type):
"""Apply specific Test-Time Augmentation transformation"""
if transform_type == "original":
return img
elif transform_type == "horizontal_flip":
return cv2.flip(img, 1)
elif transform_type == "rotation_5":
h, w = img.shape[:2]
M = cv2.getRotationMatrix2D((w//2, h//2), 5, 1.0)
return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
elif transform_type == "rotation_minus_5":
h, w = img.shape[:2]
M = cv2.getRotationMatrix2D((w//2, h//2), -5, 1.0)
return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)
elif transform_type == "brightness_up":
return np.clip(img * 1.1, 0, 255).astype(np.uint8)
return img
def process_image(self, image_path, tta_transform=None):
try:
image = cv2.imread(image_path, cv2.IMREAD_COLOR)
if image is None:
return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)
image = cv2.resize(image, (self.hp["image_size"], self.hp["image_size"]))
if tta_transform:
image = self.apply_tta_transform(image, tta_transform)
image = image / 255.0
image = (image - self.norm_stats["mean"]) / (self.norm_stats["std"] + 1e-7)
patch_shape = (self.hp["patch_size"], self.hp["patch_size"], self.hp["num_channels"])
patches = patchify(image, patch_shape, self.hp["patch_size"])
patches = np.reshape(patches, self.hp["flat_patches_shape"]).astype(np.float32)
return patches
except Exception as e:
print(f"[ERROR] Image processing failed: {e}")
return np.zeros(self.hp["flat_patches_shape"], dtype=np.float32)
def predict(self, df_row, image_path, use_tta=True):
# 1. Preprocess Tabular Data
# Match training categorical features: StoneType, Color, Brown, BlueUv, GrdType, Result
categorical_cols = ["StoneType", "Color", "Brown", "BlueUv", "GrdType", "Result"]
numerical_cols = ["Carat"]
tab_data_list = []
for col in categorical_cols:
val = str(df_row.get(col, "__missing__"))
# Safe transform for categorical values
try:
# First check if the column exists in encoders
if col in self.cat_encoders:
# Check if val is in encoder classes, otherwise fallback to __missing__
if val not in self.cat_encoders[col].classes_:
val = "__missing__" if "__missing__" in self.cat_encoders[col].classes_ else self.cat_encoders[col].classes_[0]
encoded_val = self.cat_encoders[col].transform([val])[0]
else:
print(f"[WARN] Encoder for column {col} not found. Using 0.")
encoded_val = 0
except Exception as e:
print(f"[ERROR] Encoding failed for {col} with value {val}: {e}. Using 0.")
encoded_val = 0
tab_data_list.append(encoded_val)
for col in numerical_cols:
try:
val = float(df_row.get(col, 0))
# Reshape for scaler (expected 2D array)
scaled_val = self.num_scaler.transform([[val]])[0][0]
except Exception as e:
print(f"[ERROR] Scaling failed for {col}: {e}. Using 0.")
scaled_val = 0
tab_data_list.append(scaled_val)
tab_input = np.expand_dims(np.array(tab_data_list, dtype=np.float32), axis=0)
# 2. Inference with TTA
if use_tta:
tta_transforms = ["original", "horizontal_flip", "rotation_5", "rotation_minus_5", "brightness_up"]
all_preds = []
for transform in tta_transforms:
img_patches = self.process_image(image_path, tta_transform=transform)
img_input = np.expand_dims(img_patches, axis=0)
preds = self.model.predict([img_input, tab_input], verbose=0)[0]
all_preds.append(preds)
final_pred_probs = np.mean(all_preds, axis=0)
else:
img_patches = self.process_image(image_path)
img_input = np.expand_dims(img_patches, axis=0)
final_pred_probs = self.model.predict([img_input, tab_input], verbose=0)[0]
pred_idx = np.argmax(final_pred_probs)
decoded_pred = self.target_encoder.inverse_transform([pred_idx])[0]
return decoded_pred