WebashalarForML's picture
Upload 5 files
623a404 verified
import os
import sys
import zipfile
import pandas as pd
import numpy as np
from flask import Flask, request, redirect, url_for, send_from_directory, flash, render_template
from werkzeug.utils import secure_filename
from tqdm import tqdm
from sklearn.metrics import classification_report, precision_recall_fscore_support
from inference_utils import DiamondInference
from dotenv import load_dotenv
# Load local environment variables from .env
load_dotenv()
app = Flask(__name__)
app.secret_key = "supersecretkey"
# Hugging Face Hub Integration
HF_REPO_ID = os.getenv("HF_REPO_ID", "WebashalarForML/Diamcol")
HF_TOKEN = os.getenv("HF_TOKEN")
# Model Configuration
MODEL_ID = "322c4f4d"
MODEL_NAME = f"model_vit_robust_{MODEL_ID}.keras"
def download_model_from_hf():
from huggingface_hub import hf_hub_download
print("[INFO] Checking model files from Hugging Face...")
# Model file
if not os.path.exists(MODEL_NAME):
print(f"[INFO] Downloading {MODEL_NAME}...")
hf_hub_download(repo_id=HF_REPO_ID, filename=MODEL_NAME, token=HF_TOKEN, local_dir=".")
# Encoder files (Matches names in inference_utils.py)
encoder_files = [
f"hyperparameters_{MODEL_ID}.pkl",
f"cat_encoders_{MODEL_ID}.pkl",
f"num_scaler_{MODEL_ID}.pkl",
f"target_encoder_{MODEL_ID}.pkl",
f"norm_stats_{MODEL_ID}.pkl"
]
os.makedirs("encoder", exist_ok=True)
for f in encoder_files:
f_path = os.path.join("encoder", f)
if not os.path.exists(f_path):
print(f"[INFO] Downloading {f}...")
# Note: Assuming the structure on HF is encoder/filename
hf_hub_download(repo_id=HF_REPO_ID, filename=f"encoder/{f}", token=HF_TOKEN, local_dir=".")
UPLOAD_FOLDER = 'uploads'
RESULTS_FOLDER = 'results'
EXTRACT_FOLDER = os.path.join(UPLOAD_FOLDER, 'extracted_images')
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER, EXTRACT_FOLDER]:
if not os.path.exists(folder):
os.makedirs(folder)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max upload
# Global inference object (lazy loaded)
model_path = MODEL_NAME
encoder_dir = "encoder"
infer_engine = None
def get_inference_engine():
global infer_engine
if infer_engine is None:
# Try downloading if missing (for Docker/HF Spaces environment)
try:
download_model_from_hf()
except Exception as e:
print(f"[WARNING] Could not download from HF: {e}. Expecting local files.")
infer_engine = DiamondInference(model_path, encoder_dir, MODEL_ID)
# Warmup prediction to initialize TF graph and prevent "stuck" feeling on first stone
print("[INFO] Warming up Inference Engine...")
try:
# Create a dummy row and zero patches for warmup
dummy_row = {"StoneType": "NATURAL", "Color": "D", "Brown": "N", "BlueUv": "N", "GrdType": "GIA", "Carat": 1.0, "Result": "D"}
# We don't need a real image for warmup, just a pass through predict
# We'll mock process_image to return zeros
orig_process = infer_engine.process_image
try:
infer_engine.process_image = lambda path, tta_transform=None: np.zeros(infer_engine.hp["flat_patches_shape"], dtype=np.float32)
infer_engine.predict(dummy_row, "warmup.jpg", use_tta=False)
finally:
infer_engine.process_image = orig_process
print("[INFO] Warmup complete.")
except Exception as e:
print(f"[WARNING] Warmup failed: {e}")
return infer_engine
@app.route('/flush', methods=['POST'])
def flush_data():
import shutil
try:
# Clear uploads folder
for filename in os.listdir(UPLOAD_FOLDER):
file_path = os.path.join(UPLOAD_FOLDER, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
# Re-create EXTRACT_FOLDER as it might have been deleted if it was a sub-dir
if not os.path.exists(EXTRACT_FOLDER):
os.makedirs(EXTRACT_FOLDER)
# Clear results folder
for filename in os.listdir(RESULTS_FOLDER):
file_path = os.path.join(RESULTS_FOLDER, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
flash('All data flushed successfully.')
except Exception as e:
flash(f'Error during flushing: {e}')
return redirect(url_for('index'))
@app.route('/')
def index():
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload_files():
if 'zip_file' not in request.files or 'excel_file' not in request.files:
flash('Both Zip and Excel files are required.')
return redirect(request.url)
zip_file = request.files['zip_file']
excel_file = request.files['excel_file']
if zip_file.filename == '' or excel_file.filename == '':
flash('No selected file')
return redirect(request.url)
# Save and Extract Zip
zip_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(zip_file.filename))
zip_file.save(zip_path)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(EXTRACT_FOLDER)
# Process Excel
excel_path = os.path.join(app.config['UPLOAD_FOLDER'], secure_filename(excel_file.filename))
excel_file.save(excel_path)
df = pd.read_excel(excel_path)
# Inference Logic
engine = get_inference_engine()
# Pre-cache all image paths for faster searching
all_extracted_files = []
for root, dirs, files in os.walk(EXTRACT_FOLDER):
for f in files:
if f.lower().endswith(('.jpg', '.jpeg', '.png')):
all_extracted_files.append(os.path.join(root, f))
print(f"[INFO] Found {len(all_extracted_files)} images in extraction folder.")
# Identifing ground truth for metrics
y_true = []
y_pred = []
print(f"[INFO] Initializing Inference Pipeline for {len(df)} stones...")
sys.stdout.flush()
# Progress bar with direct stdout for Gunicorn visibility
pbar = tqdm(df.iterrows(), total=len(df), desc="Inference Progress", file=sys.stdout)
for index, row in pbar:
l_code = str(row.get('L_Code', '')).split('.')[0]
sr_no = str(row.get('SrNo', '')).split('.')[0]
stone_id = str(row.get('Stone_Id', ''))
# Log currently processing stone for "aliveness" verification
if index % 5 == 0:
print(f"[PROC] Stone {index+1}/{len(df)}: {l_code}")
sys.stdout.flush()
img_path = None
for full_path in all_extracted_files:
fname = os.path.basename(full_path)
if l_code in fname and sr_no in fname:
img_path = full_path
break
if not img_path and stone_id != 'nan' and stone_id:
for full_path in all_extracted_files:
if stone_id in os.basename(full_path):
img_path = full_path
break
if img_path:
prediction = engine.predict(row, img_path)
# Store filename relative to EXTRACT_FOLDER for web serving
web_path = os.path.relpath(img_path, start=EXTRACT_FOLDER)
df.at[index, 'Predicted_FGrdCol'] = prediction
df.at[index, 'Image_Path'] = web_path
# If ground truth exists, collect it
if 'FGrdCol' in row and pd.notna(row['FGrdCol']):
y_true.append(str(row['FGrdCol']))
y_pred.append(str(prediction))
else:
df.at[index, 'Predicted_FGrdCol'] = "Image Not Found"
df.at[index, 'Image_Path'] = "N/A"
# Calculate Metrics if ground truth is available
metrics = None
if y_true:
report_dict = classification_report(y_true, y_pred, output_dict=True, zero_division=0)
# Clean up the report for better display
class_metrics = []
labels = sorted(list(set(y_true) | set(y_pred)))
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_true, y_pred, labels=labels)
for label, scores in report_dict.items():
if label not in ['accuracy', 'macro avg', 'weighted avg']:
class_metrics.append({
'label': label,
'precision': round(scores['precision'], 4),
'recall': round(scores['recall'], 4),
'f1': round(scores['f1-score'], 4),
'support': scores['support']
})
metrics = {
'accuracy': round(report_dict['accuracy'], 4),
'class_metrics': class_metrics,
'weighted_avg': report_dict['weighted avg'],
'macro_avg': report_dict['macro avg'],
'precision': round(report_dict['weighted avg']['precision'], 4),
'recall': round(report_dict['weighted avg']['recall'], 4),
'f1': round(report_dict['weighted avg']['f1-score'], 4),
'macro_f1': round(report_dict['macro avg']['f1-score'], 4),
'macro_precision': round(report_dict['macro avg']['precision'], 4),
'macro_recall': round(report_dict['macro avg']['recall'], 4),
'confusion_matrix': {
'labels': labels,
'matrix': cm.tolist()
}
}
# Model parameters (features used for prediction)
model_features = ["StoneType", "Color", "Brown", "BlueUv", "GrdType", "Carat", "Result"]
# Identify "out of box" features - only if they actually contain data
potential_oob = ['FancyYellow', 'Type2A', 'YellowUv']
out_of_box_cols = []
for col in potential_oob:
if col in df.columns:
# Check if there is at least one non-null/non-empty value
if df[col].dropna().astype(str).str.strip().replace(['nan', 'None', ''], pd.NA).notna().any():
out_of_box_cols.append(col)
output_filename = f"report_{secure_filename(excel_file.filename)}"
output_path = os.path.join(RESULTS_FOLDER, output_filename)
df.to_excel(output_path, index=False)
return render_template('report.html',
report_data=df.to_dict(orient='records'),
report_file=output_filename,
out_of_box_cols=out_of_box_cols,
model_features=model_features,
metrics=metrics)
@app.route('/download/<filename>')
def download_file(filename):
return send_from_directory(RESULTS_FOLDER, filename)
@app.route('/image/<path:filename>')
def serve_image(filename):
return send_from_directory(EXTRACT_FOLDER, filename)
if __name__ == '__main__':
app.run(debug=True)