| from typing import List, Dict, Callable, Optional |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_community.document_loaders import ( |
| DirectoryLoader, |
| UnstructuredMarkdownLoader, |
| PyPDFLoader, |
| TextLoader |
| ) |
| import os |
| import requests |
| import base64 |
| from PIL import Image |
| import io |
|
|
| class DocumentLoader: |
| """通用文档加载器""" |
| def __init__(self, file_path: str, original_filename: str = None): |
| self.file_path = file_path |
| |
| self.original_filename = original_filename or os.path.basename(file_path) |
| |
| self.extension = os.path.splitext(self.original_filename)[1].lower() |
| self.api_key = os.getenv("API_KEY") |
| self.api_base = os.getenv("BASE_URL") |
| |
| def process_image(self, image_path: str) -> str: |
| """使用 SiliconFlow VLM 模型处理图片""" |
| try: |
| |
| with open(image_path, 'rb') as image_file: |
| image_data = image_file.read() |
| base64_image = base64.b64encode(image_data).decode('utf-8') |
| |
| |
| headers = { |
| "Authorization": f"Bearer {self.api_key}", |
| "Content-Type": "application/json" |
| } |
| |
| response = requests.post( |
| f"{self.api_base}/chat/completions", |
| headers=headers, |
| json={ |
| "model": "Qwen/Qwen2.5-VL-72B-Instruct", |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "image_url", |
| "image_url": { |
| "url": f"data:image/jpeg;base64,{base64_image}", |
| "detail": "high" |
| } |
| }, |
| { |
| "type": "text", |
| "text": "请详细描述这张图片的内容,包括主要对象、场景、活动、颜色、布局等关键信息。" |
| } |
| ] |
| } |
| ], |
| "temperature": 0.7, |
| "max_tokens": 500 |
| } |
| ) |
| |
| if response.status_code != 200: |
| raise Exception(f"图片处理API调用失败: {response.text}") |
| |
| description = response.json()["choices"][0]["message"]["content"] |
| return description |
| |
| except Exception as e: |
| print(f"处理图片时出错: {str(e)}") |
| return "图片处理失败" |
| |
| def load(self): |
| try: |
| print(f"正在加载文件: {self.file_path}, 原始文件名: {self.original_filename}, 扩展名: {self.extension}") |
| |
| if self.extension == '.md': |
| try: |
| loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8') |
| return loader.load() |
| except UnicodeDecodeError: |
| |
| loader = UnstructuredMarkdownLoader(self.file_path, encoding='gbk') |
| return loader.load() |
| elif self.extension == '.pdf': |
| loader = PyPDFLoader(self.file_path) |
| return loader.load() |
| elif self.extension == '.txt': |
| try: |
| loader = TextLoader(self.file_path, encoding='utf-8') |
| return loader.load() |
| except UnicodeDecodeError: |
| |
| loader = TextLoader(self.file_path, encoding='gbk') |
| return loader.load() |
| elif self.extension in ['.png', '.jpg', '.jpeg', '.gif', '.bmp']: |
| |
| description = self.process_image(self.file_path) |
| |
| from langchain.schema import Document |
| doc = Document( |
| page_content=description, |
| metadata={ |
| 'source': self.file_path, |
| 'file_name': self.original_filename, |
| 'img_url': os.path.abspath(self.file_path) |
| } |
| ) |
| return [doc] |
| else: |
| print(f"不支持的文件扩展名: {self.extension}") |
| raise ValueError(f"不支持的文件格式: {self.extension}") |
| |
| except UnicodeDecodeError: |
| |
| print(f"文件编码错误,尝试其他编码: {self.file_path}") |
| if self.extension in ['.md', '.txt']: |
| try: |
| loader = TextLoader(self.file_path, encoding='gbk') |
| return loader.load() |
| except Exception as e: |
| print(f"尝试GBK编码也失败: {str(e)}") |
| raise |
| except Exception as e: |
| print(f"加载文件 {self.file_path} 时出错: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|
| class DocumentProcessor: |
| def __init__(self): |
| self.text_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200, |
| length_function=len, |
| ) |
| |
| def get_index_name(self, path: str) -> str: |
| """根据文件路径生成索引名称""" |
| if os.path.isdir(path): |
| |
| return f"rag_{os.path.basename(path).lower()}" |
| else: |
| |
| return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}" |
| |
| def process(self, path: str, progress_callback: Optional[Callable] = None, original_filename: str = None) -> List[Dict]: |
| """ |
| 加载并处理文档,支持目录或单个文件 |
| 参数: |
| path: 文档路径 |
| progress_callback: 进度回调函数,用于报告处理进度 |
| original_filename: 原始文件名(包括中文) |
| 返回:处理后的文档列表 |
| """ |
| if os.path.isdir(path): |
| documents = [] |
| total_files = sum([len(files) for _, _, files in os.walk(path)]) |
| processed_files = 0 |
| processed_size = 0 |
| |
| for root, _, files in os.walk(path): |
| for file in files: |
| file_path = os.path.join(root, file) |
| try: |
| |
| if progress_callback: |
| file_size = os.path.getsize(file_path) |
| processed_size += file_size |
| processed_files += 1 |
| progress_callback(processed_size, f"处理文件 {processed_files}/{total_files}: {file}") |
| |
| |
| loader = DocumentLoader(file_path, original_filename=file) |
| docs = loader.load() |
| |
| for doc in docs: |
| doc.metadata['file_name'] = file |
| documents.extend(docs) |
| except Exception as e: |
| print(f"警告:加载文件 {file_path} 时出错: {str(e)}") |
| continue |
| else: |
| try: |
| if progress_callback: |
| file_size = os.path.getsize(path) |
| progress_callback(file_size * 0.3, f"加载文件: {original_filename or os.path.basename(path)}") |
| |
| |
| loader = DocumentLoader(path, original_filename=original_filename) |
| documents = loader.load() |
| |
| |
| if progress_callback: |
| progress_callback(file_size * 0.6, f"处理文件内容...") |
| |
| |
| file_name = original_filename or os.path.basename(path) |
| for doc in documents: |
| doc.metadata['file_name'] = file_name |
| except Exception as e: |
| print(f"加载文件时出错: {str(e)}") |
| raise |
| |
| |
| chunks = self.text_splitter.split_documents(documents) |
| |
| |
| if progress_callback: |
| if os.path.isdir(path): |
| progress_callback(processed_size, f"文档分块完成,共{len(chunks)}个文档片段") |
| else: |
| file_size = os.path.getsize(path) |
| progress_callback(file_size * 0.9, f"文档分块完成,共{len(chunks)}个文档片段") |
| |
| |
| processed_docs = [] |
| for i, chunk in enumerate(chunks): |
| processed_docs.append({ |
| 'id': f'doc_{i}', |
| 'content': chunk.page_content, |
| 'metadata': chunk.metadata |
| }) |
| |
| return processed_docs |