D3V1L1810's picture
Update app.py
7a437ab verified
import tensorflow as tf
from transformers import BertTokenizer, TFBertForSequenceClassification
import numpy as np
import json
import requests
import gradio as gr
import logging
# Initialize the tokenizer and model
bert_tokenizer = BertTokenizer.from_pretrained('MultiTokenizer_ep10')
bert_model = TFBertForSequenceClassification.from_pretrained('MultiModel_ep10')
# Function to send results to API
# def send_results_to_api(data, result_url):
# headers = {'Content-Type':'application/json'}
# response = requests.post(result_url, json = data, headers=headers)
# if response.status_code == 200:
# return response.json
# else:
# return {'error':f"failed to send result to API: {response.status_code}"}
def predict_text(params):
try:
params = json.loads(params)
except json.JSONDecodeError as e:
logging.error(f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}")
return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"}
texts = params.get("urls", [])
if not params.get("normalfileID", []):
file_ids = [None] * len(texts)
else:
file_ids = params.get("normalfileID", [])
if not texts:
return {"error": "Missing required parameters: 'texts'"}
solutions = []
confidence_threshold = 0.85 # Define your confidence threshold
for text, file_id in zip(texts, file_ids):
encoding = bert_tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=128,
return_token_type_ids=True,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='tf'
)
input_ids = encoding['input_ids']
token_type_ids = encoding['token_type_ids']
attention_mask = encoding['attention_mask']
pred = bert_model.predict([input_ids, token_type_ids, attention_mask])
logits = pred.logits
softmax_scores = tf.nn.softmax(logits, axis=1).numpy()[0]
pred_label = tf.argmax(logits, axis=1).numpy()[0]
# Get the confidence score for the predicted label
confidence = softmax_scores[pred_label]
print(confidence)
# If confidence is below the threshold, set answer to None
if confidence < confidence_threshold:
pred_label = 7 # Set to 'None' class
label = {0: 'BUSINESS', 1: 'COMEDY', 2: 'CRIME', 3: 'FOOD & DRINK', 4: 'POLITICS', 5: 'SPORTS', 6: 'TRAVEL', 7: 'None'}
result = {'text': text, 'answer': [label[pred_label]], "qcUser": None, "normalfileID": file_id}
solutions.append(result)
# result_url = f"{api}/{job_id}"
# send_results_to_api(solutions, result_url)
return json.dumps({"solutions": solutions})
inputt = gr.Textbox(label="Parameters in Json Format... Eg. {'texts':['text1', 'text2']}")
outputt = gr.JSON()
application = gr.Interface(fn=predict_text, inputs=inputt, outputs=outputt, title='Multi Text Classification with API Integration..')
application.launch()