| import streamlit as st |
| from persist import persist, load_widget_state |
| from modelcards import CardData, ModelCard |
| from huggingface_hub import create_repo |
|
|
|
|
| def is_float(value): |
| try: |
| float(value) |
| return True |
| except: |
| return False |
|
|
| def get_card(): |
| languages=st.session_state.languages or None |
| license=st.session_state.license or None |
| library_name = st.session_state.library_name or None |
| tags= [x.strip() for x in st.session_state.tags.split(',') if x.strip()] |
| tags.append("autogenerated-modelcard") |
| datasets= [x.strip() for x in st.session_state.datasets.split(',') if x.strip()] or None |
| metrics=st.session_state.metrics or None |
| model_name = st.session_state.model_name or None |
| model_description = st.session_state.model_description or None |
| |
| |
| authors = st.session_state.authors or None |
| paper_url = st.session_state.paper_url or None |
| github_url = st.session_state.github_url or None |
| bibtex_citations = st.session_state.bibtex_citations or None |
| emissions = float(st.session_state.emissions) if is_float(st.session_state.emissions) else None |
|
|
| |
| do_warn = False |
| warning_msg = "Warning: The following fields are required but have not been filled in: " |
| if not languages: |
| warning_msg += "\n- Languages" |
| do_warn = True |
| if not license: |
| warning_msg += "\n- License" |
| do_warn = True |
| if do_warn: |
| st.error(warning_msg) |
| st.stop() |
|
|
| |
| card_data = CardData( |
| language=languages, |
| license=license, |
| library_name=library_name, |
| tags=tags, |
| datasets=datasets, |
| metrics=metrics, |
| ) |
| if emissions: |
| card_data.co2_eq_emissions = {'emissions': emissions} |
|
|
| card = ModelCard.from_template( |
| card_data, |
| template_path='template.md', |
| model_id=model_name, |
| |
| model_description=model_description, |
| license=license, |
| authors=authors, |
| paper_url=paper_url, |
| github_url=github_url, |
| bibtex_citations=bibtex_citations, |
| emissions=emissions |
| ) |
| return card |
|
|
|
|
| def main(): |
|
|
| card = get_card() |
| card.save('current_card.md') |
| view_raw = st.sidebar.checkbox("View Raw") |
| if view_raw: |
| st.text(card) |
| else: |
| st.markdown(card.text, unsafe_allow_html=True) |
|
|
| with st.sidebar: |
| with st.form("Upload to 🤗 Hub"): |
| st.markdown("Use a token with write access from [here](https://hf.co/settings/tokens)") |
| token = st.text_input("Token", type='password') |
| repo_id = st.text_input("Repo ID") |
| submit = st.form_submit_button('Upload to 🤗 Hub') |
|
|
| if submit: |
| if len(repo_id.split('/')) == 2: |
| repo_url = create_repo(repo_id, exist_ok=True, token=token) |
| card.push_to_hub(repo_id, token=token) |
| st.success(f"Pushed the card to the repo [here]({repo_url}!") |
| else: |
| st.error("Repo ID invalid. It should be username/repo-name. For example: nateraw/food") |
|
|
|
|
| if __name__ == "__main__": |
| load_widget_state() |
| main() |