Files
ChouJuGEO/data_storage.py
T

454 lines
18 KiB
Python
Raw Normal View History

2026-01-23 15:43:03 +08:00
"""
轻量级数据持久化模块 - MVP版本
支持 SQLite 和 JSON 两种存储方式
"""
import sqlite3
import json
import os
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Optional, Any
import pandas as pd
class DataStorage:
"""统一的数据存储接口,支持SQLite和JSON两种后端"""
def __init__(self, storage_type: str = "sqlite", db_path: str = "geo_data.db"):
"""
Args:
storage_type: "sqlite""json"
db_path: SQLite数据库路径,或JSON文件目录
"""
self.storage_type = storage_type
self.db_path = db_path
if storage_type == "sqlite":
self._init_sqlite()
else:
self._init_json()
def _init_sqlite(self):
"""初始化SQLite数据库"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 关键词表
cursor.execute("""
CREATE TABLE IF NOT EXISTS keywords (
id INTEGER PRIMARY KEY AUTOINCREMENT,
keyword TEXT NOT NULL,
brand TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# 内容表(生成的文章)
cursor.execute("""
CREATE TABLE IF NOT EXISTS articles (
id INTEGER PRIMARY KEY AUTOINCREMENT,
keyword TEXT,
platform TEXT,
content TEXT,
filename TEXT,
brand TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# 优化记录表
cursor.execute("""
CREATE TABLE IF NOT EXISTS optimizations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
original_content TEXT,
optimized_content TEXT,
changes TEXT,
platform TEXT,
brand TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# 验证结果表
cursor.execute("""
CREATE TABLE IF NOT EXISTS verify_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
query TEXT,
brand TEXT,
verify_model TEXT,
mention_count INTEGER,
mention_position TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
conn.close()
def _init_json(self):
"""初始化JSON存储目录"""
Path(self.db_path).mkdir(parents=True, exist_ok=True)
# ==================== 关键词相关 ====================
def save_keywords(self, keywords: List[str], brand: str):
"""保存关键词列表"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
for keyword in keywords:
cursor.execute(
"INSERT INTO keywords (keyword, brand) VALUES (?, ?)",
(keyword, brand)
)
conn.commit()
conn.close()
else:
# JSON方式:追加到文件
json_file = Path(self.db_path) / "keywords.json"
data = []
if json_file.exists():
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
for keyword in keywords:
data.append({
"keyword": keyword,
"brand": brand,
"created_at": datetime.now().isoformat()
})
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_keywords(self, brand: Optional[str] = None) -> List[str]:
"""获取关键词列表"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
if brand:
cursor.execute("SELECT keyword FROM keywords WHERE brand = ?", (brand,))
else:
cursor.execute("SELECT keyword FROM keywords")
keywords = [row[0] for row in cursor.fetchall()]
conn.close()
return keywords
else:
json_file = Path(self.db_path) / "keywords.json"
if not json_file.exists():
return []
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if brand:
return [item["keyword"] for item in data if item.get("brand") == brand]
return [item["keyword"] for item in data]
# ==================== 文章内容相关 ====================
def save_article(self, keyword: str, platform: str, content: str,
filename: str, brand: str):
"""保存生成的文章"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO articles (keyword, platform, content, filename, brand)
VALUES (?, ?, ?, ?, ?)
""", (keyword, platform, content, filename, brand))
conn.commit()
conn.close()
else:
json_file = Path(self.db_path) / "articles.json"
data = []
if json_file.exists():
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
data.append({
"keyword": keyword,
"platform": platform,
"content": content,
"filename": filename,
"brand": brand,
"created_at": datetime.now().isoformat()
})
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_articles(self, brand: Optional[str] = None,
platform: Optional[str] = None) -> List[Dict]:
"""获取文章列表"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
if brand and platform:
df = pd.read_sql_query(
"SELECT * FROM articles WHERE brand = ? AND platform = ?",
conn, params=(brand, platform)
)
elif brand:
df = pd.read_sql_query(
"SELECT * FROM articles WHERE brand = ?",
conn, params=(brand,)
)
else:
df = pd.read_sql_query("SELECT * FROM articles", conn)
conn.close()
return df.to_dict('records')
else:
json_file = Path(self.db_path) / "articles.json"
if not json_file.exists():
return []
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if brand and platform:
return [item for item in data
if item.get("brand") == brand and item.get("platform") == platform]
elif brand:
return [item for item in data if item.get("brand") == brand]
return data
# ==================== 优化记录相关 ====================
def save_optimization(self, original_content: str, optimized_content: str,
changes: str, platform: str, brand: str):
"""保存优化记录"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
INSERT INTO optimizations
(original_content, optimized_content, changes, platform, brand)
VALUES (?, ?, ?, ?, ?)
""", (original_content, optimized_content, changes, platform, brand))
conn.commit()
conn.close()
else:
json_file = Path(self.db_path) / "optimizations.json"
data = []
if json_file.exists():
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
data.append({
"original_content": original_content,
"optimized_content": optimized_content,
"changes": changes,
"platform": platform,
"brand": brand,
"created_at": datetime.now().isoformat()
})
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_optimizations(self, brand: Optional[str] = None) -> List[Dict]:
"""获取优化记录"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
if brand:
df = pd.read_sql_query(
"SELECT * FROM optimizations WHERE brand = ? ORDER BY created_at DESC",
conn, params=(brand,)
)
else:
df = pd.read_sql_query(
"SELECT * FROM optimizations ORDER BY created_at DESC",
conn
)
conn.close()
return df.to_dict('records')
else:
json_file = Path(self.db_path) / "optimizations.json"
if not json_file.exists():
return []
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if brand:
return [item for item in data if item.get("brand") == brand]
return data
# ==================== 验证结果相关 ====================
def save_verify_results(self, results: List[Dict]):
"""批量保存验证结果"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
for result in results:
cursor.execute("""
INSERT INTO verify_results
(query, brand, verify_model, mention_count, mention_position)
VALUES (?, ?, ?, ?, ?)
""", (
result.get("问题"),
result.get("品牌"),
result.get("验证模型"),
result.get("提及次数"),
result.get("位置")
))
conn.commit()
conn.close()
else:
json_file = Path(self.db_path) / "verify_results.json"
data = []
if json_file.exists():
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
for result in results:
data.append({
"query": result.get("问题"),
"brand": result.get("品牌"),
"verify_model": result.get("验证模型"),
"mention_count": result.get("提及次数"),
"mention_position": result.get("位置"),
"created_at": datetime.now().isoformat()
})
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def get_verify_results(self, brand: Optional[str] = None, include_timestamp: bool = False) -> pd.DataFrame:
"""获取验证结果(返回DataFrame
Args:
brand: 品牌名称,如果为None则返回所有品牌
include_timestamp: 是否包含时间戳字段
"""
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
if include_timestamp:
if brand:
df = pd.read_sql_query(
"""SELECT query as "问题", brand as "品牌", verify_model as "验证模型",
mention_count as "提及次数", mention_position as "位置",
created_at as "验证时间"
FROM verify_results WHERE brand = ? ORDER BY created_at DESC""",
conn, params=(brand,)
)
else:
df = pd.read_sql_query(
"""SELECT query as "问题", brand as "品牌", verify_model as "验证模型",
mention_count as "提及次数", mention_position as "位置",
created_at as "验证时间"
FROM verify_results ORDER BY created_at DESC""",
conn
)
else:
if brand:
df = pd.read_sql_query(
"""SELECT query as "问题", brand as "品牌", verify_model as "验证模型",
mention_count as "提及次数", mention_position as "位置"
FROM verify_results WHERE brand = ?""",
conn, params=(brand,)
)
else:
df = pd.read_sql_query(
"""SELECT query as "问题", brand as "品牌", verify_model as "验证模型",
mention_count as "提及次数", mention_position as "位置"
FROM verify_results""",
conn
)
conn.close()
if include_timestamp and not df.empty and "验证时间" in df.columns:
df["验证时间"] = pd.to_datetime(df["验证时间"])
return df
else:
json_file = Path(self.db_path) / "verify_results.json"
if not json_file.exists():
return pd.DataFrame()
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if brand:
data = [item for item in data if item.get("brand") == brand]
# 转换为DataFrame格式
records = []
for item in data:
record = {
"问题": item.get("query"),
"品牌": item.get("brand"),
"验证模型": item.get("verify_model"),
"提及次数": item.get("mention_count"),
"位置": item.get("mention_position")
}
if include_timestamp and "created_at" in item:
record["验证时间"] = pd.to_datetime(item.get("created_at"))
records.append(record)
df = pd.DataFrame(records)
if include_timestamp and not df.empty and "验证时间" in df.columns:
df = df.sort_values("验证时间", ascending=False)
return df
# ==================== 统计功能 ====================
def get_stats(self, brand: Optional[str] = None) -> Dict[str, Any]:
"""获取统计数据"""
stats = {}
if self.storage_type == "sqlite":
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 关键词数量
if brand:
cursor.execute("SELECT COUNT(*) FROM keywords WHERE brand = ?", (brand,))
else:
cursor.execute("SELECT COUNT(*) FROM keywords")
stats["keywords_count"] = cursor.fetchone()[0]
# 文章数量
if brand:
cursor.execute("SELECT COUNT(*) FROM articles WHERE brand = ?", (brand,))
else:
cursor.execute("SELECT COUNT(*) FROM articles")
stats["articles_count"] = cursor.fetchone()[0]
# 优化记录数量
if brand:
cursor.execute("SELECT COUNT(*) FROM optimizations WHERE brand = ?", (brand,))
else:
cursor.execute("SELECT COUNT(*) FROM optimizations")
stats["optimizations_count"] = cursor.fetchone()[0]
# 验证结果数量
if brand:
cursor.execute("SELECT COUNT(*) FROM verify_results WHERE brand = ?", (brand,))
else:
cursor.execute("SELECT COUNT(*) FROM verify_results")
stats["verify_results_count"] = cursor.fetchone()[0]
conn.close()
else:
# JSON方式统计
keywords_file = Path(self.db_path) / "keywords.json"
articles_file = Path(self.db_path) / "articles.json"
optimizations_file = Path(self.db_path) / "optimizations.json"
verify_file = Path(self.db_path) / "verify_results.json"
def count_json(file_path, brand_filter=None):
if not file_path.exists():
return 0
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if brand_filter:
return len([item for item in data if item.get("brand") == brand_filter])
return len(data)
stats["keywords_count"] = count_json(keywords_file, brand)
stats["articles_count"] = count_json(articles_file, brand)
stats["optimizations_count"] = count_json(optimizations_file, brand)
stats["verify_results_count"] = count_json(verify_file, brand)
return stats