| from abc import ABC, abstractmethod
|
| from typing import Any, Dict, List, Type, Union
|
|
|
| from pydantic import BaseModel
|
| from app.utils.converter import to_snake_case
|
|
|
| from app.schemas.schema_tools import (
|
| convert_attribute_to_model,
|
| validate_json_data,
|
| validate_json_schema,
|
| )
|
|
|
|
|
| def cf_style_to_pydantic_percentage_shema(
|
| cf_style_schema: dict,
|
| ) -> str:
|
| """
|
| Convert CF style schema to Pydantic schema
|
| """
|
| print(f'{cf_style_schema}')
|
| attributes_line_in_product = []
|
| values_classes = []
|
| for attribute, attribute_info in cf_style_schema.items():
|
| multiple = False
|
| if "list" in attribute_info.data_type:
|
| multiple = True
|
| else:
|
| multiple = False
|
| class_name = "Class_" + attribute.capitalize()
|
| multiple_desc = "multi-label classification" if multiple else "single-label classification"
|
| attribute_desc = attribute_info.description
|
| attribute_line = f'{attribute}: {class_name} = Field("", description="{multiple_desc}, {attribute_desc}")'
|
|
|
| class_code = f"""
|
| class {class_name}(BaseModel):
|
|
|
| """
|
| for value in attribute_info.allowed_values:
|
| class_code += f" {value.lower().replace(' ', '_').replace('-', '_')}: int\n"
|
|
|
| values_classes.append(class_code)
|
| attributes_line_in_product.append(attribute_line)
|
| attributes_line = "\n ".join(attributes_line_in_product)
|
| values_classes_code = "\n".join(values_classes)
|
| pydantic_schema = f"""
|
| from pydantic import BaseModel, Field
|
| {values_classes_code}
|
| class Product(BaseModel):
|
| {attributes_line}
|
| """
|
| pydantic_code = pydantic_schema.strip()
|
| exec(pydantic_code, globals())
|
| return Product
|
|
|
| def build_attributes_types_prompt(attributes):
|
| list_of_types_prompt = "\n List of attributes types:\n"
|
| for key, value in attributes.items():
|
| list_of_types_prompt += f"- {key}: {value.data_type}\n"
|
| return list_of_types_prompt
|
|
|
|
|
| class BaseAttributionService(ABC):
|
| @abstractmethod
|
| async def extract_attributes(
|
| self,
|
| attributes_model: Type[BaseModel],
|
| ai_model: str,
|
| img_urls: List[str],
|
| product_taxonomy: str,
|
| pil_images: List[Any] = None,
|
| appended_prompt: str = "",
|
| ) -> Dict[str, Any]:
|
| pass
|
|
|
| @abstractmethod
|
| async def reevaluate_atributes(
|
| self,
|
| attributes_model: Type[BaseModel],
|
| ai_model: str,
|
| img_urls: List[str],
|
| product_taxonomy: str,
|
| pil_images: List[Any] = None,
|
| appended_prompt: str = "",
|
| ) -> Dict[str, Any]:
|
| pass
|
|
|
| @abstractmethod
|
| async def follow_schema(
|
| self, schema: Dict[str, Any], data: Dict[str, Any]
|
| ) -> Dict[str, Any]:
|
| pass
|
|
|
| async def extract_attributes_with_validation(
|
| self,
|
| attributes: Dict[str, Any],
|
| ai_model: str,
|
| img_urls: List[str],
|
| product_taxonomy: str,
|
| product_data: Dict[str, Union[str, List[str]]],
|
| pil_images: List[Any] = None,
|
| img_paths: List[str] = None,
|
| appended_prompt = str
|
| ) -> Dict[str, Any]:
|
|
|
|
|
|
|
| forward_mapping = {}
|
| reverse_mapping = {}
|
| for i, key in enumerate(attributes.keys()):
|
| forward_mapping[key] = f'{to_snake_case(key)}_{i}'
|
| reverse_mapping[f'{to_snake_case(key)}_{i}'] = key
|
|
|
| transformed_attributes = {}
|
| for key, value in attributes.items():
|
| transformed_attributes[forward_mapping[key]] = value
|
|
|
| attributes_types_prompt = build_attributes_types_prompt(attributes)
|
|
|
|
|
| attributes_percentage_model = cf_style_to_pydantic_percentage_shema(transformed_attributes)
|
| schema = attributes_percentage_model.model_json_schema()
|
| data = await self.extract_attributes(
|
| attributes_percentage_model,
|
| ai_model,
|
| img_urls,
|
| product_taxonomy if product_taxonomy != "" else "main",
|
| product_data,
|
|
|
| img_paths=img_paths,
|
| appended_prompt=attributes_types_prompt
|
| )
|
| validate_json_data(data, schema)
|
|
|
| str_data = str(data)
|
| reevaluate_data = await self.reevaluate_atributes(
|
| attributes_percentage_model,
|
| ai_model,
|
| img_urls,
|
| product_taxonomy if product_taxonomy != "" else "main",
|
| str_data,
|
|
|
| img_paths=img_paths,
|
| appended_prompt=attributes_types_prompt
|
| )
|
|
|
| init_reevaluate_data = {}
|
| for field_name, field in attributes_percentage_model.model_fields.items():
|
| print(f"{field_name}: {field.description}")
|
| if "single-label" in field.description.lower():
|
| max_percentage = 0
|
| for k, v in reevaluate_data[field_name].items():
|
| if v > max_percentage:
|
| max_percentage = v
|
| init_reevaluate_data[field_name] = k
|
| elif "multi-label" in field.description.lower():
|
| init_list = []
|
| for k, v in reevaluate_data[field_name].items():
|
| if v >= 60:
|
| init_list.append(k)
|
| init_reevaluate_data[field_name] = init_list
|
| else:
|
| assert False, f"The description does not contain 'single-label' or 'multi-label': {field.description}"
|
|
|
|
|
| reverse_data = {}
|
| for key, value in init_reevaluate_data.items():
|
| reverse_data[reverse_mapping[key]] = value
|
| return data, reverse_data
|
|
|
| async def follow_schema_with_validation(
|
| self, schema: Dict[str, Any], data: Dict[str, Any]
|
| ) -> Dict[str, Any]:
|
| validate_json_schema(schema)
|
| data = await self.follow_schema(schema, data)
|
| validate_json_data(data, schema)
|
| return data
|
|
|