内容纲要
Agent 记忆系统设计
目录
1. 记忆系统概述
1.1 记忆类型
┌─────────────────────────────────────────────────────┐
│ Agent 记忆系统架构 │
├─────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌──────────┐ ┌─────────────┐ │
│ │短期记忆 │ │长期记忆 │ │结构化记忆 │ │
│ │(会话) │ │(向量存储) │ │(键值存储) │ │
│ └────┬────┘ └────┬─────┘ └──────┬──────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────┐ │
│ │ 记忆管理器 (Memory Manager) │ │
│ │ - 添加记忆 │ │
│ │ - 检索记忆 │ │
│ │ - 更新记忆 │ │
│ │ - 删除记忆 │ │
│ └─────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────┘
1.2 记忆类型对比
| 记忆类型 | 存储介质 | 持久化 | 检索方式 | 使用场景 |
|---|---|---|---|---|
| 短期记忆 | 内存 | 否 | 顺序/滑动窗口 | 当前会话上下文 |
| 长期记忆 | 向量数据库 | 是 | 语义相似度搜索 | 历史对话、知识 |
| 结构化记忆 | 数据库/Redis | 是 | 键值查询 | 用户偏好、状态 |
1.3 记忆价值
┌─────────────────────────────────────────────────┐
│ 记忆系统的价值 │
├─────────────────────────────────────────────────┤
│ ✓ 保持对话连贯性 │
│ ✓ 学习用户偏好 │
│ ✓ 避免重复内容 │
│ ✓ 支持跨会话引用 │
│ ✓ 实现个性化交互 │
│ ✓ 记住重要事实 │
└─────────────────────────────────────────────────┘
2. 短期记忆
2.1 基本实现
from typing import List, Dict, Optional
from dataclasses import dataclass
from datetime import datetime
import json
@dataclass
class Message:
role: str # "system", "user", "assistant", "tool"
content: str
timestamp: datetime = None
metadata: Dict = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = datetime.utcnow()
if self.metadata is None:
self.metadata = {}
class ShortTermMemory:
"""短期记忆 - 对话历史"""
def __init__(
self,
max_messages: int = 100,
max_tokens: int = 4000,
strategy: str = "sliding" # sliding, lru, summary
):
self.max_messages = max_messages
self.max_tokens = max_tokens
self.strategy = strategy
self.messages: List[Message] = []
def add(self, role: str, content: str, metadata: Dict = None) -> bool:
"""添加消息"""
message = Message(role=role, content=content, metadata=metadata or {})
# 检查是否超出限制
if not self._check_limits(message):
self._evict()
self.messages.append(message)
return True
def _check_limits(self, message: Message) -> bool:
"""检查是否添加消息会超出限制"""
message_tokens = count_tokens(message.content)
# 检查消息数量
if len(self.messages) >= self.max_messages:
return False
# 检查token数量
current_tokens = sum(count_tokens(m.content) for m in self.messages)
if current_tokens + message_tokens > self.max_tokens:
return False
return True
def _evict(self):
"""淘汰消息,根据策略"""
if self.strategy == "sliding":
# 滑动窗口:删除最旧的消息
if self.messages:
self.messages.pop(0)
elif self.strategy == "lru":
# LRU:删除最不常访问的(简化:最旧)
if self.messages:
self.messages.pop(0)
elif self.strategy == "summary":
# 摘要:将旧消息摘要
self._summarize_old()
def _summarize_old(self):
"""摘要旧消息"""
# 将前半部分消息摘要为一个系统消息
if len(self.messages) < 10:
return
# 摘要前一半消息
split_point = len(self.messages) // 2
old_messages = self.messages[:split_point]
summary = self._create_summary(old_messages)
# 替换为摘要
self.messages = [
Message(role="system", content=summary)
] + self.messages[split_point:]
def _create_summary(self, messages: List[Message]) -> str:
"""创建消息摘要"""
# 简单实现:提取关键信息
# 实际可以使用LLM摘要
return f"历史对话摘要:{len(messages)}条消息"
def get_recent(self, n: int = 10) -> List[Message]:
"""获取最近N条消息"""
return self.messages[-n:]
def get_all(self) -> List[Message]:
"""获取所有消息"""
return self.messages.copy()
def to_messages_format(self) -> List[Dict]:
"""转换为LLM消息格式"""
return [
{
"role": m.role,
"content": m.content
}
for m in self.messages
]
def clear(self):
"""清空记忆"""
self.messages.clear()
def filter_by_role(self, role: str) -> List[Message]:
"""按角色过滤消息"""
return [m for m in self.messages if m.role == role]
def search(self, keyword: str) -> List[Message]:
"""搜索包含关键词的消息"""
return [
m for m in self.messages
if keyword.lower() in m.content.lower()
]
2.2 滑动窗口实现
class SlidingWindowMemory:
"""滑动窗口记忆"""
def __init__(
self,
window_size: int = 10,
keep_system: bool = True,
keep_first_user: bool = True
):
self.window_size = window_size
self.keep_system = keep_system
self.keep_first_user = keep_first_user
self.messages: List[Message] = []
self.first_user_message: Optional[Message] = None
def add(self, message: Message):
"""添加消息"""
# 保存第一条用户消息
if message.role == "user" and self.first_user_message is None:
self.first_user_message = message
self.messages.append(message)
self._apply_window()
def _apply_window(self):
"""应用滑动窗口"""
# 计算需要保留的消息数量
keep_count = self.window_size
if self.keep_system:
keep_count -= 1 # 预留系统消息位置
if self.first_user_message and self.first_user_message in self.messages:
keep_count -= 1 # 预留第一条用户消息
# 保留系统消息
system_messages = [
m for m in self.messages if m.role == "system"
]
# 保留最近N条消息
recent_messages = self.messages[-keep_count:] if keep_count > 0 else []
# 重建消息列表
result = []
# 添加系统消息
if self.keep_system:
result.extend(system_messages)
# 添加第一条用户消息
if self.first_user_message and self.keep_first_user:
result.append(self.first_user_message)
# 添加最近消息(去重)
for msg in recent_messages:
if msg not in result:
result.append(msg)
self.messages = result
def get_context(self) -> List[Dict]:
"""获取上下文"""
return [
{"role": m.role, "content": m.content}
for m in self.messages
]
2.3 对话摘要
import openai
class ConversationSummarizer:
"""对话摘要器"""
def __init__(self, llm=None):
self.llm = llm or openai.ChatCompletion
async def summarize(
self,
messages: List[Message],
max_length: int = 500
) -> str:
"""摘要对话"""
# 格式化消息
conversation = "\n".join([
f"{m.role}: {m.content}"
for m in messages
])
prompt = f"""请将以下对话摘要为简短的总结,保留关键信息。
对话:
{conversation}
请输出摘要(不超过{max_length}字):"""
response = await self.llm.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
max_tokens=200,
temperature=0.3
)
return response.choices[0].message.content
async def summarize_with_key_points(
self,
messages: List[Message]
) -> Dict:
"""摘要并提取关键点"""
conversation = "\n".join([
f"{m.role}: {m.content}"
for m in messages
])
prompt = f"""分析以下对话,提供摘要和关键点。
对话:
{conversation}
请以JSON格式输出:
{{
"summary": "对话摘要",
"key_points": ["关键点1", "关键点2"],
"topics": ["讨论主题1", "讨论主题2"]
}}"""
response = await self.llm.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"}
)
import json
return json.loads(response.choices[0].message.content)
3. 长期记忆
3.1 向量存储实现
from typing import List, Optional
from sentence_transformers import SentenceTransformer
import numpy as np
@dataclass
class MemoryItem:
id: str
content: str
embedding: List[float] = None
metadata: Dict = None
importance: float = 1.0 # 重要性分数
access_count: int = 0
last_accessed: datetime = None
created_at: datetime = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
if self.last_accessed is None:
self.last_accessed = datetime.utcnow()
if self.created_at is None:
self.created_at = datetime.utcnow()
class LongTermMemory:
"""长期记忆 - 基于向量数据库"""
def __init__(
self,
embedding_model: str = "BAAI/bge-small-zh-v1.5",
storage_path: str = None
):
self.embedder = SentenceTransformer(embedding_model)
self.storage_path = storage_path
self._memories: Dict[str, MemoryItem] = {}
self._load_from_storage()
def add(
self,
content: str,
metadata: Dict = None,
importance: float = 1.0
) -> str:
"""添加记忆"""
# 生成ID
memory_id = self._generate_id()
# 生成嵌入
embedding = self.embedder.encode(content).tolist()
# 创建记忆项
memory = MemoryItem(
id=memory_id,
content=content,
embedding=embedding,
metadata=metadata or {},
importance=importance
)
# 存储
self._memories[memory_id] = memory
# 持久化
self._save_to_storage()
return memory_id
def search(
self,
query: str,
top_k: int = 5,
min_score: float = 0.5
) -> List[tuple[MemoryItem, float]]:
"""
语义搜索记忆
Returns:
[(memory, score), ...]
"""
# 生成查询嵌入
query_embedding = self.embedder.encode(query)
# 计算相似度
results = []
for memory in self._memories.values():
score = self._cosine_similarity(
query_embedding,
memory.embedding
)
# 考虑重要性
final_score = score * memory.importance
if final_score >= min_score:
results.append((memory, final_score))
# 更新访问信息
for memory, _ in results:
memory.access_count += 1
memory.last_accessed = datetime.utcnow()
# 按分数排序
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def _cosine_similarity(
self,
vec1: np.ndarray,
vec2: List[float]
) -> float:
"""计算余弦相似度"""
vec2 = np.array(vec2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
return dot_product / (norm1 * norm2) if norm1 * norm2 > 0 else 0
def get(self, memory_id: str) -> Optional[MemoryItem]:
"""获取记忆"""
memory = self._memories.get(memory_id)
if memory:
memory.access_count += 1
memory.last_accessed = datetime.utcnow()
return memory
def update(self, memory_id: str, **kwargs):
"""更新记忆"""
if memory_id in self._memories:
memory = self._memories[memory_id]
for key, value in kwargs.items():
setattr(memory, key, value)
# 如果更新内容,重新生成嵌入
if "content" in kwargs:
memory.embedding = self.embedder.encode(
kwargs["content"]
).tolist()
self._save_to_storage()
def delete(self, memory_id: str) -> bool:
"""删除记忆"""
if memory_id in self._memories:
del self._memories[memory_id]
self._save_to_storage()
return True
return False
def get_all(self) -> List[MemoryItem]:
"""获取所有记忆"""
return list(self._memories.values())
def filter_by_metadata(
self,
key: str,
value: any
) -> List[MemoryItem]:
"""按元数据过滤"""
return [
m for m in self._memories.values()
if m.metadata.get(key) == value
]
def get_important_memories(
self,
min_importance: float = 0.8
) -> List[MemoryItem]:
"""获取重要记忆"""
return [
m for m in self._memories.values()
if m.importance >= min_importance
]
def _generate_id(self) -> str:
"""生成记忆ID"""
import uuid
return str(uuid.uuid4())
def _load_from_storage(self):
"""从存储加载"""
if not self.storage_path:
return
# 简化实现:从JSON加载
# 实际应用中应使用向量数据库
pass
def _save_to_storage(self):
"""保存到存储"""
if not self.storage_path:
return
# 简化实现:保存到JSON
# 实际应用中应使用向量数据库
pass
3.2 分层记忆
class HierarchicalLongTermMemory:
"""分层长期记忆"""
def __init__(self):
# 三层记忆
self.immediate = LongTermMemory() # 立即记忆(最近)
self.recent = LongTermMemory() # 近期记忆(周内)
self.long_term = LongTermMemory() # 长期记忆(历史)
# 时间阈值(天)
self.immediate_threshold = 1
self.recent_threshold = 7
def add(
self,
content: str,
metadata: Dict = None,
importance: float = 1.0
) -> str:
"""添加记忆到立即层"""
return self.immediate.add(
content, metadata, importance
)
def search(
self,
query: str,
top_k: int = 5
) -> List[tuple[MemoryItem, float]]:
"""
搜索所有层
优先从立即层搜索,然后近期层,最后长期层
"""
results = []
# 从各层搜索
immediate_results = self.immediate.search(query, top_k * 2)
recent_results = self.recent.search(query, top_k * 2)
long_results = self.long_term.search(query, top_k * 2)
# 合并并去重
seen_ids = set()
for memory, score in immediate_results:
if memory.id not in seen_ids:
results.append((memory, score * 1.0)) # 立即层权重1.0
seen_ids.add(memory.id)
for memory, score in recent_results:
if memory.id not in seen_ids:
results.append((memory, score * 0.8)) # 近期层权重0.8
seen_ids.add(memory.id)
for memory, score in long_results:
if memory.id not in seen_ids:
results.append((memory, score * 0.5)) # 长期层权重0.5
seen_ids.add(memory.id)
# 排序
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
def consolidate(self):
"""合并记忆到下一层"""
now = datetime.utcnow()
# 将立即层合并到近期层
for memory in self.immediate.get_all():
days_old = (now - memory.created_at).days
if days_old > self.immediate_threshold:
self.recent._memories[memory.id] = memory
del self.immediate._memories[memory.id]
# 将近期层合并到长期层
for memory in self.recent.get_all():
days_old = (now - memory.created_at).days
if days_old > self.recent_threshold:
self.long_term._memories[memory.id] = memory
del self.recent._memories[memory.id]
3.3 记忆重要性评估
class MemoryImportanceEvaluator:
"""记忆重要性评估器"""
def __init__(self, llm=None):
self.llm = llm or openai.ChatCompletion
async def evaluate(self, content: str) -> float:
"""评估记忆内容的重要性
Returns:
0.0-1.0 的重要性分数
"""
prompt = f"""评估以下信息的重要性。
分数范围:
- 0.0-0.2: 不重要,可以忽略
- 0.2-0.4: 低重要性
- 0.4-0.6: 中等重要性
- 0.6-0.8: 高重要性
- 0.8-1.0: 非常重要,需要记住
信息:
{content}
请只输出一个0.0到1.0之间的数字:"""
response = await self.llm.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
max_tokens=10
)
try:
score = float(response.choices[0].message.content.strip())
return max(0.0, min(1.0, score))
except:
return 0.5 # 默认中等重要性
def evaluate_by_signals(
self,
content: str,
signals: Dict
) -> float:
"""
基于信号评估重要性
Signals:
- user_explicit: 用户明确要求记住
- repeated: 内容重复出现
- emotional: 包含情绪词
- actionable: 可操作的信息
- personal: 个人信息
"""
score = 0.5 # 基础分
if signals.get("user_explicit"):
score += 0.3
if signals.get("repeated"):
score += 0.2
if signals.get("emotional"):
score += 0.1
if signals.get("actionable"):
score += 0.15
if signals.get("personal"):
score += 0.2
# 检测内容特征
content_lower = content.lower()
# 检测个人信息
personal_keywords = ["我叫", "我的名字", "我住在", "我的电话"]
if any(kw in content_lower for kw in personal_keywords):
score += 0.2
# 检测重要指示词
important_keywords = ["记住", "重要", "别忘了", "关键"]
if any(kw in content_lower for kw in important_keywords):
score += 0.15
return min(1.0, score)
def extract_facts(self, content: str) -> List[Dict]:
"""提取需要记住的事实"""
prompt = f"""从以下内容中提取需要记住的事实。
内容:
{content}
请以JSON格式输出提取的事实:
{{
"facts": [
{{
"fact": "事实内容",
"category": "类别",
"importance": 0.8
}}
]
}}"""
response = self.llm.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"}
)
import json
data = json.loads(response.choices[0].message.content)
return data.get("facts", [])
4. 结构化记忆
4.1 用户偏好记忆
from typing import Dict, Any
class UserPreferenceMemory:
"""用户偏好记忆"""
def __init__(self, storage=None):
self.storage = storage or {}
self._prefix = "pref:"
def set(self, user_id: str, key: str, value: Any):
"""设置用户偏好"""
storage_key = f"{self._prefix}{user_id}:{key}"
if user_id not in self.storage:
self.storage[user_id] = {}
self.storage[user_id][key] = value
def get(self, user_id: str, key: str, default=None) -> Any:
"""获取用户偏好"""
if user_id not in self.storage:
return default
return self.storage[user_id].get(key, default)
def get_all(self, user_id: str) -> Dict:
"""获取所有用户偏好"""
return self.storage.get(user_id, {}).copy()
def delete(self, user_id: str, key: str):
"""删除用户偏好"""
if user_id in self.storage and key in self.storage[user_id]:
del self.storage[user_id][key]
def update(self, user_id: str, preferences: Dict):
"""批量更新偏好"""
if user_id not in self.storage:
self.storage[user_id] = {}
self.storage[user_id].update(preferences)
# 常用偏好设置
def set_language(self, user_id: str, language: str):
"""设置语言偏好"""
self.set(user_id, "language", language)
def get_language(self, user_id: str) -> str:
"""获取语言偏好"""
return self.get(user_id, "language", "zh-CN")
def set_response_style(self, user_id: str, style: str):
"""设置响应风格"""
# style: concise, detailed, formal, casual
self.set(user_id, "response_style", style)
def get_response_style(self, user_id: str) -> str:
"""获取响应风格"""
return self.get(user_id, "response_style", "balanced")
def set_topics(self, user_id: str, topics: List[str]):
"""设置感兴趣的话题"""
self.set(user_id, "topics", topics)
def get_topics(self, user_id: str) -> List[str]:
"""获取感兴趣的话题"""
return self.get(user_id, "topics", [])
4.2 会话状态记忆
from typing import Optional
import pickle
class SessionStateMemory:
"""会话状态记忆"""
def __init__(self, ttl: int = 3600):
"""
Args:
ttl: 生存时间(秒)
"""
self.states: Dict[str, Dict] = {}
self.timestamps: Dict[str, datetime] = {}
self.ttl = ttl
def set(self, session_id: str, key: str, value: Any):
"""设置会话状态"""
if session_id not in self.states:
self.states[session_id] = {}
self.states[session_id][key] = value
self.timestamps[session_id] = datetime.utcnow()
def get(self, session_id: str, key: str, default=None) -> Any:
"""获取会话状态"""
self._check_ttl(session_id)
if session_id not in self.states:
return default
return self.states[session_id].get(key, default)
def get_all(self, session_id: str) -> Dict:
"""获取所有会话状态"""
self._check_ttl(session_id)
return self.states.get(session_id, {}).copy()
def delete(self, session_id: str, key: str = None):
"""删除会话状态"""
if session_id not in self.states:
return
if key is None:
# 删除整个会话
del self.states[session_id]
del self.timestamps[session_id]
else:
# 删除特定key
if key in self.states[session_id]:
del self.states[session_id][key]
def _check_ttl(self, session_id: str):
"""检查TTL"""
if session_id not in self.timestamps:
return
age = (datetime.utcnow() - self.timestamps[session_id]).total_seconds()
if age > self.ttl:
self.delete(session_id)
# 常用状态方法
def set_context(self, session_id: str, context: Dict):
"""设置上下文"""
self.set(session_id, "context", context)
def get_context(self, session_id: str) -> Dict:
"""获取上下文"""
return self.get(session_id, "context", {})
def set_user_info(self, session_id: str, user_info: Dict):
"""设置用户信息"""
self.set(session_id, "user_info", user_info)
def get_user_info(self, session_id: str) -> Dict:
"""获取用户信息"""
return self.get(session_id, "user_info", {})
def set_step(self, session_id: str, step: str):
"""设置当前步骤"""
self.set(session_id, "current_step", step)
def get_step(self, session_id: str) -> Optional[str]:
"""获取当前步骤"""
return self.get(session_id, "current_step")
4.3 事实记忆
class FactMemory:
"""事实记忆 - 存储关键事实"""
def __init__(self):
self.facts: Dict[str, Dict] = {}
def add_fact(
self,
fact_id: str,
fact: str,
category: str = "general",
confidence: float = 1.0
):
"""添加事实"""
self.facts[fact_id] = {
"fact": fact,
"category": category,
"confidence": confidence,
"created_at": datetime.utcnow(),
"verified": False
}
def get_fact(self, fact_id: str) -> Optional[Dict]:
"""获取事实"""
return self.facts.get(fact_id)
def get_by_category(self, category: str) -> List[Dict]:
"""按类别获取事实"""
return [
{"id": k, **v}
for k, v in self.facts.items()
if v["category"] == category
]
def verify_fact(self, fact_id: str, is_correct: bool):
"""验证事实"""
if fact_id in self.facts:
self.facts[fact_id]["verified"] = is_correct
def search_facts(self, keyword: str) -> List[Dict]:
"""搜索事实"""
keyword_lower = keyword.lower()
return [
{"id": k, **v}
for k, v in self.facts.items()
if keyword_lower in v["fact"].lower()
]
# 个人信息事实
def set_personal_info(self, user_id: str, info_type: str, value: str):
"""设置个人信息"""
fact_id = f"{user_id}:{info_type}"
self.add_fact(
fact_id=fact_id,
fact=value,
category="personal",
confidence=0.95
)
def get_personal_info(self, user_id: str, info_type: str) -> Optional[str]:
"""获取个人信息"""
fact_id = f"{user_id}:{info_type}"
fact = self.get_fact(fact_id)
return fact["fact"] if fact else None
5. 记忆检索
5.1 混合记忆检索
class HybridMemoryRetriever:
"""混合记忆检索器"""
def __init__(
self,
short_term: ShortTermMemory,
long_term: LongTermMemory,
preferences: UserPreferenceMemory
):
self.short_term = short_term
self.long_term = long_term
self.preferences = preferences
def retrieve(
self,
query: str,
user_id: str = None,
top_k: int = 10
) -> Dict:
"""
混合检索记忆
Returns:
{
"short_term": [...],
"long_term": [...],
"preferences": {...},
"context": "组合的上下文"
}
"""
# 1. 检索短期记忆(搜索关键词)
short_term_results = self.short_term.search(query)
# 2. 检索长期记忆(语义搜索)
long_term_results = self.long_term.search(query, top_k=top_k)
# 3. 检索用户偏好
user_preferences = None
if user_id:
user_preferences = self.preferences.get_all(user_id)
# 4. 构建组合上下文
context = self._build_context(
short_term_results,
long_term_results,
user_preferences
)
return {
"short_term": short_term_results,
"long_term": long_term_results,
"preferences": user_preferences,
"context": context
}
def _build_context(
self,
short_term: List,
long_term: List,
preferences: Dict
) -> str:
"""构建上下文"""
parts = []
# 添加用户偏好
if preferences:
parts.append("用户偏好:")
for key, value in preferences.items():
parts.append(f" - {key}: {value}")
parts.append("")
# 添加短期记忆
if short_term:
parts.append("最近对话:")
for msg in short_term[-5:]: # 最多5条
parts.append(f" {msg.role}: {msg.content[:100]}...")
parts.append("")
# 添加长期记忆
if long_term:
parts.append("相关信息:")
for memory, score in long_term:
parts.append(f" - {memory.content[:150]}... (相关度: {score:.2f})")
return "\n".join(parts)
5.2 记忆去重
class MemoryDeduplicator:
"""记忆去重器"""
def __init__(self, similarity_threshold: float = 0.95):
self.similarity_threshold = similarity_threshold
def deduplicate_by_content(
self,
memories: List[MemoryItem]
) -> List[MemoryItem]:
"""按内容去重"""
unique = []
seen_embeddings = []
for memory in memories:
if not memory.embedding:
unique.append(memory)
continue
# 检查是否与已有记忆相似
is_duplicate = False
for seen_embedding in seen_embeddings:
similarity = self._cosine_similarity(
memory.embedding,
seen_embedding
)
if similarity >= self.similarity_threshold:
is_duplicate = True
break
if not is_duplicate:
unique.append(memory)
seen_embeddings.append(memory.embedding)
return unique
def deduplicate_by_id(self, memories: List[MemoryItem]) -> List[MemoryItem]:
"""按ID去重"""
seen_ids = set()
unique = []
for memory in memories:
if memory.id not in seen_ids:
unique.append(memory)
seen_ids.add(memory.id)
return unique
def _cosine_similarity(
self,
vec1: List[float],
vec2: List[float]
) -> float:
"""计算余弦相似度"""
import numpy as np
vec1 = np.array(vec1)
vec2 = np.array(vec2)
dot = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
return dot / (norm1 * norm2) if norm1 * norm2 > 0 else 0
6. 记忆管理
6.1 完整记忆管理器
class MemoryManager:
"""
完整记忆管理器
整合:
- 短期记忆
- 长期记忆
- 用户偏好
- 会话状态
"""
def __init__(
self,
user_id: str,
session_id: str = None
):
self.user_id = user_id
self.session_id = session_id or self._generate_session_id()
# 初始化各层记忆
self.short_term = ShortTermMemory()
self.long_term = LongTermMemory()
self.preferences = UserPreferenceMemory()
self.session_state = SessionStateMemory()
# 检索器
self.retriever = HybridMemoryRetriever(
self.short_term,
self.long_term,
self.preferences
)
def add_message(self, role: str, content: str):
"""添加消息到短期记忆"""
self.short_term.add(role, content)
def add_long_term_memory(
self,
content: str,
metadata: Dict = None,
importance: float = 1.0
) -> str:
"""添加长期记忆"""
# 添加用户ID到元数据
if metadata is None:
metadata = {}
metadata["user_id"] = self.user_id
metadata["session_id"] = self.session_id
return self.long_term.add(content, metadata, importance)
def retrieve(self, query: str, top_k: int = 10) -> Dict:
"""检索记忆"""
return self.retriever.retrieve(query, self.user_id, top_k)
def get_context(self, max_tokens: int = 4000) -> str:
"""获取上下文"""
# 获取短期记忆
recent_messages = self.short_term.get_recent(5)
# 获取用户偏好
prefs = self.preferences.get_all(self.user_id)
# 构建上下文
context_parts = []
# 用户偏好
if prefs:
context_parts.append("用户偏好:")
for key, value in prefs.items():
context_parts.append(f" - {key}: {value}")
context_parts.append("")
# 最近对话
if recent_messages:
context_parts.append("最近对话:")
for msg in recent_messages:
context_parts.append(f" {msg.role}: {msg.content}")
return "\n".join(context_parts)
def set_preference(self, key: str, value: Any):
"""设置用户偏好"""
self.preferences.set(self.user_id, key, value)
def get_preference(self, key: str, default=None) -> Any:
"""获取用户偏好"""
return self.preferences.get(self.user_id, key, default)
def set_session_state(self, key: str, value: Any):
"""设置会话状态"""
self.session_state.set(self.session_id, key, value)
def get_session_state(self, key: str, default=None) -> Any:
"""获取会话状态"""
return self.session_state.get(self.session_id, key, default)
def clear_session(self):
"""清空会话"""
self.short_term.clear()
self.session_state.delete(self.session_id)
def _generate_session_id(self) -> str:
"""生成会话ID"""
import uuid
return str(uuid.uuid4())
7. 记忆优化
7.1 记忆压缩
class MemoryCompressor:
"""记忆压缩器"""
def compress_long_term_memory(self, memory: LongTermMemory):
"""压缩长期记忆"""
# 策略1:删除低访问次数的记忆
# 策略2:合并相似记忆
# 策略3:摘要旧记忆
# 简化实现:删除3个月内未访问的记忆
cutoff_date = datetime.utcnow() - timedelta(days=90)
to_delete = []
for memory_id, memory in memory._memories.items():
if memory.last_accessed < cutoff_date:
to_delete.append(memory_id)
for memory_id in to_delete:
memory.delete(memory_id)
def compress_session_memory(self, memory: ShortTermMemory):
"""压缩会话记忆"""
# 策略:摘要并保留最近消息
if len(memory.messages) < 20:
return
# 摘要前一半消息
split = len(memory.messages) // 2
old_messages = memory.messages[:split]
summary = self._create_summary(old_messages)
message.messages = [
Message(role="system", content=summary)
] + memory.messages[split:]
def _create_summary(self, messages: List[Message]) -> str:
"""创建摘要"""
# 简化实现
return f"[摘要:{len(messages)}条消息]"
8. 实现示例
8.1 完整记忆系统
"""
完整的Agent记忆系统
整合所有记忆组件,提供统一的接口
"""
class AgentMemorySystem:
"""Agent记忆系统"""
def __init__(self, user_id: str, config: Dict = None):
self.user_id = user_id
self.config = config or {}
# 初始化记忆组件
self.short_term = ShortTermMemory(
max_messages=self.config.get("max_messages", 100),
max_tokens=self.config.get("max_tokens", 4000)
)
self.long_term = LongTermMemory(
embedding_model=self.config.get("embedding_model", "BAAI/bge-small-zh-v1.5")
)
self.preferences = UserPreferenceMemory()
self.session_state = SessionStateMemory()
self.fact_memory = FactMemory()
# 对话接口
def add_user_message(self, content: str):
"""添加用户消息"""
self.short_term.add("user", content)
def add_assistant_message(self, content: str):
"""添加助手消息"""
self.short_term.add("assistant", content)
def get_conversation_history(self) -> List[Dict]:
"""获取对话历史"""
return self.short_term.to_messages_format()
# 长期记忆接口
def remember(
self,
content: str,
category: str = None,
importance: float = 1.0
):
"""记住信息"""
metadata = {"user_id": self.user_id}
if category:
metadata["category"] = category
return self.long_term.add(content, metadata, importance)
def recall(self, query: str, top_k: int = 5) -> List[Dict]:
"""回忆相关信息"""
results = self.long_term.search(query, top_k)
return [
{
"content": memory.content,
"score": score,
"metadata": memory.metadata
}
for memory, score in results
]
# 用户偏好接口
def set_preference(self, key: str, value: Any):
"""设置偏好"""
self.preferences.set(self.user_id, key, value)
def get_preference(self, key: str, default=None) -> Any:
"""获取偏好"""
return self.preferences.get(self.user_id, key, default)
# 会话状态接口
def set_state(self, key: str, value: Any):
"""设置状态"""
session_id = self._get_session_id()
self.session_state.set(session_id, key, value)
def get_state(self, key: str, default=None) -> Any:
"""获取状态"""
session_id = self._get_session_id()
return self.session_state.get(session_id, key, default)
# 事实记忆接口
def remember_fact(self, fact_type: str, value: str):
"""记住事实"""
self.fact_memory.set_personal_info(self.user_id, fact_type, value)
def recall_fact(self, fact_type: str) -> Optional[str]:
"""回忆事实"""
return self.fact_memory.get_personal_info(self.user_id, fact_type)
# 上下文构建
def build_context(self, query: str = None) -> str:
"""构建完整上下文"""
parts = []
# 1. 用户偏好
prefs = self.preferences.get_all(self.user_id)
if prefs:
parts.append("用户偏好:")
for key, value in prefs.items():
parts.append(f" - {key}: {value}")
parts.append("")
# 2. 相关长期记忆
if query:
recalled = self.recall(query, top_k=3)
if recalled:
parts.append("相关信息:")
for item in recalled:
parts.append(f" - {item['content'][:100]}...")
parts.append("")
# 3. 对话历史
history = self.get_conversation_history()[-5:]
if history:
parts.append("最近对话:")
for msg in history:
content = msg["content"][:50]
parts.append(f" {msg['role']}: {content}")
return "\n".join(parts)
def _get_session_id(self) -> str:
"""获取当前会话ID"""
return getattr(self, "_current_session_id", "default")
def set_session(self, session_id: str):
"""设置当前会话"""
self._current_session_id = session_id
# ============== 使用示例 ==============
if __name__ == "__main__":
# 初始化记忆系统
memory = AgentMemorySystem(user_id="user123")
# 对话交互
memory.add_user_message("我叫张三,住在北京")
memory.add_assistant_message("你好张三!很高兴认识你。")
# 记住重要信息
memory.remember(
content="张三住在北京,是一名软件工程师",
category="personal",
importance=0.9
)
# 设置用户偏好
memory.set_preference("language", "zh-CN")
memory.set_preference("response_style", "detailed")
# 记住事实
memory.remember_fact("name", "张三")
memory.remember_fact("location", "北京")
# 回忆信息
query = "张三是谁?"
recalled = memory.recall(query, top_k=3)
print("回忆的信息:")
for item in recalled:
print(f" - {item['content']} (相关度: {item['score']:.2f})")
# 获取偏好
language = memory.get_preference("language")
print(f"\n语言偏好: {language}")
# 构建上下文
context = memory.build_context(query)
print(f"\n上下文:\n{context}")
面试高频问法
Q1: 如何设计Agent的记忆系统?
标准回答:
记忆系统设计要点:
1. 分层架构
- 短期记忆:当前会话对话历史
- 长期记忆:向量存储的历史信息
- 结构化记忆:用户偏好、状态、事实
2. 短期记忆
- 滑动窗口:保留最近N条消息
- Token限制:控制总token数
- 对话摘要:旧消息摘要压缩
3. 长期记忆
- 向量存储:语义相似度检索
- 重要性评分:自动评估重要性
- 分层存储:立即/近期/长期
4. 记忆检索
- 混合检索:结合短期和长期
- 语义搜索:基于相似度
- 关键词搜索:快速匹配
5. 持久化
- 保存用户偏好
- 保存重要事实
- 定期备份
实现示例:
```python
class MemorySystem:
def __init__(self):
self.short_term = ShortTermMemory()
self.long_term = LongTermMemory()
self.preferences = UserPreferenceMemory()
def add_message(self, role, content):
self.short_term.add(role, content)
def remember(self, content, importance=1.0):
self.long_term.add(content, importance)
def retrieve(self, query):
return self.long_term.search(query)</code></pre>
<p>```</p>
<h3>Q2: 如何管理长对话的上下文?</h3>
<p>标准回答:</p>
<pre><code>长对话管理策略:
1. 滑动窗口
- 保留最近的N条消息
- 保留系统提示
- 保留初始请求
2. 对话摘要
- 定期摘要旧消息
- 用摘要代替原消息
- 减少token占用
3. 智能筛选
- 保留高重要性的消息
- 保留用户明确要求记住的内容
- 删除低价值的内容
4. 分层存储
- 当前窗口:完整显示
- 历史摘要:压缩存储
- 长期记忆:按需检索
实现:
```python
def manage_long_conversation(messages, max_tokens):
# Step 1: 检查是否超出限制
current_tokens = sum(count_tokens(m) for m in messages)
if current_tokens <= max_tokens:
return messages
# Step 2: 摘要旧消息
split = len(messages) // 2
old_messages = messages[:split]
summary = summarize_messages(old_messages)
# Step 3: 重建消息列表
result = [
{"role": "system", "content": summary}
] + messages[split:]
return result
### Q3: 如何实现跨会话记忆?
标准回答:
跨会话记忆实现:
-
持久化存储
- 向量数据库:存储长期记忆
- 键值存储:存储用户偏好
- 数据库:存储会话状态
-
记忆标识
- 用户ID:关联用户的所有记忆
- 会话ID:关联会话内的记忆
- 时间戳:记忆的创建时间
-
记忆提取
- 语义检索:根据查询检索相关记忆
- 元数据过滤:按用户/类别筛选
- 重要性排序:优先返回重要记忆
-
记忆更新
- 访问计数:跟踪记忆使用频率
- 最后访问:跟踪访问时间
- 重要性评估:动态调整重要性
实现:
class CrossSessionMemory:
def __init__(self, user_id):
self.user_id = user_id
self.long_term = VectorDB()
self.preferences = KVStore()
def remember(self, content, category):
# 存储到向量数据库
self.long_term.add(
content=content,
metadata={
"user_id": self.user_id,
"category": category
}
)
def recall(self, query, top_k=5):
# 检索相关记忆
results = self.long_term.search(
query=query,
filter={"user_id": self.user_id},
top_k=top_k
)
return results
---
## 总结
### 记忆系统核心要点
| 要点 | 策略 |
|------|------|
| **分层架构** | 短期/长期/结构化分离 |
| **短期记忆** | 滑动窗口、对话摘要 |
| **长期记忆** | 向量存储、语义检索 |
| **结构化记忆** | 键值存储、持久化 |
| **记忆检索** | 混合检索、去重 |
| **记忆优化** | 压缩、淘汰机制 |
### 最佳实践
1. **分层管理**:不同类型记忆用不同存储
2. **智能评估**:自动评估记忆重要性
3. **及时压缩**:定期压缩旧记忆
4. **持久化**:重要信息必须持久化
5. **隐私保护**:敏感信息加密存储