| import os |
| from traceback import print_exc |
| import boto3 |
| from handler import ContentHandler |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| endpoint_name = os.environ.get("AWS_ENDPOINT_NAME") |
| aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") |
| aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") |
| aws_region_name = os.environ.get("AWS_REGION_NAME") |
|
|
| boto_client = boto3.client( |
| service_name='sagemaker-runtime', |
| aws_access_key_id=aws_access_key_id, |
| aws_secret_access_key=aws_secret_access_key, |
| region_name=aws_region_name) |
|
|
| content_handler = ContentHandler() |
|
|
| def invoke_endpoint( |
| input_, |
| model_parameters, |
| ): |
| try: |
| response = boto_client.invoke_endpoint( |
| EndpointName=endpoint_name, |
| ContentType='application/json', |
| Body=content_handler.transform_input(prompt=input_, model_kwargs=model_parameters) |
| ) |
| return content_handler.transform_output(response['Body']) |
| except: |
| print_exc() |
| return None |