Files

530 lines
17 KiB
Python
Raw Permalink Normal View History

"""
RAG 知识库模块
支持用户上传品牌文档,自动分块、索引,生成内容时自动检索相关内容
"""
import json
import hashlib
import logging
from pathlib import Path
from typing import List, Dict, Optional, Any
from datetime import datetime
logger = logging.getLogger(__name__)
class DocumentChunk:
"""文档分块"""
def __init__(self, content: str, metadata: Dict[str, Any]):
self.content = content
self.metadata = metadata
self.chunk_id = hashlib.md5(content.encode()).hexdigest()[:12]
def to_dict(self) -> Dict:
return {
"chunk_id": self.chunk_id,
"content": self.content,
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: Dict) -> 'DocumentChunk':
chunk = cls(data["content"], data["metadata"])
chunk.chunk_id = data["chunk_id"]
return chunk
class KnowledgeBase:
"""知识库管理器"""
def __init__(self, storage_path: str = "knowledge_base"):
"""
Args:
storage_path: 知识库存储路径
"""
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
# 文档元数据
self.documents_file = self.storage_path / "documents.json"
# 分块数据
self.chunks_file = self.storage_path / "chunks.json"
self.documents: Dict[str, Dict] = self._load_json(self.documents_file, {})
self.chunks: List[Dict] = self._load_json(self.chunks_file, [])
def _load_json(self, path: Path, default: Any) -> Any:
"""加载 JSON 文件"""
if path.exists():
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.warning(f"加载 {path} 失败: {e}")
return default
def _save_json(self, path: Path, data: Any):
"""保存 JSON 文件"""
try:
with open(path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"保存 {path} 失败: {e}")
def add_document(self, filename: str, content: str, doc_type: str = "text",
metadata: Optional[Dict] = None) -> Dict:
"""
添加文档到知识库
Args:
filename: 文件名
content: 文档内容
doc_type: 文档类型 (text, markdown, faq, case, product)
metadata: 额外元数据
Returns:
文档信息
"""
doc_id = hashlib.md5(f"{filename}{datetime.now().isoformat()}".encode()).hexdigest()[:12]
doc_info = {
"doc_id": doc_id,
"filename": filename,
"doc_type": doc_type,
"content_length": len(content),
"chunk_count": 0,
"created_at": datetime.now().isoformat(),
"metadata": metadata or {}
}
# 分块
chunks = self._split_document(content, doc_id, filename, doc_type)
doc_info["chunk_count"] = len(chunks)
# 保存
self.documents[doc_id] = doc_info
self.chunks.extend([c.to_dict() for c in chunks])
self._save_json(self.documents_file, self.documents)
self._save_json(self.chunks_file, self.chunks)
logger.info(f"文档 '{filename}' 已添加,分为 {len(chunks)} 个分块")
return doc_info
def _split_document(self, content: str, doc_id: str, filename: str,
doc_type: str, chunk_size: int = 500, overlap: int = 50) -> List[DocumentChunk]:
"""
将文档分割为多个分块
Args:
content: 文档内容
doc_id: 文档 ID
filename: 文件名
doc_type: 文档类型
chunk_size: 分块大小(字符数)
overlap: 重叠字符数
Returns:
分块列表
"""
chunks = []
# 根据文档类型选择分块策略
if doc_type == "faq":
# FAQ 文档按 Q&A 对分块
chunks = self._split_faq(content, doc_id, filename)
elif doc_type == "product":
# 产品文档按功能/特性分块
chunks = self._split_by_sections(content, doc_id, filename, doc_type)
else:
# 通用文档按段落/长度分块
chunks = self._split_by_length(content, doc_id, filename, doc_type,
chunk_size, overlap)
return chunks
def _split_faq(self, content: str, doc_id: str, filename: str) -> List[DocumentChunk]:
"""FAQ 文档分块:每个 Q&A 对为一个分块"""
chunks = []
lines = content.split('\n')
current_q = ""
current_a = ""
for line in lines:
line = line.strip()
if not line:
continue
# 检测问题行
if line.startswith('Q:') or line.startswith('问:') or line.startswith('Q'):
# 保存上一个 Q&A 对
if current_q and current_a:
chunk_content = f"问题:{current_q}\n回答:{current_a}"
chunks.append(DocumentChunk(
content=chunk_content,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": "faq",
"question": current_q
}
))
current_q = line[2:].strip()
current_a = ""
elif line.startswith('A:') or line.startswith('答:') or line.startswith('A'):
current_a = line[2:].strip()
elif current_a:
current_a += "\n" + line
elif current_q and not current_a:
current_q += "\n" + line
# 保存最后一个 Q&A 对
if current_q and current_a:
chunk_content = f"问题:{current_q}\n回答:{current_a}"
chunks.append(DocumentChunk(
content=chunk_content,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": "faq",
"question": current_q
}
))
return chunks
def _split_by_sections(self, content: str, doc_id: str, filename: str,
doc_type: str) -> List[DocumentChunk]:
"""按章节分块(适用于产品文档、Markdown 等)"""
chunks = []
sections = content.split('\n# ')
for i, section in enumerate(sections):
if not section.strip():
continue
# 提取标题
lines = section.split('\n', 1)
title = lines[0].strip('# ').strip()
body = lines[1].strip() if len(lines) > 1 else ""
if body:
chunks.append(DocumentChunk(
content=f"## {title}\n{body}",
metadata={
"doc_id": doc_id,
"filename": filename,
"type": doc_type,
"section_title": title
}
))
return chunks
def _split_by_length(self, content: str, doc_id: str, filename: str,
doc_type: str, chunk_size: int, overlap: int) -> List[DocumentChunk]:
"""按长度分块"""
chunks = []
paragraphs = content.split('\n\n')
current_chunk = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(current_chunk) + len(para) > chunk_size and current_chunk:
chunks.append(DocumentChunk(
content=current_chunk,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": doc_type
}
))
# 保留重叠部分
if overlap > 0:
current_chunk = current_chunk[-overlap:] + "\n" + para
else:
current_chunk = para
else:
current_chunk = current_chunk + "\n\n" + para if current_chunk else para
# 保存最后一个分块
if current_chunk.strip():
chunks.append(DocumentChunk(
content=current_chunk,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": doc_type
}
))
return chunks
def search(self, query: str, top_k: int = 5,
doc_type: Optional[str] = None) -> List[Dict]:
"""
搜索知识库
Args:
query: 查询文本
top_k: 返回结果数量
doc_type: 过滤文档类型
Returns:
相关分块列表
"""
if not self.chunks:
return []
# 计算相似度分数
scored_chunks = []
query_lower = query.lower()
query_keywords = set(query_lower.split())
for chunk_data in self.chunks:
content_lower = chunk_data["content"].lower()
# 计算关键词匹配分数
keyword_matches = sum(1 for kw in query_keywords if kw in content_lower)
keyword_score = keyword_matches / len(query_keywords) if query_keywords else 0
# 计算内容相关性分数(包含查询词的比例)
content_score = 0
for kw in query_keywords:
if kw in content_lower:
# 计算关键词在内容中的密度
count = content_lower.count(kw)
content_score += count * len(kw) / len(content_lower)
# 综合分数
total_score = keyword_score * 0.6 + content_score * 0.4
if total_score > 0:
# 过滤文档类型
if doc_type and chunk_data["metadata"].get("type") != doc_type:
continue
scored_chunks.append({
"chunk": chunk_data,
"score": total_score
})
# 按分数排序并返回 top_k
scored_chunks.sort(key=lambda x: x["score"], reverse=True)
return [
{
"content": item["chunk"]["content"],
"metadata": item["chunk"]["metadata"],
"score": item["score"]
}
for item in scored_chunks[:top_k]
]
def get_context_for_generation(self, query: str, brand: str,
platform: str, top_k: int = 3) -> str:
"""
获取用于内容生成的上下文
Args:
query: 查询/主题
brand: 品牌名
platform: 目标平台
Returns:
格式化的上下文字符串
"""
# 搜索相关文档
results = self.search(query, top_k=top_k)
if not results:
return ""
# 组装上下文
context_parts = ["以下是相关的品牌/产品信息,可用于内容生成:\n"]
for i, result in enumerate(results, 1):
source = result["metadata"].get("filename", "未知来源")
content = result["content"]
context_parts.append(f"--- 参考资料 {i}(来源:{source}---")
context_parts.append(content)
context_parts.append("")
return "\n".join(context_parts)
def list_documents(self) -> List[Dict]:
"""列出所有文档"""
return list(self.documents.values())
def delete_document(self, doc_id: str) -> bool:
"""删除文档及其分块"""
if doc_id not in self.documents:
return False
# 删除文档
del self.documents[doc_id]
# 删除相关分块
self.chunks = [c for c in self.chunks if c["metadata"].get("doc_id") != doc_id]
# 保存
self._save_json(self.documents_file, self.documents)
self._save_json(self.chunks_file, self.chunks)
return True
def get_stats(self) -> Dict:
"""获取知识库统计信息"""
doc_types = {}
for doc in self.documents.values():
doc_type = doc.get("doc_type", "unknown")
doc_types[doc_type] = doc_types.get(doc_type, 0) + 1
return {
"total_documents": len(self.documents),
"total_chunks": len(self.chunks),
"document_types": doc_types
}
class SourceVerifier:
"""来源验证器"""
def __init__(self):
self.claim_patterns = [
"根据",
"",
"报告显示",
"数据表明",
"研究表明",
"调查发现",
"据统计",
"根据报告",
"根据数据",
"根据研究",
"according to",
"based on",
"as reported by",
"research shows",
"data shows"
]
def extract_claims(self, content: str) -> List[Dict]:
"""
从内容中提取来源声明
Args:
content: 文本内容
Returns:
声明列表
"""
claims = []
sentences = content.replace('', '\n').replace('.', '.\n').split('\n')
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
for pattern in self.claim_patterns:
if pattern in sentence.lower():
claims.append({
"text": sentence,
"pattern": pattern,
"verified": False,
"verification_result": None
})
break
return claims
def generate_verification_prompt(self, claim: str) -> str:
"""
生成验证提示词
Args:
claim: 来源声明
Returns:
验证提示词
"""
return f"""请验证以下声明的真实性:
声明:{claim}
请回答:
1. 这个声明是否包含可验证的具体来源(如具体报告名称、机构名称、数据年份)?
2. 如果包含,请判断这个来源是否可能存在且可信。
3. 如果无法验证或来源可疑,请说明原因。
回答格式:
- 来源具体性:[具体/模糊/无来源]
- 可信度评估:[高/中/低/无法判断]
- 建议:[保留/修改/删除]
- 原因:[简要说明]"""
def assess_source_quality(self, content: str) -> Dict:
"""
评估内容的来源质量
Args:
content: 文本内容
Returns:
质量评估结果
"""
claims = self.extract_claims(content)
if not claims:
return {
"has_sources": False,
"claim_count": 0,
"quality_score": 0,
"suggestions": ["内容中没有引用任何来源,建议添加数据支撑"]
}
# 分析来源质量
specific_count = 0
vague_count = 0
for claim in claims:
text = claim["text"]
# 检查是否有具体来源指标
has_specific = any([
any(year in text for year in ["2020", "2021", "2022", "2023", "2024", "2025"]),
any(org in text for org in ["Gartner", "IDC", "Forrester", "McKinsey",
"哈佛", "MIT", "斯坦福", "中科院"]),
"报告" in text and ("" in text or "" in text),
"数据" in text and any(c.isdigit() for c in text)
])
if has_specific:
specific_count += 1
else:
vague_count += 1
quality_score = min(100, (specific_count / len(claims)) * 100)
suggestions = []
if vague_count > 0:
suggestions.append(f"{vague_count} 个来源描述模糊,建议补充具体报告名称或数据年份")
if specific_count == 0:
suggestions.append("所有来源都不够具体,建议引用真实的行业报告或权威机构数据")
return {
"has_sources": True,
"claim_count": len(claims),
"specific_count": specific_count,
"vague_count": vague_count,
"quality_score": quality_score,
"claims": claims,
"suggestions": suggestions
}