| import torch |
| import torch.nn.functional as F |
| from datetime import datetime |
| import os |
| import csv |
|
|
| def few_shot_fault_classification( |
| model, |
| test_images, |
| test_image_filenames, |
| nominal_images, |
| nominal_descriptions, |
| defective_images, |
| defective_descriptions, |
| num_few_shot_nominal_imgs: int, |
| device="cpu", |
| file_path: str = '.', |
| file_name: str = 'image_classification_results.csv', |
| print_one_liner: bool = False |
| ): |
| """ |
| Classify test images as nominal or defective based on similarity to nominal and defective images. |
| """ |
| |
| if not isinstance(test_images, list): |
| test_images = [test_images] |
| if not isinstance(test_image_filenames, list): |
| test_image_filenames = [test_image_filenames] |
| if not isinstance(nominal_images, list): |
| nominal_images = [nominal_images] |
| if not isinstance(nominal_descriptions, list): |
| nominal_descriptions = [nominal_descriptions] |
| if not isinstance(defective_images, list): |
| defective_images = [defective_images] |
| if not isinstance(defective_descriptions, list): |
| defective_descriptions = [defective_descriptions] |
|
|
| |
| os.makedirs(file_path, exist_ok=True) |
| |
| |
| csv_file = os.path.join(file_path, file_name) |
| results = [] |
|
|
| with torch.no_grad(): |
| |
| nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images]) |
| nominal_features /= nominal_features.norm(dim=-1, keepdim=True) |
|
|
| |
| defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images]) |
| defective_features /= defective_features.norm(dim=-1, keepdim=True) |
|
|
| |
| csv_data = [] |
|
|
| |
| for idx, test_img in enumerate(test_images): |
| test_features = model.encode_image(test_img.to(device)) |
| test_features /= test_features.norm(dim=-1, keepdim=True) |
|
|
| |
| max_nominal_similarity = -float('inf') |
| max_defective_similarity = -float('inf') |
| max_nominal_idx = -1 |
| max_defective_idx = -1 |
|
|
| |
| for i in range(nominal_features.shape[0]): |
| similarity = (test_features @ nominal_features[i].T).item() |
| if similarity > max_nominal_similarity: |
| max_nominal_similarity = similarity |
| max_nominal_idx = i |
|
|
| |
| for j in range(defective_features.shape[0]): |
| similarity = (test_features @ defective_features[j].T).item() |
| if similarity > max_defective_similarity: |
| max_defective_similarity = similarity |
| max_defective_idx = j |
|
|
| |
| similarities = torch.tensor([max_nominal_similarity, max_defective_similarity]) |
| probabilities = F.softmax(similarities, dim=0).tolist() |
| prob_not_defective = probabilities[0] |
| prob_defective = probabilities[1] |
|
|
| |
| classification = "Defective" if prob_defective > prob_not_defective else "Nominal" |
|
|
| |
| result = { |
| "datetime_of_operation": datetime.now().isoformat(), |
| "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs, |
| "image_path": test_image_filenames[idx], |
| "image_name": test_image_filenames[idx].split('/')[-1], |
| "classification_result": classification, |
| "non_defect_prob": round(prob_not_defective, 3), |
| "defect_prob": round(prob_defective, 3), |
| "nominal_description": nominal_descriptions[max_nominal_idx], |
| "defective_description": defective_descriptions[max_defective_idx], |
| "max_nominal_similarity": round(max_nominal_similarity, 3), |
| "max_defective_similarity": round(max_defective_similarity, 3) |
| } |
| |
| csv_data.append(result) |
| results.append(result) |
|
|
| |
| if print_one_liner: |
| print(f"{test_image_filenames[idx]} → {classification} " |
| f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})") |
|
|
| |
| file_exists = os.path.isfile(csv_file) |
| with open(csv_file, mode='a' if file_exists else 'w', newline='') as file: |
| fieldnames = [ |
| "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name", |
| "classification_result", "non_defect_prob", "defect_prob", |
| "nominal_description", "defective_description", |
| "max_nominal_similarity", "max_defective_similarity" |
| ] |
| writer = csv.DictWriter(file, fieldnames=fieldnames) |
|
|
| |
| if not file_exists: |
| writer.writeheader() |
|
|
| |
| for row in csv_data: |
| writer.writerow(row) |
|
|
| return results |