"""
生成示例数据脚本
用于测试审核系统
"""
import os
import json
from pathlib import Path
def create_sample_dataset():
"""创建示例数据集"""
base_path = Path("./dataset")
# 示例数据配置
sources = ["Apache_Echarts", "Plotly", "ChartJS"]
chart_types = {
"Apache_Echarts": ["bar", "line", "pie"],
"Plotly": ["scatter", "bar", "heatmap"],
"ChartJS": ["line", "doughnut", "radar"]
}
models = ["gpt-4", "claude-3", "gemini-pro"]
for source in sources:
for chart_type in chart_types[source]:
# 创建目录
web_dir = base_path / "web" / source / chart_type
label_dir = base_path / "label" / source / chart_type
web_dir.mkdir(parents=True, exist_ok=True)
label_dir.mkdir(parents=True, exist_ok=True)
for model in models:
qa_dir = base_path / "question_answer" / source / chart_type / model
qa_dir.mkdir(parents=True, exist_ok=True)
# 为每个图表类型创建示例图表
for i in range(1, 4):
chart_id = f"chart_{str(i).zfill(4)}_{chart_type}"
# 创建 HTML 文件
html_content = f"""
{chart_id}
示例图表 - {source} - {chart_type} #{i}
"""
with open(web_dir / f"{chart_id}.html", "w", encoding="utf-8") as f:
f.write(html_content)
# 创建标签文件
label_data = {
"Number": str(i).zfill(4),
"Type": chart_type,
"Source": source,
"Weblink": f"https://example.com/{source}/{chart_type}/{i}",
"Topic": f"Sample {chart_type} chart #{i}",
"Describe": f"This is a sample {chart_type} chart for testing the review system. It demonstrates the visualization capabilities of {source}.",
"Other": ""
}
with open(label_dir / f"{chart_id}.json", "w", encoding="utf-8") as f:
json.dump(label_data, f, ensure_ascii=False, indent=2)
# 为每个模型创建 QA 文件
for j, model in enumerate(models):
qa_dir = base_path / "question_answer" / source / chart_type / model
for q in range(1, 3):
qa_data = {
"id": f"{chart_id}_q{q}",
"chart": chart_id,
"question": f"在图表 {chart_id} 中,第 {q} 个数据点的值是多少?",
"answer": f"约为 {int(50 + q * 10 + j * 5)}"
}
with open(qa_dir / f"{chart_id}_q{q}.json", "w", encoding="utf-8") as f:
json.dump(qa_data, f, ensure_ascii=False, indent=2)
print("✅ 示例数据集创建完成!")
print(f"📁 数据集位置: {base_path.absolute()}")
# 打印统计
total_charts = sum(len(chart_types[s]) * 3 for s in sources)
total_qa = total_charts * len(models) * 2
print(f"📊 共创建 {total_charts} 个图表")
print(f"❓ 共创建 {total_qa} 个问答对")
if __name__ == "__main__":
create_sample_dataset()