Travel_Assistant / modules /knowledge_base.py
Eliot0110's picture
improve: knowledge base and re
05b4419
# modules/knowledge_base.py
import json
from pathlib import Path
from utils.logger import log
class KnowledgeBase:
def __init__(self, file_path: Path = Path("./config/general_travelplan.json")):
self.knowledge = []
self.city_index = {} # 城市索引
self.country_index = {} # 国家索引
self.region_index = {} # 地区索引
with open(file_path, 'r', encoding='utf-8') as f:
self.knowledge = json.load(f).get('clean_knowledge', [])
log.info(f"✅ 知识库加载完成")
def _build_indexes(self):
"""建立快速检索索引"""
for idx, item in enumerate(self.knowledge):
knowledge = item.get('knowledge', {}).get('travel_knowledge', {})
if not knowledge:
continue
dest_info = knowledge.get('destination_info', {})
# 建立城市索引
primary_destinations = dest_info.get('primary_destinations', [])
for city in primary_destinations:
if city not in self.city_index:
self.city_index[city] = []
self.city_index[city].append(idx)
# 建立国家索引
countries = dest_info.get('countries', [])
for country in countries:
if country not in self.country_index:
self.country_index[country] = []
self.country_index[country].append(idx)
# 建立地区索引
region_type = dest_info.get('region_type', '')
if region_type:
if region_type not in self.region_index:
self.region_index[region_type] = []
self.region_index[region_type].append(idx)
def search(self, query: str) -> list:
"""搜索知识库中的相关信息"""
relevant_knowledge = []
query_lower = query.lower()
log.info(f"🔍 在知识库中搜索: '{query}'")
# 1. 直接城市匹配
if query in self.city_index:
for idx in self.city_index[query]:
if self.knowledge[idx] not in relevant_knowledge:
relevant_knowledge.append(self.knowledge[idx])
log.info(f"✅ 通过城市直接匹配找到 {len(self.city_index[query])} 条记录")
# 2. 国家匹配
matching_country = self._find_country_for_city(query)
if matching_country and matching_country in self.country_index:
for idx in self.country_index[matching_country]:
if self.knowledge[idx] not in relevant_knowledge:
relevant_knowledge.append(self.knowledge[idx])
log.info(f"✅ 通过国家匹配({matching_country})找到额外记录")
# 3. 地区匹配
matching_region = self._find_region_for_city(query)
if matching_region and matching_region in self.region_index:
for idx in self.region_index[matching_region]:
if self.knowledge[idx] not in relevant_knowledge:
relevant_knowledge.append(self.knowledge[idx])
log.info(f"✅ 通过地区匹配({matching_region})找到额外记录")
# 4. 模糊匹配
if not relevant_knowledge:
log.info("🔍 尝试模糊匹配...")
for item in self.knowledge:
knowledge = item.get('knowledge', {}).get('travel_knowledge', {})
dest_info = knowledge.get('destination_info', {})
# 检查所有目的地
primary_destinations = dest_info.get('primary_destinations', [])
for dest in primary_destinations:
if query_lower in dest.lower() or dest.lower() in query_lower:
if item not in relevant_knowledge:
relevant_knowledge.append(item)
log.info(f"✅ 模糊匹配找到: {dest}")
break
log.info(f"📊 搜索完成,共找到 {len(relevant_knowledge)} 条相关记录")
return relevant_knowledge
def _find_country_for_city(self, city_name: str) -> str:
"""根据城市名查找所属国家"""
city_country_mapping = {
# 中欧
"布拉格": "捷克", "布尔诺": "捷克", "库特纳霍拉": "捷克",
"维也纳": "奥地利", "萨尔茨堡": "奥地利", "哈尔施塔特": "奥地利", "巴德伊舍": "奥地利",
"布达佩斯": "匈牙利", "德布勒森": "匈牙利", "圣安德烈": "匈牙利",
"布拉迪斯拉发": "斯洛伐克",
# 西欧
"巴黎": "法国", "里昂": "法国", "尼斯": "法国", "马赛": "法国",
"柏林": "德国", "慕尼黑": "德国", "汉堡": "德国", "科隆": "德国", "法兰克福": "德国",
"阿姆斯特丹": "荷兰", "鹿特丹": "荷兰", "海牙": "荷兰",
"布鲁塞尔": "比利时", "安特卫普": "比利时", "布吕赫": "比利时",
"卢森堡市": "卢森堡",
"苏黎世": "瑞士", "日内瓦": "瑞士", "因特拉肯": "瑞士",
# 南欧
"罗马": "意大利", "米兰": "意大利", "威尼斯": "意大利", "佛罗伦萨": "意大利",
"马德里": "西班牙", "巴塞罗那": "西班牙", "塞维利亚": "西班牙",
"里斯本": "葡萄牙", "波尔图": "葡萄牙",
"雅典": "希腊", "圣托里尼": "希腊", "米科诺斯": "希腊",
# 北欧
"斯德哥尔摩": "瑞典", "哥德堡": "瑞典",
"奥斯陆": "挪威", "卑尔根": "挪威",
"哥本哈根": "丹麦", "奥胡斯": "丹麦",
"赫尔辛基": "芬兰", "坦佩雷": "芬兰",
"雷克雅未克": "冰岛",
# 英国
"伦敦": "英国", "爱丁堡": "英国", "曼彻斯特": "英国",
}
return city_country_mapping.get(city_name, "")
def _find_region_for_city(self, city_name: str) -> str:
"""根据城市名查找所属地区"""
city_region_mapping = {
# 中欧
"布拉格": "中欧", "布尔诺": "中欧", "库特纳霍拉": "中欧",
"维也纳": "中欧", "萨尔茨堡": "中欧", "哈尔施塔特": "中欧", "巴德伊舍": "中欧",
"布达佩斯": "中欧", "德布勒森": "中欧", "圣安德烈": "中欧",
"布拉迪斯拉发": "中欧",
# 西欧
"巴黎": "西欧", "里昂": "西欧", "尼斯": "西欧",
"柏林": "西欧", "慕尼黑": "西欧", "汉堡": "西欧",
"阿姆斯特丹": "西欧", "鹿特丹": "西欧",
"布鲁塞尔": "西欧", "安特卫普": "西欧",
"苏黎世": "西欧", "日内瓦": "西欧",
# 东欧(按你的知识库分类)
"华沙": "东欧", "克拉科夫": "东欧",
"莫斯科": "东欧", "圣彼得堡": "东欧",
# 南欧
"罗马": "南欧", "米兰": "南欧", "威尼斯": "南欧",
"马德里": "南欧", "巴塞罗那": "南欧",
"里斯本": "南欧", "波尔图": "南欧",
"雅典": "南欧", "圣托里尼": "南欧",
# 北欧
"斯德哥尔摩": "北欧", "哥德堡": "北欧",
"奥斯陆": "北欧", "卑尔根": "北欧",
"哥本哈根": "北欧", "赫尔辛基": "北欧",
"雷克雅未克": "北欧",
}
return city_region_mapping.get(city_name, "")
def get_knowledge_by_destination(self, destination: str) -> dict:
"""根据目的地获取结构化的知识信息"""
relevant_items = self.search(destination)
if not relevant_items:
log.warning(f"⚠️ 未找到关于 '{destination}' 的知识")
return {}
# 合并所有相关知识
merged_knowledge = {
"destination_info": {},
"budget_analysis": {},
"detailed_itinerary": [],
"professional_insights": {}
}
for item in relevant_items:
knowledge = item.get('knowledge', {}).get('travel_knowledge', {})
# 合并目的地信息
if 'destination_info' in knowledge:
dest_info = knowledge['destination_info']
merged_knowledge['destination_info'].update(dest_info)
# 使用最详细的预算分析
if 'budget_analysis' in knowledge:
if not merged_knowledge['budget_analysis'] or len(knowledge['budget_analysis']) > len(merged_knowledge['budget_analysis']):
merged_knowledge['budget_analysis'] = knowledge['budget_analysis']
# 合并行程建议
if 'detailed_itinerary' in knowledge:
merged_knowledge['detailed_itinerary'].extend(knowledge['detailed_itinerary'])
# 合并专业洞察
if 'professional_insights' in knowledge:
for key, value in knowledge['professional_insights'].items():
if key not in merged_knowledge['professional_insights']:
merged_knowledge['professional_insights'][key] = value
elif isinstance(value, list):
# 合并列表,去重
existing = merged_knowledge['professional_insights'][key]
if isinstance(existing, list):
merged_knowledge['professional_insights'][key] = list(set(existing + value))
# 去重行程建议
if merged_knowledge['detailed_itinerary']:
seen_days = set()
unique_itinerary = []
for day_plan in merged_knowledge['detailed_itinerary']:
day_key = (day_plan.get('day_number', 0), day_plan.get('location', ''))
if day_key not in seen_days:
seen_days.add(day_key)
unique_itinerary.append(day_plan)
merged_knowledge['detailed_itinerary'] = unique_itinerary
log.info(f"📚 为 '{destination}' 合并了 {len(relevant_items)} 条知识记录")
return merged_knowledge
def get_similar_destinations(self, destination: str, limit: int = 5) -> list:
"""获取相似的目的地推荐"""
similar_destinations = []
# 找到目标城市的国家和地区
target_country = self._find_country_for_city(destination)
target_region = self._find_region_for_city(destination)
# 优先推荐同国家的其他城市
if target_country and target_country in self.country_index:
for idx in self.country_index[target_country]:
knowledge = self.knowledge[idx].get('knowledge', {}).get('travel_knowledge', {})
dest_info = knowledge.get('destination_info', {})
destinations = dest_info.get('primary_destinations', [])
for dest in destinations:
if dest != destination and dest not in similar_destinations:
similar_destinations.append(dest)
if len(similar_destinations) >= limit:
return similar_destinations
# 然后推荐同地区的城市
if target_region and target_region in self.region_index and len(similar_destinations) < limit:
for idx in self.region_index[target_region]:
knowledge = self.knowledge[idx].get('knowledge', {}).get('travel_knowledge', {})
dest_info = knowledge.get('destination_info', {})
destinations = dest_info.get('primary_destinations', [])
for dest in destinations:
if dest != destination and dest not in similar_destinations:
similar_destinations.append(dest)
if len(similar_destinations) >= limit:
return similar_destinations
return similar_destinations
def get_statistics(self) -> dict:
"""获取知识库统计信息"""
stats = {
"total_records": len(self.knowledge),
"cities_covered": len(self.city_index),
"countries_covered": len(self.country_index),
"regions_covered": len(self.region_index),
"cities_by_region": {},
"popular_cities": []
}
# 按地区统计城市数量
for region, indices in self.region_index.items():
cities_in_region = set()
for idx in indices:
knowledge = self.knowledge[idx].get('knowledge', {}).get('travel_knowledge', {})
dest_info = knowledge.get('destination_info', {})
cities_in_region.update(dest_info.get('primary_destinations', []))
stats["cities_by_region"][region] = len(cities_in_region)
# 找出出现频率最高的城市
city_frequency = {}
for city, indices in self.city_index.items():
city_frequency[city] = len(indices)
# 按出现频率排序
sorted_cities = sorted(city_frequency.items(), key=lambda x: x[1], reverse=True)
stats["popular_cities"] = sorted_cities[:10] # 前10个最热门城市
return stats