| import logging |
| from typing import Optional |
|
|
| from core.app.app_config.entities import AppConfig |
| from core.moderation.base import ModerationAction, ModerationError |
| from core.moderation.factory import ModerationFactory |
| from core.ops.entities.trace_entity import TraceTaskName |
| from core.ops.ops_trace_manager import TraceQueueManager, TraceTask |
| from core.ops.utils import measure_time |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class InputModeration: |
| def check( |
| self, |
| app_id: str, |
| tenant_id: str, |
| app_config: AppConfig, |
| inputs: dict, |
| query: str, |
| message_id: str, |
| trace_manager: Optional[TraceQueueManager] = None, |
| ) -> tuple[bool, dict, str]: |
| """ |
| Process sensitive_word_avoidance. |
| :param app_id: app id |
| :param tenant_id: tenant id |
| :param app_config: app config |
| :param inputs: inputs |
| :param query: query |
| :param message_id: message id |
| :param trace_manager: trace manager |
| :return: |
| """ |
| if not app_config.sensitive_word_avoidance: |
| return False, inputs, query |
|
|
| sensitive_word_avoidance_config = app_config.sensitive_word_avoidance |
| moderation_type = sensitive_word_avoidance_config.type |
|
|
| moderation_factory = ModerationFactory( |
| name=moderation_type, app_id=app_id, tenant_id=tenant_id, config=sensitive_word_avoidance_config.config |
| ) |
|
|
| with measure_time() as timer: |
| moderation_result = moderation_factory.moderation_for_inputs(inputs, query) |
|
|
| if trace_manager: |
| trace_manager.add_trace_task( |
| TraceTask( |
| TraceTaskName.MODERATION_TRACE, |
| message_id=message_id, |
| moderation_result=moderation_result, |
| inputs=inputs, |
| timer=timer, |
| ) |
| ) |
|
|
| if not moderation_result.flagged: |
| return False, inputs, query |
|
|
| if moderation_result.action == ModerationAction.DIRECT_OUTPUT: |
| raise ModerationError(moderation_result.preset_response) |
| elif moderation_result.action == ModerationAction.OVERRIDDEN: |
| inputs = moderation_result.inputs |
| query = moderation_result.query |
|
|
| return True, inputs, query |
|
|