fb309299bf
Made-with: Cursor
530 lines
17 KiB
Python
530 lines
17 KiB
Python
"""
|
||
RAG 知识库模块
|
||
支持用户上传品牌文档,自动分块、索引,生成内容时自动检索相关内容
|
||
"""
|
||
|
||
import json
|
||
import hashlib
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import List, Dict, Optional, Any
|
||
from datetime import datetime
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class DocumentChunk:
|
||
"""文档分块"""
|
||
|
||
def __init__(self, content: str, metadata: Dict[str, Any]):
|
||
self.content = content
|
||
self.metadata = metadata
|
||
self.chunk_id = hashlib.md5(content.encode()).hexdigest()[:12]
|
||
|
||
def to_dict(self) -> Dict:
|
||
return {
|
||
"chunk_id": self.chunk_id,
|
||
"content": self.content,
|
||
"metadata": self.metadata
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict) -> 'DocumentChunk':
|
||
chunk = cls(data["content"], data["metadata"])
|
||
chunk.chunk_id = data["chunk_id"]
|
||
return chunk
|
||
|
||
|
||
class KnowledgeBase:
|
||
"""知识库管理器"""
|
||
|
||
def __init__(self, storage_path: str = "knowledge_base"):
|
||
"""
|
||
Args:
|
||
storage_path: 知识库存储路径
|
||
"""
|
||
self.storage_path = Path(storage_path)
|
||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 文档元数据
|
||
self.documents_file = self.storage_path / "documents.json"
|
||
# 分块数据
|
||
self.chunks_file = self.storage_path / "chunks.json"
|
||
|
||
self.documents: Dict[str, Dict] = self._load_json(self.documents_file, {})
|
||
self.chunks: List[Dict] = self._load_json(self.chunks_file, [])
|
||
|
||
def _load_json(self, path: Path, default: Any) -> Any:
|
||
"""加载 JSON 文件"""
|
||
if path.exists():
|
||
try:
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except Exception as e:
|
||
logger.warning(f"加载 {path} 失败: {e}")
|
||
return default
|
||
|
||
def _save_json(self, path: Path, data: Any):
|
||
"""保存 JSON 文件"""
|
||
try:
|
||
with open(path, 'w', encoding='utf-8') as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
logger.error(f"保存 {path} 失败: {e}")
|
||
|
||
def add_document(self, filename: str, content: str, doc_type: str = "text",
|
||
metadata: Optional[Dict] = None) -> Dict:
|
||
"""
|
||
添加文档到知识库
|
||
|
||
Args:
|
||
filename: 文件名
|
||
content: 文档内容
|
||
doc_type: 文档类型 (text, markdown, faq, case, product)
|
||
metadata: 额外元数据
|
||
|
||
Returns:
|
||
文档信息
|
||
"""
|
||
doc_id = hashlib.md5(f"{filename}{datetime.now().isoformat()}".encode()).hexdigest()[:12]
|
||
|
||
doc_info = {
|
||
"doc_id": doc_id,
|
||
"filename": filename,
|
||
"doc_type": doc_type,
|
||
"content_length": len(content),
|
||
"chunk_count": 0,
|
||
"created_at": datetime.now().isoformat(),
|
||
"metadata": metadata or {}
|
||
}
|
||
|
||
# 分块
|
||
chunks = self._split_document(content, doc_id, filename, doc_type)
|
||
doc_info["chunk_count"] = len(chunks)
|
||
|
||
# 保存
|
||
self.documents[doc_id] = doc_info
|
||
self.chunks.extend([c.to_dict() for c in chunks])
|
||
|
||
self._save_json(self.documents_file, self.documents)
|
||
self._save_json(self.chunks_file, self.chunks)
|
||
|
||
logger.info(f"文档 '{filename}' 已添加,分为 {len(chunks)} 个分块")
|
||
return doc_info
|
||
|
||
def _split_document(self, content: str, doc_id: str, filename: str,
|
||
doc_type: str, chunk_size: int = 500, overlap: int = 50) -> List[DocumentChunk]:
|
||
"""
|
||
将文档分割为多个分块
|
||
|
||
Args:
|
||
content: 文档内容
|
||
doc_id: 文档 ID
|
||
filename: 文件名
|
||
doc_type: 文档类型
|
||
chunk_size: 分块大小(字符数)
|
||
overlap: 重叠字符数
|
||
|
||
Returns:
|
||
分块列表
|
||
"""
|
||
chunks = []
|
||
|
||
# 根据文档类型选择分块策略
|
||
if doc_type == "faq":
|
||
# FAQ 文档按 Q&A 对分块
|
||
chunks = self._split_faq(content, doc_id, filename)
|
||
elif doc_type == "product":
|
||
# 产品文档按功能/特性分块
|
||
chunks = self._split_by_sections(content, doc_id, filename, doc_type)
|
||
else:
|
||
# 通用文档按段落/长度分块
|
||
chunks = self._split_by_length(content, doc_id, filename, doc_type,
|
||
chunk_size, overlap)
|
||
|
||
return chunks
|
||
|
||
def _split_faq(self, content: str, doc_id: str, filename: str) -> List[DocumentChunk]:
|
||
"""FAQ 文档分块:每个 Q&A 对为一个分块"""
|
||
chunks = []
|
||
lines = content.split('\n')
|
||
|
||
current_q = ""
|
||
current_a = ""
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
# 检测问题行
|
||
if line.startswith('Q:') or line.startswith('问:') or line.startswith('Q:'):
|
||
# 保存上一个 Q&A 对
|
||
if current_q and current_a:
|
||
chunk_content = f"问题:{current_q}\n回答:{current_a}"
|
||
chunks.append(DocumentChunk(
|
||
content=chunk_content,
|
||
metadata={
|
||
"doc_id": doc_id,
|
||
"filename": filename,
|
||
"type": "faq",
|
||
"question": current_q
|
||
}
|
||
))
|
||
current_q = line[2:].strip()
|
||
current_a = ""
|
||
elif line.startswith('A:') or line.startswith('答:') or line.startswith('A:'):
|
||
current_a = line[2:].strip()
|
||
elif current_a:
|
||
current_a += "\n" + line
|
||
elif current_q and not current_a:
|
||
current_q += "\n" + line
|
||
|
||
# 保存最后一个 Q&A 对
|
||
if current_q and current_a:
|
||
chunk_content = f"问题:{current_q}\n回答:{current_a}"
|
||
chunks.append(DocumentChunk(
|
||
content=chunk_content,
|
||
metadata={
|
||
"doc_id": doc_id,
|
||
"filename": filename,
|
||
"type": "faq",
|
||
"question": current_q
|
||
}
|
||
))
|
||
|
||
return chunks
|
||
|
||
def _split_by_sections(self, content: str, doc_id: str, filename: str,
|
||
doc_type: str) -> List[DocumentChunk]:
|
||
"""按章节分块(适用于产品文档、Markdown 等)"""
|
||
chunks = []
|
||
sections = content.split('\n# ')
|
||
|
||
for i, section in enumerate(sections):
|
||
if not section.strip():
|
||
continue
|
||
|
||
# 提取标题
|
||
lines = section.split('\n', 1)
|
||
title = lines[0].strip('# ').strip()
|
||
body = lines[1].strip() if len(lines) > 1 else ""
|
||
|
||
if body:
|
||
chunks.append(DocumentChunk(
|
||
content=f"## {title}\n{body}",
|
||
metadata={
|
||
"doc_id": doc_id,
|
||
"filename": filename,
|
||
"type": doc_type,
|
||
"section_title": title
|
||
}
|
||
))
|
||
|
||
return chunks
|
||
|
||
def _split_by_length(self, content: str, doc_id: str, filename: str,
|
||
doc_type: str, chunk_size: int, overlap: int) -> List[DocumentChunk]:
|
||
"""按长度分块"""
|
||
chunks = []
|
||
paragraphs = content.split('\n\n')
|
||
|
||
current_chunk = ""
|
||
for para in paragraphs:
|
||
para = para.strip()
|
||
if not para:
|
||
continue
|
||
|
||
if len(current_chunk) + len(para) > chunk_size and current_chunk:
|
||
chunks.append(DocumentChunk(
|
||
content=current_chunk,
|
||
metadata={
|
||
"doc_id": doc_id,
|
||
"filename": filename,
|
||
"type": doc_type
|
||
}
|
||
))
|
||
# 保留重叠部分
|
||
if overlap > 0:
|
||
current_chunk = current_chunk[-overlap:] + "\n" + para
|
||
else:
|
||
current_chunk = para
|
||
else:
|
||
current_chunk = current_chunk + "\n\n" + para if current_chunk else para
|
||
|
||
# 保存最后一个分块
|
||
if current_chunk.strip():
|
||
chunks.append(DocumentChunk(
|
||
content=current_chunk,
|
||
metadata={
|
||
"doc_id": doc_id,
|
||
"filename": filename,
|
||
"type": doc_type
|
||
}
|
||
))
|
||
|
||
return chunks
|
||
|
||
def search(self, query: str, top_k: int = 5,
|
||
doc_type: Optional[str] = None) -> List[Dict]:
|
||
"""
|
||
搜索知识库
|
||
|
||
Args:
|
||
query: 查询文本
|
||
top_k: 返回结果数量
|
||
doc_type: 过滤文档类型
|
||
|
||
Returns:
|
||
相关分块列表
|
||
"""
|
||
if not self.chunks:
|
||
return []
|
||
|
||
# 计算相似度分数
|
||
scored_chunks = []
|
||
query_lower = query.lower()
|
||
query_keywords = set(query_lower.split())
|
||
|
||
for chunk_data in self.chunks:
|
||
content_lower = chunk_data["content"].lower()
|
||
|
||
# 计算关键词匹配分数
|
||
keyword_matches = sum(1 for kw in query_keywords if kw in content_lower)
|
||
keyword_score = keyword_matches / len(query_keywords) if query_keywords else 0
|
||
|
||
# 计算内容相关性分数(包含查询词的比例)
|
||
content_score = 0
|
||
for kw in query_keywords:
|
||
if kw in content_lower:
|
||
# 计算关键词在内容中的密度
|
||
count = content_lower.count(kw)
|
||
content_score += count * len(kw) / len(content_lower)
|
||
|
||
# 综合分数
|
||
total_score = keyword_score * 0.6 + content_score * 0.4
|
||
|
||
if total_score > 0:
|
||
# 过滤文档类型
|
||
if doc_type and chunk_data["metadata"].get("type") != doc_type:
|
||
continue
|
||
|
||
scored_chunks.append({
|
||
"chunk": chunk_data,
|
||
"score": total_score
|
||
})
|
||
|
||
# 按分数排序并返回 top_k
|
||
scored_chunks.sort(key=lambda x: x["score"], reverse=True)
|
||
|
||
return [
|
||
{
|
||
"content": item["chunk"]["content"],
|
||
"metadata": item["chunk"]["metadata"],
|
||
"score": item["score"]
|
||
}
|
||
for item in scored_chunks[:top_k]
|
||
]
|
||
|
||
def get_context_for_generation(self, query: str, brand: str,
|
||
platform: str, top_k: int = 3) -> str:
|
||
"""
|
||
获取用于内容生成的上下文
|
||
|
||
Args:
|
||
query: 查询/主题
|
||
brand: 品牌名
|
||
platform: 目标平台
|
||
|
||
Returns:
|
||
格式化的上下文字符串
|
||
"""
|
||
# 搜索相关文档
|
||
results = self.search(query, top_k=top_k)
|
||
|
||
if not results:
|
||
return ""
|
||
|
||
# 组装上下文
|
||
context_parts = ["以下是相关的品牌/产品信息,可用于内容生成:\n"]
|
||
|
||
for i, result in enumerate(results, 1):
|
||
source = result["metadata"].get("filename", "未知来源")
|
||
content = result["content"]
|
||
|
||
context_parts.append(f"--- 参考资料 {i}(来源:{source})---")
|
||
context_parts.append(content)
|
||
context_parts.append("")
|
||
|
||
return "\n".join(context_parts)
|
||
|
||
def list_documents(self) -> List[Dict]:
|
||
"""列出所有文档"""
|
||
return list(self.documents.values())
|
||
|
||
def delete_document(self, doc_id: str) -> bool:
|
||
"""删除文档及其分块"""
|
||
if doc_id not in self.documents:
|
||
return False
|
||
|
||
# 删除文档
|
||
del self.documents[doc_id]
|
||
|
||
# 删除相关分块
|
||
self.chunks = [c for c in self.chunks if c["metadata"].get("doc_id") != doc_id]
|
||
|
||
# 保存
|
||
self._save_json(self.documents_file, self.documents)
|
||
self._save_json(self.chunks_file, self.chunks)
|
||
|
||
return True
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取知识库统计信息"""
|
||
doc_types = {}
|
||
for doc in self.documents.values():
|
||
doc_type = doc.get("doc_type", "unknown")
|
||
doc_types[doc_type] = doc_types.get(doc_type, 0) + 1
|
||
|
||
return {
|
||
"total_documents": len(self.documents),
|
||
"total_chunks": len(self.chunks),
|
||
"document_types": doc_types
|
||
}
|
||
|
||
|
||
class SourceVerifier:
|
||
"""来源验证器"""
|
||
|
||
def __init__(self):
|
||
self.claim_patterns = [
|
||
"根据",
|
||
"据",
|
||
"报告显示",
|
||
"数据表明",
|
||
"研究表明",
|
||
"调查发现",
|
||
"据统计",
|
||
"根据报告",
|
||
"根据数据",
|
||
"根据研究",
|
||
"according to",
|
||
"based on",
|
||
"as reported by",
|
||
"research shows",
|
||
"data shows"
|
||
]
|
||
|
||
def extract_claims(self, content: str) -> List[Dict]:
|
||
"""
|
||
从内容中提取来源声明
|
||
|
||
Args:
|
||
content: 文本内容
|
||
|
||
Returns:
|
||
声明列表
|
||
"""
|
||
claims = []
|
||
sentences = content.replace('。', '。\n').replace('.', '.\n').split('\n')
|
||
|
||
for sentence in sentences:
|
||
sentence = sentence.strip()
|
||
if not sentence:
|
||
continue
|
||
|
||
for pattern in self.claim_patterns:
|
||
if pattern in sentence.lower():
|
||
claims.append({
|
||
"text": sentence,
|
||
"pattern": pattern,
|
||
"verified": False,
|
||
"verification_result": None
|
||
})
|
||
break
|
||
|
||
return claims
|
||
|
||
def generate_verification_prompt(self, claim: str) -> str:
|
||
"""
|
||
生成验证提示词
|
||
|
||
Args:
|
||
claim: 来源声明
|
||
|
||
Returns:
|
||
验证提示词
|
||
"""
|
||
return f"""请验证以下声明的真实性:
|
||
|
||
声明:{claim}
|
||
|
||
请回答:
|
||
1. 这个声明是否包含可验证的具体来源(如具体报告名称、机构名称、数据年份)?
|
||
2. 如果包含,请判断这个来源是否可能存在且可信。
|
||
3. 如果无法验证或来源可疑,请说明原因。
|
||
|
||
回答格式:
|
||
- 来源具体性:[具体/模糊/无来源]
|
||
- 可信度评估:[高/中/低/无法判断]
|
||
- 建议:[保留/修改/删除]
|
||
- 原因:[简要说明]"""
|
||
|
||
def assess_source_quality(self, content: str) -> Dict:
|
||
"""
|
||
评估内容的来源质量
|
||
|
||
Args:
|
||
content: 文本内容
|
||
|
||
Returns:
|
||
质量评估结果
|
||
"""
|
||
claims = self.extract_claims(content)
|
||
|
||
if not claims:
|
||
return {
|
||
"has_sources": False,
|
||
"claim_count": 0,
|
||
"quality_score": 0,
|
||
"suggestions": ["内容中没有引用任何来源,建议添加数据支撑"]
|
||
}
|
||
|
||
# 分析来源质量
|
||
specific_count = 0
|
||
vague_count = 0
|
||
|
||
for claim in claims:
|
||
text = claim["text"]
|
||
# 检查是否有具体来源指标
|
||
has_specific = any([
|
||
any(year in text for year in ["2020", "2021", "2022", "2023", "2024", "2025"]),
|
||
any(org in text for org in ["Gartner", "IDC", "Forrester", "McKinsey",
|
||
"哈佛", "MIT", "斯坦福", "中科院"]),
|
||
"报告" in text and ("《" in text or "年" in text),
|
||
"数据" in text and any(c.isdigit() for c in text)
|
||
])
|
||
|
||
if has_specific:
|
||
specific_count += 1
|
||
else:
|
||
vague_count += 1
|
||
|
||
quality_score = min(100, (specific_count / len(claims)) * 100)
|
||
|
||
suggestions = []
|
||
if vague_count > 0:
|
||
suggestions.append(f"有 {vague_count} 个来源描述模糊,建议补充具体报告名称或数据年份")
|
||
if specific_count == 0:
|
||
suggestions.append("所有来源都不够具体,建议引用真实的行业报告或权威机构数据")
|
||
|
||
return {
|
||
"has_sources": True,
|
||
"claim_count": len(claims),
|
||
"specific_count": specific_count,
|
||
"vague_count": vague_count,
|
||
"quality_score": quality_score,
|
||
"claims": claims,
|
||
"suggestions": suggestions
|
||
}
|