| | import sys |
| | from typing import List |
| | import traceback |
| | import os |
| | import base64 |
| |
|
| | import logging |
| | logging.basicConfig(level=logging.INFO) |
| | import modules.cloud_logging |
| |
|
| | import tokenizers |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | import json |
| | import pprint |
| |
|
| | |
| | if os.path.exists('debug'): |
| | BIG_MODEL = False |
| | CUDA = False |
| | else: |
| | BIG_MODEL = True |
| | CUDA = True |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | PORT = 7860 |
| | VERBOSE = False |
| |
|
| | MAX_LENGTH = 1024+64 |
| | TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.' |
| | BIG_MODEL = False |
| | CUDA = False |
| |
|
| | if BIG_MODEL: |
| | model_name = "facebook/incoder-6B" |
| | kwargs = dict( |
| | revision="float16", |
| | torch_dtype=torch.float16, |
| | low_cpu_mem_usage=True, |
| | ) |
| | else: |
| | model_name = "facebook/incoder-1B" |
| | kwargs = dict() |
| |
|
| | from fastapi import FastAPI, Request |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.responses import FileResponse, StreamingResponse |
| | app = FastAPI(docs_url=None, redoc_url=None) |
| | app.mount("/static", StaticFiles(directory="static"), name="static") |
| |
|
| |
|
| | logging.info(f"loading model {model_name}") |
| | model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) |
| | logging.info("loading tokenizer") |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | logging.info("loading complete") |
| |
|
| | if CUDA: |
| | model = model.half().cuda() |
| |
|
| | BOS = "<|endoftext|>" |
| | EOM = "<|endofmask|>" |
| |
|
| | def make_sentinel(i): |
| | return f"<|mask:{i}|>" |
| |
|
| | SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM] |
| |
|
| | def generate(input, length_limit=None, temperature=None): |
| | input_ids = tokenizer(input, return_tensors="pt").input_ids |
| | if CUDA: |
| | input_ids = input_ids.cuda() |
| | current_length = input_ids.flatten().size(0) |
| | max_length = length_limit + current_length |
| | truncated = False |
| | if max_length > MAX_LENGTH: |
| | max_length = MAX_LENGTH |
| | truncated = True |
| | if max_length == current_length: |
| | return input, True |
| | output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=max_length) |
| | detok_hypo_str = tokenizer.decode(output.flatten()) |
| | if detok_hypo_str.startswith(BOS): |
| | detok_hypo_str = detok_hypo_str[len(BOS):] |
| | return detok_hypo_str, truncated |
| |
|
| | def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1): |
| | assert isinstance(parts, list) |
| | retries_attempted = 0 |
| | done = False |
| |
|
| |
|
| | while (not done) and (retries_attempted < max_retries): |
| | any_truncated = False |
| | retries_attempted += 1 |
| | if VERBOSE: |
| | logging.info(f"retry {retries_attempted}") |
| | if len(parts) == 1: |
| | prompt = parts[0] |
| | else: |
| | prompt = "" |
| | |
| | for sentinel_ix, part in enumerate(parts): |
| | prompt += part |
| | if extra_sentinel or (sentinel_ix < len(parts) - 1): |
| | prompt += make_sentinel(sentinel_ix) |
| | |
| | |
| | |
| | infills = [] |
| | complete = [] |
| |
|
| | done = True |
| |
|
| | for sentinel_ix, part in enumerate(parts[:-1]): |
| | complete.append(part) |
| | prompt += make_sentinel(sentinel_ix) |
| | completion, this_truncated = generate(prompt, length_limit, temperature) |
| | any_truncated |= this_truncated |
| | completion = completion[len(prompt):] |
| | if EOM not in completion: |
| | if VERBOSE: |
| | logging.info(f"warning: {EOM} not found") |
| | completion += EOM |
| | |
| | done = False |
| | completion = completion[:completion.index(EOM) + len(EOM)] |
| | infilled = completion[:-len(EOM)] |
| | infills.append(infilled) |
| | complete.append(infilled) |
| | prompt += completion |
| | complete.append(parts[-1]) |
| | text = ''.join(complete) |
| |
|
| | if VERBOSE: |
| | logging.info("generated text:") |
| | logging.info(prompt) |
| | logging.info() |
| | logging.info("parts:") |
| | logging.info(parts) |
| | logging.info() |
| | logging.info("infills:") |
| | logging.info(infills) |
| | logging.info() |
| | logging.info("restitched text:") |
| | logging.info(text) |
| | logging.info() |
| | |
| | return { |
| | 'text': text, |
| | 'parts': parts, |
| | 'infills': infills, |
| | 'retries_attempted': retries_attempted, |
| | 'truncated': any_truncated, |
| | } |
| |
|
| |
|
| | @app.head("/") |
| | @app.get("/") |
| | def index() -> FileResponse: |
| | return FileResponse(path="static/index.html", media_type="text/html") |
| |
|
| | @app.get('/generate') |
| | |
| | async def generate_maybe(info: str): |
| | |
| | |
| | |
| | |
| | info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') |
| | form = json.loads(info) |
| | |
| | prompt = form['prompt'] |
| | length_limit = int(form['length']) |
| | temperature = float(form['temperature']) |
| | logging.info(json.dumps({ |
| | 'length': length_limit, |
| | 'temperature': temperature, |
| | 'prompt': prompt, |
| | })) |
| | try: |
| | generation, truncated = generate(prompt, length_limit, temperature) |
| | if truncated: |
| | message = TRUNCATION_MESSAGE |
| | else: |
| | message = '' |
| | return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation, 'message': message} |
| | except Exception as e: |
| | traceback.print_exception(*sys.exc_info()) |
| | logging.error(e) |
| | return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'} |
| |
|
| | @app.get('/infill') |
| | |
| | async def infill_maybe(info: str): |
| | |
| | |
| | |
| | |
| | info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') |
| | form = json.loads(info) |
| | length_limit = int(form['length']) |
| | temperature = float(form['temperature']) |
| | max_retries = 1 |
| | extra_sentinel = True |
| | logging.info(json.dumps({ |
| | 'length': length_limit, |
| | 'temperature': temperature, |
| | 'parts_joined': '<infill>'.join(form['parts']), |
| | })) |
| | try: |
| | if len(form['parts']) > 4: |
| | return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Can't use more than 3 <infill> tokens in this demo (for efficiency)."} |
| | generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries) |
| | generation['result'] = 'success' |
| | generation['type'] = 'infill' |
| | if generation['truncated']: |
| | generation['message'] = TRUNCATION_MESSAGE |
| | else: |
| | generation['message'] = '' |
| | return generation |
| | |
| | except Exception as e: |
| | traceback.print_exception(*sys.exc_info()) |
| | logging.error(e) |
| | return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'} |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.run(host='0.0.0.0', port=PORT, threaded=False) |
| |
|