| |
| from flask import Flask, request, jsonify |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM |
| import os |
|
|
| app = Flask(__name__) |
|
|
| MODEL_ID = os.environ.get("MODEL_ID", "Helsinki-NLP/opus-mt-en-ar") |
|
|
| |
| translator = None |
|
|
| def get_translator(): |
| global translator |
| if translator is None: |
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) |
| translator = pipeline("translation", model=model, tokenizer=tokenizer, src="en", tgt="ar") |
| return translator |
|
|
| @app.route("/health", methods=["GET"]) |
| def health(): |
| return jsonify({"status": "ok"}), 200 |
|
|
| @app.route("/translate", methods=["POST"]) |
| def translate(): |
| """ |
| Accepts JSON: |
| { |
| "texts": ["Hello", "How are you?"], # or a single string as "text" |
| "max_length": 256, # optional |
| "batch_size": 8 # optional |
| } |
| Returns: |
| { |
| "translations": ["مرحبا", "كيف حالك؟"] |
| } |
| """ |
| payload = request.get_json(force=True) |
| if payload is None: |
| return jsonify({"error": "invalid json"}), 400 |
|
|
| |
| texts = payload.get("texts") or payload.get("text") |
| if texts is None: |
| return jsonify({"error": "provide 'text' or 'texts' in JSON"}), 400 |
|
|
| if isinstance(texts, str): |
| texts = [texts] |
|
|
| max_length = payload.get("max_length", 256) |
| batch_size = payload.get("batch_size", 8) |
|
|
| pipe = get_translator() |
| |
| translated = pipe(texts, max_length=max_length, batch_size=batch_size) |
| |
| out = [t["translation_text"] for t in translated] |
| return jsonify({"translations": out}), 200 |
|
|
| if __name__ == "__main__": |
| |
| app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8080))) |
|
|