| | import os |
| | from typing import List |
| |
|
| | import pinecone |
| | from tqdm.auto import tqdm |
| | from uuid import uuid4 |
| | import arxiv |
| |
|
| | from langchain.document_loaders import PyPDFLoader |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain.embeddings.openai import OpenAIEmbeddings |
| | from langchain.embeddings import CacheBackedEmbeddings |
| | from langchain.storage import LocalFileStore |
| | from langchain.vectorstores import Pinecone |
| |
|
| | INDEX_BATCH_LIMIT = 100 |
| |
|
| | class CharacterTextSplitter: |
| | def __init__( |
| | self, |
| | chunk_size: int = 1000, |
| | chunk_overlap: int = 200, |
| | ): |
| | assert ( |
| | chunk_size > chunk_overlap |
| | ), "Chunk size must be greater than chunk overlap" |
| |
|
| | self.chunk_size = chunk_size |
| | self.chunk_overlap = chunk_overlap |
| |
|
| | self.text_splitter = RecursiveCharacterTextSplitter( |
| | chunk_size = self.chunk_size, |
| | chunk_overlap = self.chunk_overlap, |
| | length_function = len, |
| |
|
| | ) |
| |
|
| | def split(self, text: str) -> List[str]: |
| | return self.text_splitter.split_text(text) |
| |
|
| | class ArxivLoader: |
| |
|
| | def __init__(self, query : str = "Nuclear Fission", max_results : int = 5, encoding: str = "utf-8"): |
| | """""" |
| | self.query = query |
| | self.max_results = max_results |
| | |
| | self.paper_urls = [] |
| | self.documents = [] |
| | self.splitter = CharacterTextSplitter() |
| |
|
| | def retrieve_urls(self): |
| | """""" |
| | arxiv_client = arxiv.Client() |
| | search = arxiv.Search( |
| | query = self.query, |
| | max_results = self.max_results, |
| | sort_by = arxiv.SortCriterion.Relevance |
| | ) |
| |
|
| | for result in arxiv_client.results(search): |
| | self.paper_urls.append(result.pdf_url) |
| |
|
| | def load_documents(self): |
| | """""" |
| | for paper_url in self.paper_urls: |
| | loader = PyPDFLoader(paper_url) |
| | |
| | self.documents.append(loader.load()) |
| |
|
| | def format_document(self, document): |
| | """""" |
| | metadata = { |
| | 'source_document' : document.metadata["source"], |
| | 'page_number' : document.metadata["page"] |
| | } |
| |
|
| | record_texts = self.splitter.split(document.page_content) |
| | record_metadatas = [{ |
| | "chunk": j, "text": text, **metadata |
| | } for j, text in enumerate(record_texts)] |
| |
|
| | return record_texts, record_metadatas |
| | |
| | def main(self): |
| | """""" |
| | self.retrieve_urls() |
| | self.load_documents() |
| |
|
| |
|
| | class PineconeIndexer: |
| | |
| | def __init__(self, index_name : str = "arxiv-paper-index", metric : str = "cosine", n_dims : int = 1536): |
| | """""" |
| | pinecone.init( |
| | api_key=os.environ["PINECONE_API_KEY"], |
| | environment=os.environ["PINECONE_ENV"] |
| | ) |
| | |
| | if index_name not in pinecone.list_indexes(): |
| | |
| | pinecone.create_index( |
| | name=index_name, |
| | metric=metric, |
| | dimension=n_dims |
| | ) |
| |
|
| | self.arxiv_loader = ArxivLoader() |
| | |
| | self.index = pinecone.Index(index_name) |
| |
|
| | def load_embedder(self): |
| | """""" |
| | store = LocalFileStore("./cache/") |
| | |
| | core_embeddings_model = OpenAIEmbeddings() |
| |
|
| | self.embedder = CacheBackedEmbeddings.from_bytes_store( |
| | core_embeddings_model, |
| | store, |
| | namespace=core_embeddings_model.model |
| | ) |
| |
|
| | def upsert(self, texts, metadatas): |
| | """""" |
| | ids = [str(uuid4()) for _ in range(len(texts))] |
| | embeds = self.embedder.embed_documents(texts) |
| | self.index.upsert(vectors=zip(ids, embeds, metadatas)) |
| |
|
| | def index_documents(self, documents, batch_limit : int = INDEX_BATCH_LIMIT): |
| | """""" |
| | texts = [] |
| | metadatas = [] |
| |
|
| | |
| | for i in tqdm(range(len(documents))): |
| |
|
| | |
| | for page in documents[i] : |
| |
|
| | record_texts, record_metadatas = self.arxiv_loader.format_document(page) |
| |
|
| | texts.extend(record_texts) |
| | metadatas.extend(record_metadatas) |
| | |
| | if len(texts) >= batch_limit: |
| | self.upsert(texts, metadatas) |
| |
|
| | texts = [] |
| | metadatas = [] |
| |
|
| | if len(texts) > 0: |
| | self.upsert(texts, metadatas) |
| |
|
| | def get_vectorstore(self): |
| | """""" |
| | return Pinecone(self.index, self.embedder.embed_query, "text") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | print("-------------- Loading Arxiv --------------") |
| | axloader = ArxivLoader() |
| | axloader.retrieve_urls() |
| | axloader.load_documents() |
| |
|
| | print("\n-------------- Splitting sample doc --------------") |
| | sample_doc = axloader.documents[0] |
| | sample_page = sample_doc[0] |
| |
|
| | splitter = CharacterTextSplitter() |
| | chunks = splitter.split(sample_page.page_content) |
| | print(len(chunks)) |
| | print(chunks[0]) |
| |
|
| | print("\n-------------- testing pinecode indexer --------------") |
| |
|
| | pi = PineconeIndexer() |
| | pi.load_embedder() |
| | pi.index_documents(axloader.documents) |
| |
|
| | print(pi.index.describe_index_stats()) |
| |
|