| import time |
| import re |
| import pandas as pd |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer, AutoModel |
| from tokenizers import Tokenizer, AddedToken |
| import streamlit as st |
| from st_click_detector import click_detector |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| def update_params(): |
| try: |
| print("update1") |
| |
| except ValueError: |
| pass |
| |
| |
| |
| options = ["artificial intelligence", "robot", "VR", "medicine", "genomics", "cure", "heal", "brain", "support", "friendship", "memory", "aging", "pharma", "virus", "nurse", "doctor", "therapist", "nutrition", "technology", "computer", "software", "neuroscience", "birth", "death", "soul", "space", "sci-fi"] |
|
|
| query_params = st.experimental_get_query_params() |
| ix = 0 |
| if query_params: |
| try: |
| q0 = query_params['query'][0] |
| ix = options.index(q0) |
| except ValueError: |
| pass |
| selected_option = st.radio( |
| "Param", options, index=ix, key="query", on_change=update_params |
| ) |
| st.write("<style>div.row-widget.stRadio > div{flex-direction:row;}</style>", unsafe_allow_html=True) |
|
|
|
|
| st.experimental_set_query_params(option=selected_option) |
|
|
| try: |
| st.session_state.query = query |
| except: |
| print("Error cant set after init") |
|
|
|
|
| |
| |
| if 'query' not in st.session_state: |
| |
| query = st.text_input("", value="artificial intelligence", key="query") |
| |
| |
| else: |
| query = st.text_input("", value=st.session_state["query"], key="query") |
| try: |
| query_params = st.experimental_get_query_params() |
| query_option = query_params['query'][0] |
| option_selected = st.sidebar.selectbox('Pick option', options, index=options.index(query_option)) |
| except: |
| st.experimental_set_query_params(query="health") |
| query_params = st.experimental_get_query_params() |
| query_option = query_params['query'][0] |
| query_option = "ai" |
| |
| DEVICE = "cpu" |
| MODEL_OPTIONS = ["msmarco-distilbert-base-tas-b", "all-mpnet-base-v2"] |
| DESCRIPTION = """ |
| # Semantic search |
| **Enter your query and hit enter** |
| Built with 🤗 Hugging Face's [transformers](https://huggingface.co/transformers/) library, [SentenceBert](https://www.sbert.net/) models, [Streamlit](https://streamlit.io/) and 44k movie descriptions from the Kaggle [Movies Dataset](https://www.kaggle.com/rounakbanik/the-movies-dataset) |
| """ |
|
|
| |
| if 'key' not in st.session_state: |
| st.session_state['key'] = 'value' |
| if 'key' not in st.session_state: |
| st.session_state.key = 'value' |
| st.write(st.session_state.key) |
| st.write(st.session_state) |
|
|
| |
| for key in st.session_state.keys(): |
| del st.session_state[key] |
| |
| |
|
|
| @st.cache( |
| show_spinner=False, |
| hash_funcs={ |
| AutoModel: lambda _: None, |
| AutoTokenizer: lambda _: None, |
| dict: lambda _: None, |
| }, |
| ) |
| def load(): |
| models, tokenizers, embeddings = [], [], [] |
| for model_option in MODEL_OPTIONS: |
| tokenizers.append( |
| AutoTokenizer.from_pretrained(f"sentence-transformers/{model_option}") |
| ) |
| models.append( |
| AutoModel.from_pretrained(f"sentence-transformers/{model_option}").to( |
| DEVICE |
| ) |
| ) |
| embeddings.append(np.load("embeddings.npy")) |
| embeddings.append(np.load("embeddings2.npy")) |
| df = pd.read_csv("movies.csv") |
| return tokenizers, models, embeddings, df |
|
|
| tokenizers, models, embeddings, df = load() |
| def pooling(model_output): |
| return model_output.last_hidden_state[:, 0] |
|
|
| def compute_embeddings(texts): |
| encoded_input = tokenizers[0]( |
| texts, padding=True, truncation=True, return_tensors="pt" |
| ).to(DEVICE) |
|
|
| with torch.no_grad(): |
| model_output = models[0](**encoded_input, return_dict=True) |
|
|
| embeddings = pooling(model_output) |
| return embeddings.cpu().numpy() |
|
|
| def pooling2(model_output, attention_mask): |
| token_embeddings = model_output[0] |
| input_mask_expanded = ( |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| ) |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| input_mask_expanded.sum(1), min=1e-9 |
| ) |
|
|
| def compute_embeddings2(list_of_strings): |
| encoded_input = tokenizers[1]( |
| list_of_strings, padding=True, truncation=True, return_tensors="pt" |
| ).to(DEVICE) |
| with torch.no_grad(): |
| model_output = models[1](**encoded_input) |
| sentence_embeddings = pooling2(model_output, encoded_input["attention_mask"]) |
| return F.normalize(sentence_embeddings, p=2, dim=1).cpu().numpy() |
|
|
| @st.cache( |
| show_spinner=False, |
| hash_funcs={Tokenizer: lambda _: None, AddedToken: lambda _: None}, |
| ) |
| def semantic_search(query, model_id): |
| start = time.time() |
| if len(query.strip()) == 0: |
| return "" |
| if "[Similar:" not in query: |
| if model_id == 0: |
| query_embedding = compute_embeddings([query]) |
| else: |
| query_embedding = compute_embeddings2([query]) |
| else: |
| match = re.match(r"\[Similar:(\d{1,5}).*", query) |
| if match: |
| idx = int(match.groups()[0]) |
| query_embedding = embeddings[model_id][idx : idx + 1, :] |
| if query_embedding.shape[0] == 0: |
| return "" |
| else: |
| return "" |
| indices = np.argsort(embeddings[model_id] @ np.transpose(query_embedding)[:, 0])[ |
| -1:-11:-1 |
| ] |
| if len(indices) == 0: |
| return "" |
| result = "<ol>" |
| for i in indices: |
| result += f"<li style='padding-top: 10px'><b>{df.iloc[i].title}</b> ({df.iloc[i].release_date}). {df.iloc[i].overview} " |
| |
| |
| delay = "%.3f" % (time.time() - start) |
| return f"<p><i>Computation time: {delay} seconds</i></p>{result}</ol>" |
|
|
| st.sidebar.markdown(DESCRIPTION) |
|
|
| model_choice = st.sidebar.selectbox("Similarity model", options=MODEL_OPTIONS) |
| model_id = 0 if model_choice == MODEL_OPTIONS[0] else 1 |
|
|
| clicked = click_detector(semantic_search(query, model_id)) |
|
|
| if clicked != "": |
| st.markdown(clicked) |
| change_query = False |
| if "last_clicked" not in st.session_state: |
| st.session_state["last_clicked"] = clicked |
| change_query = True |
| else: |
| if clicked != st.session_state["last_clicked"]: |
| st.session_state["last_clicked"] = clicked |
| change_query = True |
| if change_query: |
| st.session_state["query"] = f"[Similar:{clicked}] {df.iloc[int(clicked)].title}" |
| st.experimental_rerun() |
|
|