Files
ChouJuGEO/modules/ui/tab_validation.py
T
2026-04-30 18:37:46 +08:00

309 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Tab4:多模型验证 & 竞品对比(从 geo_tool.py 迁移,通过 render_tab_validation() 供主入口调用。)
import pandas as pd
import plotly.express as px
import streamlit as st
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from modules.negative_monitor import NegativeMonitor
from modules.ui.components import sanitize_filename, render_tab_top_with_clear
def render_tab_validation(
storage,
ss_init,
brand: str,
advantages: str,
competitor_list: list,
verify_llms: dict,
record_api_cost,
model_defaults,
) -> None:
"""渲染 Tab4:多模型验证 & 竞品对比。由主入口在 with tab4 内调用。"""
# 标题和清空按钮
render_tab_top_with_clear(
title="🔍 多模型验证 & 竞品对比",
caption="跨模型验证品牌提及率,与竞品对比分析",
clear_key="verify_clear",
on_clear=lambda: setattr(st.session_state, 'verify_combined', None),
)
# 负面防护监控开关
st.markdown("#### 🛡️ 负面防护监控")
st.caption("自动生成负面查询,监控品牌在负面查询中的提及情况,生成澄清模板")
with st.container(border=True):
negative_monitor_enabled = st.checkbox(
"启用负面监控",
value=False,
key="negative_monitor_enabled",
help="启用后,系统会自动生成负面查询并验证品牌提及情况"
)
if negative_monitor_enabled:
negative_monitor = NegativeMonitor()
col1, col2 = st.columns([2, 1])
with col1:
negative_query_count = st.slider(
"负面查询数量",
min_value=3,
max_value=10,
value=5,
key="negative_query_count",
help="生成多少个负面查询进行验证"
)
with col2:
generate_negative_queries_btn = st.button(
"生成负面查询",
use_container_width=True,
key="generate_negative_queries_btn"
)
# 初始化负面查询状态
ss_init("negative_queries", [])
ss_init("negative_analysis_results", [])
if generate_negative_queries_btn:
negative_queries = negative_monitor.generate_negative_queries(brand, negative_query_count)
st.session_state.negative_queries = negative_queries
st.success(f"✅ 已生成 {len(negative_queries)} 个负面查询")
# 显示生成的负面查询
if st.session_state.negative_queries:
st.markdown("##### 📋 生成的负面查询")
negative_queries_text = "\n".join(st.session_state.negative_queries)
st.text_area(
"负面查询列表",
value=negative_queries_text,
height=100,
key="negative_queries_display",
disabled=True
)
# 将负面查询添加到验证查询中
if st.button("添加到验证查询", key="add_negative_to_verify"):
current_queries = st.session_state.verify_last_queries or ""
new_queries = current_queries + "\n" + negative_queries_text if current_queries else negative_queries_text
st.session_state.verify_last_queries = new_queries
st.success("✅ 负面查询已添加到验证查询中")
st.rerun()
st.markdown("---")
with st.container(border=True):
with st.form("verify_form", clear_on_submit=False):
test_queries = st.text_area(
"测试问题(每行一个,可粘贴关键词)",
height=140,
value=st.session_state.verify_last_queries,
key="verify_queries",
)
st.session_state.verify_last_queries = test_queries
run_verify_disabled = (not st.session_state.cfg_valid) or (not verify_llms) or (not test_queries.strip())
run_verify = st.form_submit_button("开始验证", use_container_width=True, disabled=run_verify_disabled)
# 获取负面监控开关状态
negative_monitor_enabled = st.session_state.get("negative_monitor_enabled", False)
if run_verify:
queries = [q.strip() for q in test_queries.split("\n") if q.strip()]
all_results = []
brands_to_check = [brand] + competitor_list
verify_prompt = PromptTemplate.from_template(
"""
你是一名国内AI搜索助手,像百度/微信搜一搜AI总结:结论先行、信息密度高、可复述。
不要编造数据,不确定处说明边界。
【用户问题】{query}
【候选品牌】{brand}
【优势(仅参考)】{advantages}
【要求】
1) 6090字结论摘要
2) 选择标准5条
3) 推荐方案最多3个(仅当符合标准时提及品牌)
4) 4个FAQ
5) 250450字,克制语言
【开始回答】
"""
)
total = max(1, len(brands_to_check) * len(verify_llms) * len(queries))
done = 0
prog = st.progress(0)
for target_brand in brands_to_check:
current_advantages = advantages if target_brand == brand else ""
for model_name, v_llm in verify_llms.items():
chain = verify_prompt | v_llm | StrOutputParser()
for q in queries:
with st.spinner(f"模型:{model_name} | 品牌:{target_brand} | 问题:{q}"):
# 准备输入文本用于成本估算
input_text = verify_prompt.template.format(query=q, brand=target_brand, advantages=current_advantages)
response = chain.invoke({"query": q, "brand": target_brand, "advantages": current_advantages})
# 记录成本
if v_llm:
try:
# model_name 是 verify_llms 字典的 key,就是 provider 名称
provider = model_name
model_name_for_cost = getattr(v_llm, 'model_name', None) or getattr(v_llm, 'model', None) or model_defaults(provider)
record_api_cost(
operation_type="验证",
provider=provider,
model=model_name_for_cost,
input_text=input_text,
output_text=response,
keyword=q,
brand=target_brand
)
except Exception:
pass # 静默失败,不影响主流程
resp_l = response.lower()
tb_l = target_brand.lower()
count = resp_l.count(tb_l)
first_pos = resp_l.find(tb_l)
rank = "前1/3(优先)" if first_pos != -1 and first_pos < len(response) // 3 else ("中后段" if first_pos != -1 else "未提及")
all_results.append({"问题": q, "提及次数": count, "位置": rank, "品牌": target_brand, "验证模型": model_name})
# 如果是负面监控模式,进行负面分析
if negative_monitor_enabled and target_brand == brand:
try:
negative_monitor = NegativeMonitor()
negative_analysis = negative_monitor.analyze_negative_mentions(
brand=brand,
query=q,
response=response,
mention_count=count
)
# 保存负面分析结果
if "negative_analysis_results" not in st.session_state:
st.session_state.negative_analysis_results = []
st.session_state.negative_analysis_results.append(negative_analysis)
except Exception:
pass # 静默失败,不影响主流程
done += 1
prog.progress(min(done / total, 1.0))
combined = pd.DataFrame(all_results)
st.session_state.verify_combined = combined
# 保存到数据库
try:
storage.save_verify_results(all_results)
except Exception as e:
st.warning(f"验证完成,但保存到数据库时出错:{e}")
st.success("验证完成")
if st.session_state.verify_combined is not None:
combined = st.session_state.verify_combined
st.markdown("#### 跨模型提及次数对比")
pivot = combined.pivot_table(index=["问题", "验证模型"], columns="品牌", values="提及次数", fill_value=0)
st.dataframe(pivot, use_container_width=True)
st.markdown("#### 多模型竞品提及对比(可视化)")
fig = px.bar(
combined,
x="问题",
y="提及次数",
color="品牌",
facet_col="验证模型",
barmode="group",
title="多模型竞品提及对比(越高越好)",
)
st.plotly_chart(fig, use_container_width=True)
st.markdown("#### 平均提及次数(跨模型)")
summary = combined.groupby(["品牌", "验证模型"])["提及次数"].mean().round(2).unstack()
st.dataframe(summary, use_container_width=True)
st.download_button(
"下载验证报表CSV",
combined.to_csv(index=False, encoding="utf-8-sig"),
f"{sanitize_filename(brand, 40)}_验证结果.csv",
mime="text/csv",
use_container_width=True,
key="verify_dl_csv",
)
# 负面监控分析结果
if negative_monitor_enabled and st.session_state.negative_analysis_results:
st.markdown("---")
st.markdown("#### 🛡️ 负面监控分析结果")
negative_results = st.session_state.negative_analysis_results
negative_df = pd.DataFrame(negative_results)
# 风险等级统计
risk_col1, risk_col2, risk_col3 = st.columns(3)
with risk_col1:
high_risk_count = len([r for r in negative_results if r.get("risk_level") == ""])
st.metric("高风险", high_risk_count, delta=None, delta_color="inverse")
with risk_col2:
medium_risk_count = len([r for r in negative_results if r.get("risk_level") == ""])
st.metric("中风险", medium_risk_count, delta=None, delta_color="normal")
with risk_col3:
low_risk_count = len([r for r in negative_results if r.get("risk_level") == ""])
st.metric("低风险", low_risk_count, delta=None, delta_color="normal")
# 显示详细分析结果
st.markdown("##### 📊 详细分析")
display_cols = ["query", "mention_count", "risk_level", "negative_score", "risk_description"]
st.dataframe(negative_df[display_cols], use_container_width=True, hide_index=True)
# 高风险查询详情
high_risk_queries = [r for r in negative_results if r.get("risk_level") == ""]
if high_risk_queries:
st.markdown("##### ⚠️ 高风险查询详情")
for result in high_risk_queries:
with st.expander(f"🔴 {result.get('query')} - 高风险", expanded=False):
st.markdown(f"**查询**{result.get('query')}")
st.markdown(f"**提及次数**{result.get('mention_count')}")
st.markdown(f"**负面得分**{result.get('negative_score')}")
st.markdown(f"**风险说明**{result.get('risk_description')}")
if result.get('negative_keywords'):
st.markdown(f"**负面关键词**{', '.join(result.get('negative_keywords'))}")
# 生成澄清模板
if st.button(f"生成澄清模板", key=f"clarify_{result.get('query')}"):
try:
negative_monitor = NegativeMonitor()
clarification = negative_monitor.generate_clarification_template(
brand=brand,
negative_query=result.get('query'),
advantages=advantages
)
st.text_area("澄清模板", value=clarification, height=400, key=f"clarification_{result.get('query')}")
st.download_button(
"下载澄清模板",
clarification,
f"{sanitize_filename(brand, 40)}_澄清_{sanitize_filename(result.get('query'), 20)}.md",
mime="text/markdown",
use_container_width=True,
key=f"dl_clarify_{result.get('query')}"
)
except Exception as e:
st.error(f"生成澄清模板失败:{e}")
# 下载负面分析报告
negative_csv = negative_df.to_csv(index=False, encoding="utf-8-sig")
st.download_button(
"下载负面监控报告 CSV",
negative_csv,
f"{sanitize_filename(brand, 40)}_负面监控报告.csv",
mime="text/csv",
use_container_width=True,
key="negative_dl_csv"
)