添加产品规格文档并优化项目结构

Made-with: Cursor
This commit is contained in:
刘国栋
2026-04-30 18:37:46 +08:00
parent bf2551d529
commit fb309299bf
101 changed files with 9586 additions and 14386 deletions
+529
View File
@@ -0,0 +1,529 @@
"""
RAG 知识库模块
支持用户上传品牌文档,自动分块、索引,生成内容时自动检索相关内容
"""
import json
import hashlib
import logging
from pathlib import Path
from typing import List, Dict, Optional, Any
from datetime import datetime
logger = logging.getLogger(__name__)
class DocumentChunk:
"""文档分块"""
def __init__(self, content: str, metadata: Dict[str, Any]):
self.content = content
self.metadata = metadata
self.chunk_id = hashlib.md5(content.encode()).hexdigest()[:12]
def to_dict(self) -> Dict:
return {
"chunk_id": self.chunk_id,
"content": self.content,
"metadata": self.metadata
}
@classmethod
def from_dict(cls, data: Dict) -> 'DocumentChunk':
chunk = cls(data["content"], data["metadata"])
chunk.chunk_id = data["chunk_id"]
return chunk
class KnowledgeBase:
"""知识库管理器"""
def __init__(self, storage_path: str = "knowledge_base"):
"""
Args:
storage_path: 知识库存储路径
"""
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
# 文档元数据
self.documents_file = self.storage_path / "documents.json"
# 分块数据
self.chunks_file = self.storage_path / "chunks.json"
self.documents: Dict[str, Dict] = self._load_json(self.documents_file, {})
self.chunks: List[Dict] = self._load_json(self.chunks_file, [])
def _load_json(self, path: Path, default: Any) -> Any:
"""加载 JSON 文件"""
if path.exists():
try:
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.warning(f"加载 {path} 失败: {e}")
return default
def _save_json(self, path: Path, data: Any):
"""保存 JSON 文件"""
try:
with open(path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"保存 {path} 失败: {e}")
def add_document(self, filename: str, content: str, doc_type: str = "text",
metadata: Optional[Dict] = None) -> Dict:
"""
添加文档到知识库
Args:
filename: 文件名
content: 文档内容
doc_type: 文档类型 (text, markdown, faq, case, product)
metadata: 额外元数据
Returns:
文档信息
"""
doc_id = hashlib.md5(f"{filename}{datetime.now().isoformat()}".encode()).hexdigest()[:12]
doc_info = {
"doc_id": doc_id,
"filename": filename,
"doc_type": doc_type,
"content_length": len(content),
"chunk_count": 0,
"created_at": datetime.now().isoformat(),
"metadata": metadata or {}
}
# 分块
chunks = self._split_document(content, doc_id, filename, doc_type)
doc_info["chunk_count"] = len(chunks)
# 保存
self.documents[doc_id] = doc_info
self.chunks.extend([c.to_dict() for c in chunks])
self._save_json(self.documents_file, self.documents)
self._save_json(self.chunks_file, self.chunks)
logger.info(f"文档 '{filename}' 已添加,分为 {len(chunks)} 个分块")
return doc_info
def _split_document(self, content: str, doc_id: str, filename: str,
doc_type: str, chunk_size: int = 500, overlap: int = 50) -> List[DocumentChunk]:
"""
将文档分割为多个分块
Args:
content: 文档内容
doc_id: 文档 ID
filename: 文件名
doc_type: 文档类型
chunk_size: 分块大小(字符数)
overlap: 重叠字符数
Returns:
分块列表
"""
chunks = []
# 根据文档类型选择分块策略
if doc_type == "faq":
# FAQ 文档按 Q&A 对分块
chunks = self._split_faq(content, doc_id, filename)
elif doc_type == "product":
# 产品文档按功能/特性分块
chunks = self._split_by_sections(content, doc_id, filename, doc_type)
else:
# 通用文档按段落/长度分块
chunks = self._split_by_length(content, doc_id, filename, doc_type,
chunk_size, overlap)
return chunks
def _split_faq(self, content: str, doc_id: str, filename: str) -> List[DocumentChunk]:
"""FAQ 文档分块:每个 Q&A 对为一个分块"""
chunks = []
lines = content.split('\n')
current_q = ""
current_a = ""
for line in lines:
line = line.strip()
if not line:
continue
# 检测问题行
if line.startswith('Q:') or line.startswith('问:') or line.startswith('Q'):
# 保存上一个 Q&A 对
if current_q and current_a:
chunk_content = f"问题:{current_q}\n回答:{current_a}"
chunks.append(DocumentChunk(
content=chunk_content,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": "faq",
"question": current_q
}
))
current_q = line[2:].strip()
current_a = ""
elif line.startswith('A:') or line.startswith('答:') or line.startswith('A'):
current_a = line[2:].strip()
elif current_a:
current_a += "\n" + line
elif current_q and not current_a:
current_q += "\n" + line
# 保存最后一个 Q&A 对
if current_q and current_a:
chunk_content = f"问题:{current_q}\n回答:{current_a}"
chunks.append(DocumentChunk(
content=chunk_content,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": "faq",
"question": current_q
}
))
return chunks
def _split_by_sections(self, content: str, doc_id: str, filename: str,
doc_type: str) -> List[DocumentChunk]:
"""按章节分块(适用于产品文档、Markdown 等)"""
chunks = []
sections = content.split('\n# ')
for i, section in enumerate(sections):
if not section.strip():
continue
# 提取标题
lines = section.split('\n', 1)
title = lines[0].strip('# ').strip()
body = lines[1].strip() if len(lines) > 1 else ""
if body:
chunks.append(DocumentChunk(
content=f"## {title}\n{body}",
metadata={
"doc_id": doc_id,
"filename": filename,
"type": doc_type,
"section_title": title
}
))
return chunks
def _split_by_length(self, content: str, doc_id: str, filename: str,
doc_type: str, chunk_size: int, overlap: int) -> List[DocumentChunk]:
"""按长度分块"""
chunks = []
paragraphs = content.split('\n\n')
current_chunk = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
if len(current_chunk) + len(para) > chunk_size and current_chunk:
chunks.append(DocumentChunk(
content=current_chunk,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": doc_type
}
))
# 保留重叠部分
if overlap > 0:
current_chunk = current_chunk[-overlap:] + "\n" + para
else:
current_chunk = para
else:
current_chunk = current_chunk + "\n\n" + para if current_chunk else para
# 保存最后一个分块
if current_chunk.strip():
chunks.append(DocumentChunk(
content=current_chunk,
metadata={
"doc_id": doc_id,
"filename": filename,
"type": doc_type
}
))
return chunks
def search(self, query: str, top_k: int = 5,
doc_type: Optional[str] = None) -> List[Dict]:
"""
搜索知识库
Args:
query: 查询文本
top_k: 返回结果数量
doc_type: 过滤文档类型
Returns:
相关分块列表
"""
if not self.chunks:
return []
# 计算相似度分数
scored_chunks = []
query_lower = query.lower()
query_keywords = set(query_lower.split())
for chunk_data in self.chunks:
content_lower = chunk_data["content"].lower()
# 计算关键词匹配分数
keyword_matches = sum(1 for kw in query_keywords if kw in content_lower)
keyword_score = keyword_matches / len(query_keywords) if query_keywords else 0
# 计算内容相关性分数(包含查询词的比例)
content_score = 0
for kw in query_keywords:
if kw in content_lower:
# 计算关键词在内容中的密度
count = content_lower.count(kw)
content_score += count * len(kw) / len(content_lower)
# 综合分数
total_score = keyword_score * 0.6 + content_score * 0.4
if total_score > 0:
# 过滤文档类型
if doc_type and chunk_data["metadata"].get("type") != doc_type:
continue
scored_chunks.append({
"chunk": chunk_data,
"score": total_score
})
# 按分数排序并返回 top_k
scored_chunks.sort(key=lambda x: x["score"], reverse=True)
return [
{
"content": item["chunk"]["content"],
"metadata": item["chunk"]["metadata"],
"score": item["score"]
}
for item in scored_chunks[:top_k]
]
def get_context_for_generation(self, query: str, brand: str,
platform: str, top_k: int = 3) -> str:
"""
获取用于内容生成的上下文
Args:
query: 查询/主题
brand: 品牌名
platform: 目标平台
Returns:
格式化的上下文字符串
"""
# 搜索相关文档
results = self.search(query, top_k=top_k)
if not results:
return ""
# 组装上下文
context_parts = ["以下是相关的品牌/产品信息,可用于内容生成:\n"]
for i, result in enumerate(results, 1):
source = result["metadata"].get("filename", "未知来源")
content = result["content"]
context_parts.append(f"--- 参考资料 {i}(来源:{source}---")
context_parts.append(content)
context_parts.append("")
return "\n".join(context_parts)
def list_documents(self) -> List[Dict]:
"""列出所有文档"""
return list(self.documents.values())
def delete_document(self, doc_id: str) -> bool:
"""删除文档及其分块"""
if doc_id not in self.documents:
return False
# 删除文档
del self.documents[doc_id]
# 删除相关分块
self.chunks = [c for c in self.chunks if c["metadata"].get("doc_id") != doc_id]
# 保存
self._save_json(self.documents_file, self.documents)
self._save_json(self.chunks_file, self.chunks)
return True
def get_stats(self) -> Dict:
"""获取知识库统计信息"""
doc_types = {}
for doc in self.documents.values():
doc_type = doc.get("doc_type", "unknown")
doc_types[doc_type] = doc_types.get(doc_type, 0) + 1
return {
"total_documents": len(self.documents),
"total_chunks": len(self.chunks),
"document_types": doc_types
}
class SourceVerifier:
"""来源验证器"""
def __init__(self):
self.claim_patterns = [
"根据",
"",
"报告显示",
"数据表明",
"研究表明",
"调查发现",
"据统计",
"根据报告",
"根据数据",
"根据研究",
"according to",
"based on",
"as reported by",
"research shows",
"data shows"
]
def extract_claims(self, content: str) -> List[Dict]:
"""
从内容中提取来源声明
Args:
content: 文本内容
Returns:
声明列表
"""
claims = []
sentences = content.replace('', '\n').replace('.', '.\n').split('\n')
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
for pattern in self.claim_patterns:
if pattern in sentence.lower():
claims.append({
"text": sentence,
"pattern": pattern,
"verified": False,
"verification_result": None
})
break
return claims
def generate_verification_prompt(self, claim: str) -> str:
"""
生成验证提示词
Args:
claim: 来源声明
Returns:
验证提示词
"""
return f"""请验证以下声明的真实性:
声明:{claim}
请回答:
1. 这个声明是否包含可验证的具体来源(如具体报告名称、机构名称、数据年份)?
2. 如果包含,请判断这个来源是否可能存在且可信。
3. 如果无法验证或来源可疑,请说明原因。
回答格式:
- 来源具体性:[具体/模糊/无来源]
- 可信度评估:[高/中/低/无法判断]
- 建议:[保留/修改/删除]
- 原因:[简要说明]"""
def assess_source_quality(self, content: str) -> Dict:
"""
评估内容的来源质量
Args:
content: 文本内容
Returns:
质量评估结果
"""
claims = self.extract_claims(content)
if not claims:
return {
"has_sources": False,
"claim_count": 0,
"quality_score": 0,
"suggestions": ["内容中没有引用任何来源,建议添加数据支撑"]
}
# 分析来源质量
specific_count = 0
vague_count = 0
for claim in claims:
text = claim["text"]
# 检查是否有具体来源指标
has_specific = any([
any(year in text for year in ["2020", "2021", "2022", "2023", "2024", "2025"]),
any(org in text for org in ["Gartner", "IDC", "Forrester", "McKinsey",
"哈佛", "MIT", "斯坦福", "中科院"]),
"报告" in text and ("" in text or "" in text),
"数据" in text and any(c.isdigit() for c in text)
])
if has_specific:
specific_count += 1
else:
vague_count += 1
quality_score = min(100, (specific_count / len(claims)) * 100)
suggestions = []
if vague_count > 0:
suggestions.append(f"{vague_count} 个来源描述模糊,建议补充具体报告名称或数据年份")
if specific_count == 0:
suggestions.append("所有来源都不够具体,建议引用真实的行业报告或权威机构数据")
return {
"has_sources": True,
"claim_count": len(claims),
"specific_count": specific_count,
"vague_count": vague_count,
"quality_score": quality_score,
"claims": claims,
"suggestions": suggestions
}