| import dataclasses
|
| import logging
|
| import math
|
| import os
|
| import io
|
| import sys
|
| import time
|
| import json
|
| from typing import Optional, Sequence, Union
|
|
|
| import openai
|
| import tqdm
|
| from openai import openai_object
|
| import copy
|
|
|
| StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
|
|
|
| openai.api_key =''
|
| openai_org = os.getenv("OPENAI_ORG")
|
| if openai_org is not None:
|
| openai.organization = openai_org
|
| logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
|
|
|
|
|
| @dataclasses.dataclass
|
| class OpenAIDecodingArguments(object):
|
| max_tokens: int = 1800
|
| temperature: float = 0.2
|
| top_p: float = 1.0
|
| n: int = 1
|
| stream: bool = False
|
| stop: Optional[Sequence[str]] = None
|
| presence_penalty: float = 0.0
|
| frequency_penalty: float = 0.0
|
| suffix: Optional[str] = None
|
| logprobs: Optional[int] = None
|
| echo: bool = False
|
|
|
|
|
| def openai_completion(
|
| prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
|
| decoding_args: OpenAIDecodingArguments,
|
| model_name="text-davinci-003",
|
| sleep_time=2,
|
| batch_size=1,
|
| max_instances=sys.maxsize,
|
| max_batches=sys.maxsize,
|
| return_text=False,
|
| **decoding_kwargs,
|
| ) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
|
| """Decode with OpenAI API.
|
|
|
| Args:
|
| prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
|
| as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
|
| it can also be a dictionary (or list thereof) as explained here:
|
| https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
| decoding_args: Decoding arguments.
|
| model_name: Model name. Can be either in the format of "org/model" or just "model".
|
| sleep_time: Time to sleep once the rate-limit is hit.
|
| batch_size: Number of prompts to send in a single request. Only for non chat model.
|
| max_instances: Maximum number of prompts to decode.
|
| max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
|
| return_text: If True, return text instead of full completion object (which contains things like logprob).
|
| decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
|
|
|
| Returns:
|
| A completion or a list of completions.
|
| Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
|
| - a string (if return_text is True)
|
| - an openai_object.OpenAIObject object (if return_text is False)
|
| - a list of objects of the above types (if decoding_args.n > 1)
|
| """
|
| is_single_prompt = isinstance(prompts, (str, dict))
|
| if is_single_prompt:
|
| prompts = [prompts]
|
|
|
| if max_batches < sys.maxsize:
|
| logging.warning(
|
| "`max_batches` will be deprecated in the future, please use `max_instances` instead."
|
| "Setting `max_instances` to `max_batches * batch_size` for now."
|
| )
|
| max_instances = max_batches * batch_size
|
|
|
| prompts = prompts[:max_instances]
|
| num_prompts = len(prompts)
|
| prompt_batches = [
|
| prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
|
| for batch_id in range(int(math.ceil(num_prompts / batch_size)))
|
| ]
|
|
|
| completions = []
|
| for batch_id, prompt_batch in tqdm.tqdm(
|
| enumerate(prompt_batches),
|
| desc="prompt_batches",
|
| total=len(prompt_batches),
|
| ):
|
| batch_decoding_args = copy.deepcopy(decoding_args)
|
|
|
| while True:
|
| try:
|
| shared_kwargs = dict(
|
| model=model_name,
|
| **batch_decoding_args.__dict__,
|
| **decoding_kwargs,
|
| )
|
| completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
|
| choices = completion_batch.choices
|
|
|
| for choice in choices:
|
| choice["total_tokens"] = completion_batch.usage.total_tokens
|
| completions.extend(choices)
|
| break
|
| except openai.error.OpenAIError as e:
|
| logging.warning(f"OpenAIError: {e}.")
|
| if "Please reduce your prompt" in str(e):
|
| batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
|
| logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
|
| else:
|
| logging.warning("Hit request rate limit; retrying...")
|
| time.sleep(sleep_time)
|
|
|
| if return_text:
|
| completions = [completion.text for completion in completions]
|
| if decoding_args.n > 1:
|
|
|
| completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
|
| if is_single_prompt:
|
|
|
| (completions,) = completions
|
| return completions
|
|
|
|
|
| def _make_w_io_base(f, mode: str):
|
| if not isinstance(f, io.IOBase):
|
| f_dirname = os.path.dirname(f)
|
| if f_dirname != "":
|
| os.makedirs(f_dirname, exist_ok=True)
|
| f = open(f, mode=mode)
|
| return f
|
|
|
|
|
| def _make_r_io_base(f, mode: str):
|
| if not isinstance(f, io.IOBase):
|
| f = open(f, mode=mode)
|
| return f
|
|
|
|
|
| def jdump(obj, f, mode="w", indent=4, default=str):
|
| """Dump a str or dictionary to a file in json format.
|
|
|
| Args:
|
| obj: An object to be written.
|
| f: A string path to the location on disk.
|
| mode: Mode for opening the file.
|
| indent: Indent for storing json dictionaries.
|
| default: A function to handle non-serializable entries; defaults to `str`.
|
| """
|
| f = _make_w_io_base(f, mode)
|
| if isinstance(obj, (dict, list)):
|
| json.dump(obj, f, indent=indent, default=default)
|
| elif isinstance(obj, str):
|
| f.write(obj)
|
| else:
|
| raise ValueError(f"Unexpected type: {type(obj)}")
|
| f.close()
|
|
|
|
|
| def jload(f, mode="r"):
|
| """Load a .json file into a dictionary."""
|
| f = _make_r_io_base(f, mode)
|
| jdict = json.load(f)
|
| f.close()
|
| return jdict
|
|
|