| | """ |
| | Client test. |
| | |
| | Run server: |
| | |
| | python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b |
| | |
| | NOTE: For private models, add --use-auth_token=True |
| | |
| | NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches. |
| | Currently, this will force model to be on a single GPU. |
| | |
| | Then run this client as: |
| | |
| | python src/client_test.py |
| | |
| | |
| | |
| | For HF spaces: |
| | |
| | HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py |
| | |
| | Result: |
| | |
| | Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔ |
| | {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''} |
| | |
| | |
| | For demo: |
| | |
| | HOST="https://gpt.h2o.ai" python src/client_test.py |
| | |
| | Result: |
| | |
| | Loaded as API: https://gpt.h2o.ai ✔ |
| | {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''} |
| | |
| | NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict: |
| | |
| | {'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''} |
| | |
| | |
| | """ |
| | import ast |
| | import time |
| | import os |
| | import markdown |
| | import pytest |
| | from bs4 import BeautifulSoup |
| |
|
| | from src.utils import is_gradio_version4 |
| |
|
| | try: |
| | from enums import DocumentSubset, LangChainAction |
| | except: |
| | from src.enums import DocumentSubset, LangChainAction |
| |
|
| | from tests.utils import get_inf_server |
| |
|
| | debug = False |
| |
|
| | os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' |
| |
|
| |
|
| | def get_client(serialize=not is_gradio_version4): |
| | from gradio_client import Client |
| |
|
| | client = Client(get_inf_server(), serialize=serialize) |
| | if debug: |
| | print(client.view_api(all_endpoints=True)) |
| | return client |
| |
|
| |
|
| | def get_args(prompt, prompt_type=None, chat=False, stream_output=False, |
| | max_new_tokens=50, |
| | top_k_docs=3, |
| | langchain_mode='Disabled', |
| | add_chat_history_to_context=True, |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | prompt_dict=None, |
| | version=None, |
| | h2ogpt_key=None, |
| | visible_models=None, |
| | system_prompt='', |
| | add_search_to_context=False, |
| | chat_conversation=None, |
| | text_context_list=None, |
| | document_choice=[], |
| | document_source_substrings=[], |
| | document_source_substrings_op='and', |
| | document_content_substrings=[], |
| | document_content_substrings_op='and', |
| | max_time=20, |
| | repetition_penalty=1.0, |
| | do_sample=True, |
| | ): |
| | from collections import OrderedDict |
| | kwargs = OrderedDict(instruction=prompt if chat else '', |
| | iinput='', |
| | context='', |
| | |
| | |
| | stream_output=stream_output, |
| | prompt_type=prompt_type, |
| | prompt_dict=prompt_dict, |
| | temperature=0.1, |
| | top_p=0.75, |
| | top_k=40, |
| | penalty_alpha=0, |
| | num_beams=1, |
| | max_new_tokens=max_new_tokens, |
| | min_new_tokens=0, |
| | early_stopping=False, |
| | max_time=max_time, |
| | repetition_penalty=repetition_penalty, |
| | num_return_sequences=1, |
| | do_sample=do_sample, |
| | chat=chat, |
| | instruction_nochat=prompt if not chat else '', |
| | iinput_nochat='', |
| | langchain_mode=langchain_mode, |
| | add_chat_history_to_context=add_chat_history_to_context, |
| | langchain_action=langchain_action, |
| | langchain_agents=langchain_agents, |
| | top_k_docs=top_k_docs, |
| | chunk=True, |
| | chunk_size=512, |
| | document_subset=DocumentSubset.Relevant.name, |
| | document_choice=[] or document_choice, |
| | document_source_substrings=[] or document_source_substrings, |
| | document_source_substrings_op='and' or document_source_substrings_op, |
| | document_content_substrings=[] or document_content_substrings, |
| | document_content_substrings_op='and' or document_content_substrings_op, |
| | pre_prompt_query=None, |
| | prompt_query=None, |
| | pre_prompt_summary=None, |
| | prompt_summary=None, |
| | hyde_llm_prompt=None, |
| | system_prompt=system_prompt, |
| | image_audio_loaders=None, |
| | pdf_loaders=None, |
| | url_loaders=None, |
| | jq_schema=None, |
| | extract_frames=None, |
| | llava_prompt=None, |
| | visible_models=visible_models, |
| | h2ogpt_key=h2ogpt_key, |
| | add_search_to_context=add_search_to_context, |
| | chat_conversation=chat_conversation, |
| | text_context_list=text_context_list, |
| | docs_ordering_type=None, |
| | min_max_new_tokens=None, |
| | max_input_tokens=None, |
| | max_total_input_tokens=None, |
| | docs_token_handling=None, |
| | docs_joiner=None, |
| | hyde_level=0, |
| | hyde_template=None, |
| | hyde_show_only_final=False, |
| | doc_json_mode=False, |
| |
|
| | chatbot_role='None', |
| | speaker='None', |
| | tts_language='autodetect', |
| | tts_speed=1.0, |
| | ) |
| | diff = 0 |
| | if version is None: |
| | |
| | version = 1 |
| | if version == 0: |
| | diff = 1 |
| | if version >= 1: |
| | kwargs.update(dict(system_prompt=system_prompt)) |
| | diff = 0 |
| |
|
| | from evaluate_params import eval_func_param_names |
| | assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == diff |
| | if chat: |
| | |
| | kwargs.update(dict(chatbot=[])) |
| |
|
| | return kwargs, list(kwargs.values()) |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic(prompt_type='human_bot', version=None, visible_models=None, prompt='Who are you?', |
| | h2ogpt_key=None): |
| | return run_client_nochat(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50, version=version, |
| | visible_models=visible_models, h2ogpt_key=h2ogpt_key) |
| |
|
| |
|
| | """ |
| | time HOST=https://gpt-internal.h2o.ai PYTHONPATH=. pytest -n 20 src/client_test.py::test_client_basic_benchmark |
| | 32 seconds to answer 20 questions at once with 70B llama2 on 4x A100 80GB using TGI 0.9.3 |
| | """ |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | @pytest.mark.parametrize("id", range(20)) |
| | def test_client_basic_benchmark(id, prompt_type='human_bot', version=None): |
| | return run_client_nochat(prompt=""" |
| | /nfs4/llm/h2ogpt/h2ogpt/bin/python /home/arno/pycharm-2022.2.2/plugins/python/helpers/pycharm/_jb_pytest_runner.py --target src/client_test.py::test_client_basic |
| | Testing started at 8:41 AM ... |
| | Launching pytest with arguments src/client_test.py::test_client_basic --no-header --no-summary -q in /nfs4/llm/h2ogpt |
| | |
| | ============================= test session starts ============================== |
| | collecting ... |
| | src/client_test.py:None (src/client_test.py) |
| | ImportError while importing test module '/nfs4/llm/h2ogpt/src/client_test.py'. |
| | Hint: make sure your test modules/packages have valid Python names. |
| | Traceback: |
| | h2ogpt/lib/python3.10/site-packages/_pytest/python.py:618: in _importtestmodule |
| | mod = import_path(self.path, mode=importmode, root=self.config.rootpath) |
| | h2ogpt/lib/python3.10/site-packages/_pytest/pathlib.py:533: in import_path |
| | importlib.import_module(module_name) |
| | /usr/lib/python3.10/importlib/__init__.py:126: in import_module |
| | return _bootstrap._gcd_import(name[level:], package, level) |
| | <frozen importlib._bootstrap>:1050: in _gcd_import |
| | ??? |
| | <frozen importlib._bootstrap>:1027: in _find_and_load |
| | ??? |
| | <frozen importlib._bootstrap>:1006: in _find_and_load_unlocked |
| | ??? |
| | <frozen importlib._bootstrap>:688: in _load_unlocked |
| | ??? |
| | h2ogpt/lib/python3.10/site-packages/_pytest/assertion/rewrite.py:168: in exec_module |
| | exec(co, module.__dict__) |
| | src/client_test.py:51: in <module> |
| | from enums import DocumentSubset, LangChainAction |
| | E ModuleNotFoundError: No module named 'enums' |
| | |
| | |
| | collected 0 items / 1 error |
| | |
| | =============================== 1 error in 0.14s =============================== |
| | ERROR: not found: /nfs4/llm/h2ogpt/src/client_test.py::test_client_basic |
| | (no name '/nfs4/llm/h2ogpt/src/client_test.py::test_client_basic' in any of [<Module client_test.py>]) |
| | |
| | |
| | Process finished with exit code 4 |
| | |
| | What happened? |
| | """, prompt_type=prompt_type, max_new_tokens=100, version=version) |
| |
|
| |
|
| | def run_client_nochat(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None, visible_models=None): |
| | kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version, |
| | visible_models=visible_models, h2ogpt_key=h2ogpt_key) |
| |
|
| | api_name = '/submit_nochat' |
| | client = get_client(serialize=not is_gradio_version4) |
| | res = client.predict( |
| | *tuple(args), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], |
| | response=md_to_text(res)) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic_api(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| | return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50, version=version, |
| | h2ogpt_key=h2ogpt_key) |
| |
|
| |
|
| | def run_client_nochat_api(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None): |
| | kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version, |
| | h2ogpt_key=h2ogpt_key) |
| |
|
| | api_name = '/submit_nochat_api' |
| | client = get_client(serialize=not is_gradio_version4) |
| | res = client.predict( |
| | str(dict(kwargs)), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], |
| | response=md_to_text(ast.literal_eval(res)['response']), |
| | sources=ast.literal_eval(res)['sources']) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic_api_lean(prompt='Who are you?', prompt_type='human_bot', version=None, h2ogpt_key=None, |
| | chat_conversation=None, system_prompt=''): |
| | return run_client_nochat_api_lean(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50, |
| | version=version, h2ogpt_key=h2ogpt_key, |
| | chat_conversation=chat_conversation, |
| | system_prompt=system_prompt) |
| |
|
| |
|
| | def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None, |
| | chat_conversation=None, system_prompt=''): |
| | kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key, chat_conversation=chat_conversation, |
| | system_prompt=system_prompt) |
| |
|
| | api_name = '/submit_nochat_api' |
| | client = get_client(serialize=not is_gradio_version4) |
| | res = client.predict( |
| | str(dict(kwargs)), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], |
| | response=md_to_text(ast.literal_eval(res)['response']), |
| | sources=ast.literal_eval(res)['sources'], |
| | h2ogpt_key=h2ogpt_key) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_basic_api_lean_morestuff(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| | return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50, |
| | version=version, h2ogpt_key=h2ogpt_key) |
| |
|
| |
|
| | def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512, version=None, |
| | h2ogpt_key=None): |
| | kwargs = dict( |
| | instruction='', |
| | iinput='', |
| | context='', |
| | stream_output=False, |
| | prompt_type=prompt_type, |
| | temperature=0.1, |
| | top_p=0.75, |
| | top_k=40, |
| | penalty_alpha=0, |
| | num_beams=1, |
| | max_new_tokens=1024, |
| | min_new_tokens=0, |
| | early_stopping=False, |
| | max_time=20, |
| | repetition_penalty=1.0, |
| | num_return_sequences=1, |
| | do_sample=True, |
| | chat=False, |
| | instruction_nochat=prompt, |
| | iinput_nochat='', |
| | langchain_mode='Disabled', |
| | add_chat_history_to_context=True, |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | top_k_docs=4, |
| | document_subset=DocumentSubset.Relevant.name, |
| | document_choice=[], |
| | document_source_substrings=[], |
| | document_source_substrings_op='and', |
| | document_content_substrings=[], |
| | document_content_substrings_op='and', |
| | h2ogpt_key=h2ogpt_key, |
| | add_search_to_context=False, |
| | ) |
| |
|
| | api_name = '/submit_nochat_api' |
| | client = get_client(serialize=not is_gradio_version4) |
| | res = client.predict( |
| | str(dict(kwargs)), |
| | api_name=api_name, |
| | ) |
| | print("Raw client result: %s" % res, flush=True) |
| | res_dict = dict(prompt=kwargs['instruction_nochat'], |
| | response=md_to_text(ast.literal_eval(res)['response']), |
| | sources=ast.literal_eval(res)['sources'], |
| | h2ogpt_key=h2ogpt_key) |
| | print(res_dict) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_chat(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| | return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50, |
| | langchain_mode='Disabled', |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | version=version, |
| | h2ogpt_key=h2ogpt_key) |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_chat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| | return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, |
| | stream_output=True, max_new_tokens=512, |
| | langchain_mode='Disabled', |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | version=version, |
| | h2ogpt_key=h2ogpt_key) |
| |
|
| |
|
| | def run_client_chat(prompt='', |
| | stream_output=None, |
| | max_new_tokens=128, |
| | langchain_mode='Disabled', |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | prompt_type=None, prompt_dict=None, |
| | version=None, |
| | h2ogpt_key=None, |
| | chat_conversation=None, |
| | system_prompt='', |
| | document_choice=[], |
| | document_content_substrings=[], |
| | document_content_substrings_op='and', |
| | document_source_substrings=[], |
| | document_source_substrings_op='and', |
| | top_k_docs=3, |
| | max_time=20, |
| | repetition_penalty=1.0, |
| | do_sample=True): |
| | client = get_client(serialize=False) |
| |
|
| | kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, |
| | max_new_tokens=max_new_tokens, |
| | langchain_mode=langchain_mode, |
| | langchain_action=langchain_action, |
| | langchain_agents=langchain_agents, |
| | prompt_dict=prompt_dict, |
| | version=version, |
| | h2ogpt_key=h2ogpt_key, |
| | chat_conversation=chat_conversation, |
| | system_prompt=system_prompt, |
| | document_choice=document_choice, |
| | document_source_substrings=document_source_substrings, |
| | document_source_substrings_op=document_source_substrings_op, |
| | document_content_substrings=document_content_substrings, |
| | document_content_substrings_op=document_content_substrings_op, |
| | top_k_docs=top_k_docs, |
| | max_time=max_time, |
| | repetition_penalty=repetition_penalty, |
| | do_sample=do_sample) |
| | return run_client(client, prompt, args, kwargs) |
| |
|
| |
|
| | def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): |
| | if is_gradio_version4: |
| | kwargs['answer_with_sources'] = True |
| | kwargs['show_accordions'] = True |
| | kwargs['append_sources_to_answer'] = True |
| | kwargs['append_sources_to_chat'] = False |
| | kwargs['show_link_in_sources'] = True |
| | res_dict, client = run_client_gen(client, kwargs, do_md_to_text=do_md_to_text) |
| | res_dict['response'] += str(res_dict['sources_str']) |
| | return res_dict, client |
| | |
| |
|
| | assert kwargs['chat'], "Chat mode only" |
| | res = client.predict(*tuple(args), api_name='/instruction') |
| | args[-1] += [res[-1]] |
| |
|
| | res_dict = kwargs |
| | res_dict['prompt'] = prompt |
| | if not kwargs['stream_output']: |
| | res = client.predict(*tuple(args), api_name='/instruction_bot') |
| | res_dict['response'] = res[0][-1][1] |
| | print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) |
| | return res_dict, client |
| | else: |
| | job = client.submit(*tuple(args), api_name='/instruction_bot') |
| | res1 = '' |
| | while not job.done(): |
| | outputs_list = job.communicator.job.outputs |
| | if outputs_list: |
| | res = job.communicator.job.outputs[-1] |
| | res1 = res[0][-1][-1] |
| | res1 = md_to_text(res1, do_md_to_text=do_md_to_text) |
| | print(res1) |
| | time.sleep(0.1) |
| | full_outputs = job.outputs() |
| | if verbose: |
| | print('job.outputs: %s' % str(full_outputs)) |
| | |
| | |
| | |
| | |
| | |
| | res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text) |
| | return res_dict, client |
| |
|
| |
|
| | @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| | def test_client_nochat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| | return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, |
| | stream_output=True, max_new_tokens=512, |
| | langchain_mode='Disabled', |
| | langchain_action=LangChainAction.QUERY.value, |
| | langchain_agents=[], |
| | version=version, |
| | h2ogpt_key=h2ogpt_key) |
| |
|
| |
|
| | def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, |
| | langchain_mode, langchain_action, langchain_agents, version=None, |
| | h2ogpt_key=None): |
| | client = get_client(serialize=False) |
| |
|
| | kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output, |
| | max_new_tokens=max_new_tokens, langchain_mode=langchain_mode, |
| | langchain_action=langchain_action, langchain_agents=langchain_agents, |
| | version=version, h2ogpt_key=h2ogpt_key) |
| | return run_client_gen(client, kwargs) |
| |
|
| |
|
| | def run_client_gen(client, kwargs, do_md_to_text=True): |
| | res_dict = kwargs |
| | res_dict['prompt'] = kwargs['instruction'] or kwargs['instruction_nochat'] |
| | if not kwargs['stream_output']: |
| | res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api') |
| | res_dict1 = ast.literal_eval(res) |
| | res_dict.update(res_dict1) |
| | print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) |
| | return res_dict, client |
| | else: |
| | job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api') |
| | while not job.done(): |
| | outputs_list = job.communicator.job.outputs |
| | if outputs_list: |
| | res = job.communicator.job.outputs[-1] |
| | res_dict1 = ast.literal_eval(res) |
| | print('Stream: %s' % res_dict1['response']) |
| | time.sleep(0.1) |
| | res_list = job.outputs() |
| | assert len(res_list) > 0, "No response, check server" |
| | res = res_list[-1] |
| | res_dict1 = ast.literal_eval(res) |
| | print('Final: %s' % res_dict1['response']) |
| | res_dict.update(res_dict1) |
| | return res_dict, client |
| |
|
| |
|
| | def md_to_text(md, do_md_to_text=True): |
| | if not do_md_to_text: |
| | return md |
| | assert md is not None, "Markdown is None" |
| | html = markdown.markdown(md) |
| | soup = BeautifulSoup(html, features='html.parser') |
| | return soup.get_text() |
| |
|
| |
|
| | def run_client_many(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| | kwargs = dict(prompt_type=prompt_type, version=version, h2ogpt_key=h2ogpt_key) |
| | ret1, _ = test_client_chat(**kwargs) |
| | ret2, _ = test_client_chat_stream(**kwargs) |
| | ret3, _ = test_client_nochat_stream(**kwargs) |
| | ret4, _ = test_client_basic(**kwargs) |
| | ret5, _ = test_client_basic_api(**kwargs) |
| | ret6, _ = test_client_basic_api_lean(**kwargs) |
| | ret7, _ = test_client_basic_api_lean_morestuff(**kwargs) |
| | return ret1, ret2, ret3, ret4, ret5, ret6, ret7 |
| |
|
| |
|
| | if __name__ == '__main__': |
| | run_client_many() |
| |
|