Files
ChouJuGEO/geo_tool.py
T

660 lines
22 KiB
Python
Raw Normal View History

2026-01-23 15:43:03 +08:00
import streamlit as st
from pathlib import Path
2026-01-23 15:43:03 +08:00
import json
from typing import Optional
from modules.data_storage import DataStorage
from modules.keyword_tool import KeywordTool
from modules.roi_analyzer import ROIAnalyzer
from modules.knowledge_base import KnowledgeBase
from modules.ui import (
tab_keywords,
tab_autowrite,
tab_optimize,
tab_validation,
tab_history,
tab_reports,
tab_workflow,
tab_resources,
tab_platform_sync,
tab_config_optimizer,
)
from modules.ui.tab_knowledge import render_tab_knowledge
from modules.ui.state import ss_init, init_session_state
from modules.ui.theme import inject_global_theme
2026-01-23 15:43:03 +08:00
2026-05-30 12:51:19 +08:00
APP_TITLE = "丑橘GEO内容优化平台"
2026-01-23 15:43:03 +08:00
# ------------------- 页面配置 & 极简美学 CSS(产品级精修,仍然克制) -------------------
2026-05-30 12:51:19 +08:00
st.set_page_config(page_title="丑橘GEO内容优化平台", layout="wide", initial_sidebar_state="collapsed")
2026-01-23 15:43:03 +08:00
inject_global_theme()
init_session_state()
2026-01-23 15:43:03 +08:00
st.title(APP_TITLE)
st.caption("🚀 AI 驱动的品牌内容策略 · 让您的品牌在 AI 对话中脱颖而出")
# ------------------- 初始化数据存储(SQLite -------------------
storage = DataStorage(storage_type="sqlite", db_path="geo_data.db")
# ------------------- 初始化知识库(RAG -------------------
kb = KnowledgeBase(storage_path="knowledge_base")
# ------------------- 成本记录辅助函数 -------------------
def estimate_tokens(text: str) -> int:
"""估算文本的 token 数量:中文约 1.5 字符 = 1 token,英文约 4 字符 = 1 token"""
if not text:
return 0
chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff')
other_chars = len(text) - chinese_chars
estimated_tokens = int(chinese_chars / 1.5 + other_chars / 4)
return max(estimated_tokens, len(text) // 4)
def record_api_cost(operation_type: str, provider: str, model: str, input_text: str, output_text: str, keyword: Optional[str] = None, platform: Optional[str] = None, brand: Optional[str] = None):
"""记录 API 调用成本"""
try:
roi_analyzer = ROIAnalyzer()
input_tokens = estimate_tokens(input_text)
output_tokens = estimate_tokens(output_text)
total_tokens = input_tokens + output_tokens
cost_usd, cost_cny = roi_analyzer.calculate_cost(provider, model, input_tokens, output_tokens)
storage.save_api_call(operation_type=operation_type, provider=provider, model=model, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens, cost_usd=cost_usd, cost_cny=cost_cny, keyword=keyword, platform=platform, brand=brand)
except Exception as e:
import logging
logging.warning(f"记录 API 成本失败: {e}")
2026-01-23 15:43:03 +08:00
2026-05-30 12:51:19 +08:00
# =================== 函数定义:配置管理(在 expander 之前定义,因为 expander 内使用了它们) ===================
2026-01-23 15:43:03 +08:00
def load_default_cfg():
"""
从项目根目录的 config.json 读取默认配置,如果不存在则使用内置默认值。
敏感信息(API Keys)优先从 .streamlit/secrets.toml 读取。
"""
base_cfg = {
2026-01-23 15:43:03 +08:00
"gen_provider": "DeepSeek",
"gen_api_key": "",
2026-01-23 15:43:03 +08:00
"verify_providers": ["DeepSeek"],
"verify_keys": {
"DeepSeek": ""
},
"tongyi_wanxiang_api_key": "",
"brand": "",
"advantages": "",
"competitors": "",
2026-01-23 15:43:03 +08:00
"temperature": 0.7,
}
# 从 config.json 读取非敏感配置
config_path = Path(__file__).with_name("config.json")
if config_path.exists():
try:
with config_path.open("r", encoding="utf-8") as f:
file_cfg = json.load(f)
if isinstance(file_cfg, dict):
base_cfg.update(file_cfg)
except Exception as e:
import logging
logging.warning(f"配置文件加载失败: {e}")
# 从 st.secrets 读取敏感信息(优先级更高)
try:
if hasattr(st, 'secrets') and st.secrets:
if "api_keys" in st.secrets:
api_keys = st.secrets["api_keys"]
if "deepseek" in api_keys and api_keys["deepseek"]:
base_cfg["gen_api_key"] = api_keys["deepseek"]
base_cfg["verify_keys"]["DeepSeek"] = api_keys["deepseek"]
if "tongyi_wanxiang" in api_keys and api_keys["tongyi_wanxiang"]:
base_cfg["tongyi_wanxiang_api_key"] = api_keys["tongyi_wanxiang"]
2026-05-30 12:51:19 +08:00
if "app_config" in st.secrets:
app_config = st.secrets["app_config"]
for key in ["brand", "advantages", "competitors", "temperature"]:
if key in app_config and app_config[key]:
base_cfg[key] = app_config[key]
2026-04-30 23:35:06 +08:00
except FileNotFoundError:
pass
2026-04-30 23:35:06 +08:00
except Exception as e:
import logging
logging.warning(f"读取 secrets.toml 失败: {e}")
return base_cfg
def save_cfg_to_file(cfg: dict) -> None:
"""
2026-05-30 12:51:19 +08:00
将配置持久化到本地文件:
- 非敏感配置 → config.json
- API Keys + 品牌信息 → .streamlit/secrets.toml
"""
2026-05-30 12:51:19 +08:00
import tomllib
import tomli_w # type: ignore
# ── 1. 非敏感配置 → config.json ──
config_path = Path(__file__).with_name("config.json")
try:
data = {}
if config_path.exists():
try:
with config_path.open("r", encoding="utf-8") as f:
loaded = json.load(f)
if isinstance(loaded, dict):
data.update(loaded)
except Exception as e:
import logging
2026-05-30 12:51:19 +08:00
logging.warning(f"读取 config.json 失败: {e}")
data = {}
2026-05-30 12:51:19 +08:00
for key in ["gen_provider", "verify_providers", "brand", "advantages", "competitors", "temperature"]:
if key in cfg:
data[key] = cfg[key]
2026-05-30 12:51:19 +08:00
with config_path.open("w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
2026-05-30 12:51:19 +08:00
except Exception as e:
import logging
logging.error(f"保存 config.json 失败: {e}")
# ── 2. API Keys + 品牌信息 → .streamlit/secrets.toml ──
secrets_path = Path(__file__).parent / ".streamlit" / "secrets.toml"
try:
# 读取现有的 secrets.toml(如果存在)
existing = {}
if secrets_path.exists():
try:
2026-05-30 12:51:19 +08:00
with secrets_path.open("rb") as f:
existing = tomllib.load(f)
except Exception:
2026-05-30 12:51:19 +08:00
existing = {}
# 构建新的 api_keys 段
new_api_keys = existing.get("api_keys", {})
# gen_api_key → deepseek
if "gen_api_key" in cfg and cfg["gen_api_key"]:
new_api_keys["deepseek"] = cfg["gen_api_key"]
# verify_keys → 各 provider 的 key
verify_keys = cfg.get("verify_keys", {})
provider_map = {
"DeepSeek": "deepseek",
"OpenAI (GPT)": "openai",
"Groq": "groq",
"Moonshot (Kimi)": "moonshot",
"Tongyi (通义千问)": "tongyi",
"豆包(字节跳动)": "doubao",
"文心一言(百度)": "wenxin",
}
for display_name, key_name in provider_map.items():
if display_name in verify_keys and verify_keys[display_name]:
new_api_keys[key_name] = verify_keys[display_name]
# tongyi_wanxiang_api_key
if "tongyi_wanxiang_api_key" in cfg and cfg["tongyi_wanxiang_api_key"]:
new_api_keys["tongyi_wanxiang"] = cfg["tongyi_wanxiang_api_key"]
# 构建新的 app_config 段
new_app_config = existing.get("app_config", {})
for key in ["brand", "advantages", "competitors"]:
if key in cfg and cfg[key]:
new_app_config[key] = cfg[key]
if "temperature" in cfg:
new_app_config["temperature"] = cfg["temperature"]
# 写出 TOML
secrets_path.parent.mkdir(parents=True, exist_ok=True)
with secrets_path.open("wb") as f:
content = {"api_keys": new_api_keys, "app_config": new_app_config}
tomli_w.dump(content, f)
except ImportError:
# tomli_w 未安装 — 回退到手动写 TOML 字符串
_save_secrets_fallback(secrets_path, cfg)
except Exception as e:
import logging
2026-05-30 12:51:19 +08:00
logging.error(f"保存 secrets.toml 失败: {e}")
try:
2026-05-30 12:51:19 +08:00
st.warning("⚠️ 无法将 API Key 写入 .streamlit/secrets.toml,但当前会话已生效。")
except Exception:
pass
2026-05-30 12:51:19 +08:00
def _save_secrets_fallback(secrets_path: Path, cfg: dict) -> None:
"""tomli_w 不可用时的纯文本备选方案"""
try:
secrets_path.parent.mkdir(parents=True, exist_ok=True)
lines = ["# Streamlit Secrets(自动保存)\n"]
lines.append("\n[api_keys]\n")
if cfg.get("gen_api_key"):
lines.append(f'deepseek = "{cfg["gen_api_key"]}"\n')
verify_keys = cfg.get("verify_keys", {})
provider_map = {
"DeepSeek": "deepseek",
"OpenAI (GPT)": "openai",
"Groq": "groq",
"Moonshot (Kimi)": "moonshot",
"Tongyi (通义千问)": "tongyi",
"豆包(字节跳动)": "doubao",
"文心一言(百度)": "wenxin",
}
for display_name, key_name in provider_map.items():
if display_name in verify_keys and verify_keys[display_name]:
lines.append(f'{key_name} = "{verify_keys[display_name]}"\n')
if cfg.get("tongyi_wanxiang_api_key"):
lines.append(f'tongyi_wanxiang = "{cfg["tongyi_wanxiang_api_key"]}"\n')
lines.append("\n[app_config]\n")
for key in ["brand", "advantages", "competitors"]:
if cfg.get(key):
lines.append(f'{key} = "{cfg[key]}"\n')
if "temperature" in cfg:
lines.append(f"temperature = {cfg['temperature']}\n")
2026-01-23 15:43:03 +08:00
2026-05-30 12:51:19 +08:00
with secrets_path.open("w", encoding="utf-8") as f:
f.writelines(lines)
except Exception as e:
import logging
logging.error(f"保存 secrets.toml(备选方案)失败: {e}")
2026-01-23 15:43:03 +08:00
def validate_cfg(cfg: dict):
2026-04-30 23:35:06 +08:00
"""验证配置完整性,返回 (是否有效, 错误列表)。"""
2026-01-23 15:43:03 +08:00
errors = []
2026-04-30 23:35:06 +08:00
warnings = []
2026-05-30 12:51:19 +08:00
2026-01-23 15:43:03 +08:00
if not cfg.get("gen_api_key", "").strip():
errors.append("生成&优化 LLM 的 API Key 未填写")
verify_providers = cfg.get("verify_providers", [])
verify_keys = cfg.get("verify_keys", {})
if not verify_providers:
errors.append("至少选择一个验证模型")
for vp in verify_providers:
if not verify_keys.get(vp, "").strip():
errors.append(f"验证模型 {vp} 的 API Key 未填写")
2026-05-30 12:51:19 +08:00
2026-04-30 23:35:06 +08:00
if not cfg.get("brand", "").strip():
warnings.append("品牌名称未填写(部分功能需要)")
if not cfg.get("advantages", "").strip():
warnings.append("核心优势未填写(部分功能需要)")
2026-01-23 15:43:03 +08:00
2026-04-30 23:35:06 +08:00
return (len(errors) == 0), errors + warnings
2026-01-23 15:43:03 +08:00
2026-05-30 12:51:19 +08:00
# 初始化默认配置(要在 expander 之前,因为 expander 内访问了 cfg
ss_init("cfg", load_default_cfg())
2026-01-23 15:43:03 +08:00
2026-05-30 12:51:19 +08:00
with st.expander("配置", expanded=False):
2026-05-30 12:51:19 +08:00
with st.expander("LLM 配置", expanded=True):
PROVIDER_LIST = ["DeepSeek", "OpenAI (GPT)", "Tongyi (通义千问)", "Groq", "Moonshot (Kimi)", "豆包(字节跳动)", "文心一言(百度)"]
gen_provider = st.selectbox(
2026-01-23 15:43:03 +08:00
"生成&优化 LLM",
PROVIDER_LIST,
index=PROVIDER_LIST.index(st.session_state.cfg["gen_provider"]) if st.session_state.cfg["gen_provider"] in PROVIDER_LIST else 0,
2026-01-23 15:43:03 +08:00
key="sb_gen_provider",
)
# API Key 输入提示
api_key_help = ""
if gen_provider == "豆包(字节跳动)":
api_key_help = "格式:access_key:secret_key:endpoint_id(用冒号分隔)"
elif gen_provider == "文心一言(百度)":
api_key_help = "格式:app_key:app_secret(用冒号分隔)"
gen_api_key = st.text_input(
f"{gen_provider} API Key(生成&优化用)",
type="password",
value=st.session_state.cfg.get("gen_api_key", ""),
key="sb_gen_api_key",
help=api_key_help if api_key_help else None,
)
# 验证配置组
with st.expander("🔍 验证配置", expanded=False):
verify_providers = st.multiselect(
"选择验证模型",
PROVIDER_LIST,
default=st.session_state.cfg.get("verify_providers", []),
key="sb_verify_providers",
)
verify_keys = {}
old_keys = st.session_state.cfg.get("verify_keys", {})
for vp in verify_providers:
vp_help = ""
if vp == "豆包(字节跳动)":
vp_help = "格式:access_key:secret_key:endpoint_id(用冒号分隔)"
elif vp == "文心一言(百度)":
vp_help = "格式:app_key:app_secret(用冒号分隔)"
verify_keys[vp] = st.text_input(
f"{vp} API Key(验证用)",
type="password",
value=old_keys.get(vp, ""),
key=f"sb_verify_key_{vp}",
help=vp_help if vp_help else None,
)
2026-01-23 15:43:03 +08:00
# 品牌信息组
with st.expander("🏢 品牌信息", expanded=True):
brand = st.text_input("主品牌名称", value=st.session_state.cfg.get("brand", ""), key="sb_brand")
advantages = st.text_area(
"核心优势/卖点(AI专属)",
value=st.session_state.cfg.get("advantages", ""),
height=120,
key="sb_advantages",
)
competitors = st.text_area(
"竞品品牌(每行一个)",
value=st.session_state.cfg.get("competitors", ""),
height=100,
key="sb_competitors",
)
# 高级设置组
with st.expander("⚙️ 高级设置", expanded=False):
temperature = st.slider(
"生成温度(更稳→更低)",
0.0,
1.0,
float(st.session_state.cfg.get("temperature", 0.7)),
0.05,
key="sb_temperature",
)
tongyi_wanxiang_api_key = st.text_input(
"通义万相 API Key(图片生成)",
type="password",
value=st.session_state.cfg.get("tongyi_wanxiang_api_key", ""),
key="sb_tongyi_wanxiang_api_key",
help="阿里云 DashScope API Key,用于生成文章配图。",
)
# 应用配置按钮
apply_cfg = st.button("应用配置", use_container_width=True, type="primary")
2026-01-23 15:43:03 +08:00
if apply_cfg or not st.session_state.cfg_applied:
# 优先从主 key 读取值(如果使用了临时 key 更新,值已同步到主 key)
brand_value = st.session_state.get("sb_brand", brand)
advantages_value = st.session_state.get("sb_advantages", advantages)
2026-01-23 15:43:03 +08:00
st.session_state.cfg = {
"gen_provider": gen_provider,
"gen_api_key": gen_api_key,
"verify_providers": verify_providers,
"verify_keys": verify_keys,
"tongyi_wanxiang_api_key": tongyi_wanxiang_api_key,
"brand": brand_value,
"advantages": advantages_value,
2026-01-23 15:43:03 +08:00
"competitors": competitors,
"temperature": temperature,
}
ok, errs = validate_cfg(st.session_state.cfg)
st.session_state.cfg_valid = ok
st.session_state.cfg_errors = errs
if ok:
# 仅在配置合法时才写入本地配置文件,并标记为已应用
save_cfg_to_file(st.session_state.cfg)
st.session_state.cfg_applied = True
else:
st.session_state.cfg_applied = False
2026-01-23 15:43:03 +08:00
if not st.session_state.cfg_valid:
2026-04-30 23:35:06 +08:00
with st.container(border=True):
st.markdown("**⚠️ 完成配置后即可使用全部功能**")
for err in st.session_state.cfg_errors:
st.markdown(f"{err}")
2026-01-23 15:43:03 +08:00
else:
2026-04-30 23:35:06 +08:00
with st.container(border=True):
st.markdown("**✅ 配置已就绪**")
st.caption("所有功能已解锁,可以开始使用")
2026-01-23 15:43:03 +08:00
st.markdown("---")
if st.button("重置全部结果(不删除配置)", use_container_width=True, key="sb_reset_all"):
st.session_state.keywords = []
st.session_state.generated_contents = []
st.session_state.zip_bytes = None
st.session_state.zip_filename = ""
st.session_state.optimized_article = ""
st.session_state.opt_changes = ""
st.session_state.verify_combined = None
st.session_state.config_optimization_result = None
st.session_state.config_hash = None
2026-01-23 15:43:03 +08:00
st.toast("已重置全部结果。")
st.caption("闭环:关键词 → 创作 → 优化 → 验证")
2026-05-30 12:51:19 +08:00
def model_defaults(provider: str) -> str:
from modules.llm_factory import get_default_model
return get_default_model(provider)
# ------------------- 缓存 LLM 客户端(显著降低“频繁 Loading”) -------------------
@st.cache_resource(show_spinner=False)
def build_llm(provider: str, api_key: str, model: str, temperature: float):
"""
- 使用 cache_resource 缓存客户端,避免每次 rerun 重建
- 统一使用 llm_factory 模块构建 LLM
"""
from modules.llm_factory import build_llm as _build_llm
return _build_llm(provider, api_key, model, temperature)
# ------------------- 侧边栏:全局配置(分组折叠) -------------------
2026-01-23 15:43:03 +08:00
cfg = st.session_state.cfg
brand = cfg["brand"]
advantages = cfg["advantages"]
temperature = float(cfg.get("temperature", 0.7))
competitor_list = [c.strip() for c in cfg["competitors"].split("\n") if c.strip()]
_seen = set()
clean_competitors = []
for c in competitor_list:
cl = c.lower()
if cl == brand.lower():
continue
if cl in _seen:
continue
_seen.add(cl)
clean_competitors.append(c)
competitor_list = clean_competitors
# ------------------- 初始化 LLM(仅在 cfg_valid 时;且 build_llm 已缓存) -------------------
gen_llm = None
verify_llms = {}
if st.session_state.cfg_valid:
try:
gen_llm = build_llm(cfg["gen_provider"], cfg["gen_api_key"], model_defaults(cfg["gen_provider"]), temperature)
except Exception as e:
st.error(f"生成LLM加载失败:{e}")
for vp in cfg["verify_providers"]:
key = cfg["verify_keys"].get(vp, "").strip()
if not key:
continue
try:
verify_llms[vp] = build_llm(vp, key, model_defaults(vp), temperature)
except Exception as e:
st.error(f"{vp}验证LLM加载失败:{e}")
# ------------------- KPI 总览(极简但更像产品) -------------------
k1, k2, k3, k4 = st.columns(4)
k1.metric("关键词", len(st.session_state.keywords), border=True)
k2.metric("内容包", len(st.session_state.generated_contents), border=True)
k3.metric("文章优化", "已生成" if bool(st.session_state.optimized_article) else "未生成", border=True)
k4.metric("验证结果", "已生成" if st.session_state.verify_combined is not None else "未生成", border=True)
2026-01-23 15:43:03 +08:00
st.markdown("---")
# ------------------- 主导航:Tabs(流程更清晰) -------------------
tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9, tab10, tab11 = st.tabs([
"🎯 关键词蒸馏",
"✍️ 自动创作",
"🔧 文章优化",
"✅ 多模型验证",
"📚 历史记录",
"📊 AI 数据报表",
"⚙️ 工作流自动化",
"📦 GEO 资源库",
"🔄 平台同步",
"🛠️ 配置优化助手",
"📚 品牌知识库"
])
2026-01-23 15:43:03 +08:00
# =======================
# Tab1:关键词蒸馏
# =======================
with tab1:
tab_keywords.render_tab_keywords(
storage,
ss_init,
gen_llm,
brand,
advantages
2026-01-23 15:43:03 +08:00
)
# =======================
# Tab2:自动创作内容(含批量 ZIP / GitHub 模板)
# =======================
with tab2:
tab_autowrite.render_tab_autowrite(
storage,
ss_init,
gen_llm,
brand,
advantages,
cfg,
record_api_cost,
model_defaults
)
# =======================
# Tab3:文章优化
# =======================
with tab3:
tab_optimize.render_tab_optimize(
storage,
ss_init,
gen_llm,
brand,
advantages,
cfg,
record_api_cost,
model_defaults,
)
2026-01-23 15:43:03 +08:00
# =======================
# Tab4:多模型验证 & 竞品对比
# =======================
with tab4:
tab_validation.render_tab_validation(
storage,
ss_init,
brand,
advantages,
competitor_list,
verify_llms,
record_api_cost,
model_defaults,
)
2026-01-23 15:43:03 +08:00
# =======================
# Tab5:历史记录
# =======================
with tab5:
tab_history.render_tab_history(storage, brand)
# =======================
# Tab6AI 数据报表
# =======================
with tab6:
tab_reports.render_tab_reports(
storage,
ss_init,
gen_llm,
brand,
advantages,
competitor_list,
verify_llms,
record_api_cost,
model_defaults,
)
2026-01-23 15:43:03 +08:00
# =======================
# Tab7:工作流自动化
# =======================
with tab7:
tab_workflow.render_tab_workflow(
storage,
ss_init,
gen_llm,
brand,
advantages,
competitor_list,
verify_llms,
record_api_cost,
model_defaults,
)
2026-01-23 15:43:03 +08:00
# =======================
# Tab8GEO 资源库
# =======================
with tab8:
tab_resources.render_tab_resources(storage, brand)
2026-01-23 15:43:03 +08:00
# =======================
# Tab9:平台同步
# =======================
with tab9:
tab_platform_sync.render_tab_platform_sync(storage, brand)
# =======================
# Tab10:配置优化助手
# =======================
with tab10:
tab_config_optimizer.render_tab_config_optimizer(
storage,
cfg,
brand,
advantages,
competitor_list,
build_llm,
model_defaults,
)
2026-01-23 15:43:03 +08:00
# =======================
# Tab11:品牌知识库(RAG
# =======================
with tab11:
render_tab_knowledge(kb)
2026-05-30 12:51:19 +08:00
st.caption("一站式GEO优化平台| 多模型验证 + 文章优化 + RAG知识库 • GEO全闭环,专注AI品牌影响力")