【AI Agent 知识库】21-Agent记忆系统设计

内容纲要

Agent 记忆系统设计

目录

  1. 记忆系统概述
  2. 短期记忆
  3. 长期记忆
  4. 结构化记忆
  5. 记忆检索
  6. 记忆管理
  7. 记忆优化
  8. 实现示例

1. 记忆系统概述

1.1 记忆类型

┌─────────────────────────────────────────────────────┐
│              Agent 记忆系统架构                      │
├─────────────────────────────────────────────────────┤
│                                                     │
│  ┌─────────┐    ┌──────────┐    ┌─────────────┐  │
│  │短期记忆  │    │长期记忆   │    │结构化记忆    │  │
│  │(会话)   │    │(向量存储) │    │(键值存储)    │  │
│  └────┬────┘    └────┬─────┘    └──────┬──────┘  │
│       │               │                  │          │
│       ▼               ▼                  ▼          │
│  ┌─────────────────────────────────────────────┐  │
│  │           记忆管理器 (Memory Manager)      │  │
│  │  - 添加记忆                                │  │
│  │  - 检索记忆                                │  │
│  │  - 更新记忆                                │  │
│  │  - 删除记忆                                │  │
│  └─────────────────────────────────────────────┘  │
└─────────────────────────────────────────────────────┘

1.2 记忆类型对比

记忆类型 存储介质 持久化 检索方式 使用场景
短期记忆 内存 顺序/滑动窗口 当前会话上下文
长期记忆 向量数据库 语义相似度搜索 历史对话、知识
结构化记忆 数据库/Redis 键值查询 用户偏好、状态

1.3 记忆价值

┌─────────────────────────────────────────────────┐
│              记忆系统的价值                      │
├─────────────────────────────────────────────────┤
│  ✓ 保持对话连贯性                               │
│  ✓ 学习用户偏好                                 │
│  ✓ 避免重复内容                                 │
│  ✓ 支持跨会话引用                               │
│  ✓ 实现个性化交互                               │
│  ✓ 记住重要事实                                 │
└─────────────────────────────────────────────────┘

2. 短期记忆

2.1 基本实现

from typing import List, Dict, Optional
from dataclasses import dataclass
from datetime import datetime
import json

@dataclass
class Message:
    role: str  # "system", "user", "assistant", "tool"
    content: str
    timestamp: datetime = None
    metadata: Dict = None

    def __post_init__(self):
        if self.timestamp is None:
            self.timestamp = datetime.utcnow()
        if self.metadata is None:
            self.metadata = {}

class ShortTermMemory:
    """短期记忆 - 对话历史"""

    def __init__(
        self,
        max_messages: int = 100,
        max_tokens: int = 4000,
        strategy: str = "sliding"  # sliding, lru, summary
    ):
        self.max_messages = max_messages
        self.max_tokens = max_tokens
        self.strategy = strategy
        self.messages: List[Message] = []

    def add(self, role: str, content: str, metadata: Dict = None) -> bool:
        """添加消息"""
        message = Message(role=role, content=content, metadata=metadata or {})

        # 检查是否超出限制
        if not self._check_limits(message):
            self._evict()

        self.messages.append(message)
        return True

    def _check_limits(self, message: Message) -> bool:
        """检查是否添加消息会超出限制"""
        message_tokens = count_tokens(message.content)

        # 检查消息数量
        if len(self.messages) >= self.max_messages:
            return False

        # 检查token数量
        current_tokens = sum(count_tokens(m.content) for m in self.messages)
        if current_tokens + message_tokens > self.max_tokens:
            return False

        return True

    def _evict(self):
        """淘汰消息,根据策略"""
        if self.strategy == "sliding":
            # 滑动窗口:删除最旧的消息
            if self.messages:
                self.messages.pop(0)
        elif self.strategy == "lru":
            # LRU:删除最不常访问的(简化:最旧)
            if self.messages:
                self.messages.pop(0)
        elif self.strategy == "summary":
            # 摘要:将旧消息摘要
            self._summarize_old()

    def _summarize_old(self):
        """摘要旧消息"""
        # 将前半部分消息摘要为一个系统消息
        if len(self.messages) < 10:
            return

        # 摘要前一半消息
        split_point = len(self.messages) // 2
        old_messages = self.messages[:split_point]
        summary = self._create_summary(old_messages)

        # 替换为摘要
        self.messages = [
            Message(role="system", content=summary)
        ] + self.messages[split_point:]

    def _create_summary(self, messages: List[Message]) -> str:
        """创建消息摘要"""
        # 简单实现:提取关键信息
        # 实际可以使用LLM摘要
        return f"历史对话摘要:{len(messages)}条消息"

    def get_recent(self, n: int = 10) -> List[Message]:
        """获取最近N条消息"""
        return self.messages[-n:]

    def get_all(self) -> List[Message]:
        """获取所有消息"""
        return self.messages.copy()

    def to_messages_format(self) -> List[Dict]:
        """转换为LLM消息格式"""
        return [
            {
                "role": m.role,
                "content": m.content
            }
            for m in self.messages
        ]

    def clear(self):
        """清空记忆"""
        self.messages.clear()

    def filter_by_role(self, role: str) -> List[Message]:
        """按角色过滤消息"""
        return [m for m in self.messages if m.role == role]

    def search(self, keyword: str) -> List[Message]:
        """搜索包含关键词的消息"""
        return [
            m for m in self.messages
            if keyword.lower() in m.content.lower()
        ]

2.2 滑动窗口实现

class SlidingWindowMemory:
    """滑动窗口记忆"""

    def __init__(
        self,
        window_size: int = 10,
        keep_system: bool = True,
        keep_first_user: bool = True
    ):
        self.window_size = window_size
        self.keep_system = keep_system
        self.keep_first_user = keep_first_user
        self.messages: List[Message] = []
        self.first_user_message: Optional[Message] = None

    def add(self, message: Message):
        """添加消息"""
        # 保存第一条用户消息
        if message.role == "user" and self.first_user_message is None:
            self.first_user_message = message

        self.messages.append(message)
        self._apply_window()

    def _apply_window(self):
        """应用滑动窗口"""
        # 计算需要保留的消息数量
        keep_count = self.window_size
        if self.keep_system:
            keep_count -= 1  # 预留系统消息位置
        if self.first_user_message and self.first_user_message in self.messages:
            keep_count -= 1  # 预留第一条用户消息

        # 保留系统消息
        system_messages = [
            m for m in self.messages if m.role == "system"
        ]

        # 保留最近N条消息
        recent_messages = self.messages[-keep_count:] if keep_count > 0 else []

        # 重建消息列表
        result = []

        # 添加系统消息
        if self.keep_system:
            result.extend(system_messages)

        # 添加第一条用户消息
        if self.first_user_message and self.keep_first_user:
            result.append(self.first_user_message)

        # 添加最近消息(去重)
        for msg in recent_messages:
            if msg not in result:
                result.append(msg)

        self.messages = result

    def get_context(self) -> List[Dict]:
        """获取上下文"""
        return [
            {"role": m.role, "content": m.content}
            for m in self.messages
        ]

2.3 对话摘要

import openai

class ConversationSummarizer:
    """对话摘要器"""

    def __init__(self, llm=None):
        self.llm = llm or openai.ChatCompletion

    async def summarize(
        self,
        messages: List[Message],
        max_length: int = 500
    ) -> str:
        """摘要对话"""
        # 格式化消息
        conversation = "\n".join([
            f"{m.role}: {m.content}"
            for m in messages
        ])

        prompt = f"""请将以下对话摘要为简短的总结,保留关键信息。

对话:
{conversation}

请输出摘要(不超过{max_length}字):"""

        response = await self.llm.create(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=200,
            temperature=0.3
        )

        return response.choices[0].message.content

    async def summarize_with_key_points(
        self,
        messages: List[Message]
    ) -> Dict:
        """摘要并提取关键点"""
        conversation = "\n".join([
            f"{m.role}: {m.content}"
            for m in messages
        ])

        prompt = f"""分析以下对话,提供摘要和关键点。

对话:
{conversation}

请以JSON格式输出:
{{
    "summary": "对话摘要",
    "key_points": ["关键点1", "关键点2"],
    "topics": ["讨论主题1", "讨论主题2"]
}}"""

        response = await self.llm.create(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            response_format={"type": "json_object"}
        )

        import json
        return json.loads(response.choices[0].message.content)

3. 长期记忆

3.1 向量存储实现

from typing import List, Optional
from sentence_transformers import SentenceTransformer
import numpy as np

@dataclass
class MemoryItem:
    id: str
    content: str
    embedding: List[float] = None
    metadata: Dict = None
    importance: float = 1.0  # 重要性分数
    access_count: int = 0
    last_accessed: datetime = None
    created_at: datetime = None

    def __post_init__(self):
        if self.metadata is None:
            self.metadata = {}
        if self.last_accessed is None:
            self.last_accessed = datetime.utcnow()
        if self.created_at is None:
            self.created_at = datetime.utcnow()

class LongTermMemory:
    """长期记忆 - 基于向量数据库"""

    def __init__(
        self,
        embedding_model: str = "BAAI/bge-small-zh-v1.5",
        storage_path: str = None
    ):
        self.embedder = SentenceTransformer(embedding_model)
        self.storage_path = storage_path
        self._memories: Dict[str, MemoryItem] = {}
        self._load_from_storage()

    def add(
        self,
        content: str,
        metadata: Dict = None,
        importance: float = 1.0
    ) -> str:
        """添加记忆"""
        # 生成ID
        memory_id = self._generate_id()

        # 生成嵌入
        embedding = self.embedder.encode(content).tolist()

        # 创建记忆项
        memory = MemoryItem(
            id=memory_id,
            content=content,
            embedding=embedding,
            metadata=metadata or {},
            importance=importance
        )

        # 存储
        self._memories[memory_id] = memory

        # 持久化
        self._save_to_storage()

        return memory_id

    def search(
        self,
        query: str,
        top_k: int = 5,
        min_score: float = 0.5
    ) -> List[tuple[MemoryItem, float]]:
        """
        语义搜索记忆

        Returns:
            [(memory, score), ...]
        """
        # 生成查询嵌入
        query_embedding = self.embedder.encode(query)

        # 计算相似度
        results = []
        for memory in self._memories.values():
            score = self._cosine_similarity(
                query_embedding,
                memory.embedding
            )

            # 考虑重要性
            final_score = score * memory.importance

            if final_score >= min_score:
                results.append((memory, final_score))

        # 更新访问信息
        for memory, _ in results:
            memory.access_count += 1
            memory.last_accessed = datetime.utcnow()

        # 按分数排序
        results.sort(key=lambda x: x[1], reverse=True)

        return results[:top_k]

    def _cosine_similarity(
        self,
        vec1: np.ndarray,
        vec2: List[float]
    ) -> float:
        """计算余弦相似度"""
        vec2 = np.array(vec2)
        dot_product = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        return dot_product / (norm1 * norm2) if norm1 * norm2 > 0 else 0

    def get(self, memory_id: str) -> Optional[MemoryItem]:
        """获取记忆"""
        memory = self._memories.get(memory_id)
        if memory:
            memory.access_count += 1
            memory.last_accessed = datetime.utcnow()
        return memory

    def update(self, memory_id: str, **kwargs):
        """更新记忆"""
        if memory_id in self._memories:
            memory = self._memories[memory_id]

            for key, value in kwargs.items():
                setattr(memory, key, value)

            # 如果更新内容,重新生成嵌入
            if "content" in kwargs:
                memory.embedding = self.embedder.encode(
                    kwargs["content"]
                ).tolist()

            self._save_to_storage()

    def delete(self, memory_id: str) -> bool:
        """删除记忆"""
        if memory_id in self._memories:
            del self._memories[memory_id]
            self._save_to_storage()
            return True
        return False

    def get_all(self) -> List[MemoryItem]:
        """获取所有记忆"""
        return list(self._memories.values())

    def filter_by_metadata(
        self,
        key: str,
        value: any
    ) -> List[MemoryItem]:
        """按元数据过滤"""
        return [
            m for m in self._memories.values()
            if m.metadata.get(key) == value
        ]

    def get_important_memories(
        self,
        min_importance: float = 0.8
    ) -> List[MemoryItem]:
        """获取重要记忆"""
        return [
            m for m in self._memories.values()
            if m.importance >= min_importance
        ]

    def _generate_id(self) -> str:
        """生成记忆ID"""
        import uuid
        return str(uuid.uuid4())

    def _load_from_storage(self):
        """从存储加载"""
        if not self.storage_path:
            return

        # 简化实现:从JSON加载
        # 实际应用中应使用向量数据库
        pass

    def _save_to_storage(self):
        """保存到存储"""
        if not self.storage_path:
            return

        # 简化实现:保存到JSON
        # 实际应用中应使用向量数据库
        pass

3.2 分层记忆

class HierarchicalLongTermMemory:
    """分层长期记忆"""

    def __init__(self):
        # 三层记忆
        self.immediate = LongTermMemory()  # 立即记忆(最近)
        self.recent = LongTermMemory()    # 近期记忆(周内)
        self.long_term = LongTermMemory()  # 长期记忆(历史)

        # 时间阈值(天)
        self.immediate_threshold = 1
        self.recent_threshold = 7

    def add(
        self,
        content: str,
        metadata: Dict = None,
        importance: float = 1.0
    ) -> str:
        """添加记忆到立即层"""
        return self.immediate.add(
            content, metadata, importance
        )

    def search(
        self,
        query: str,
        top_k: int = 5
    ) -> List[tuple[MemoryItem, float]]:
        """
        搜索所有层

        优先从立即层搜索,然后近期层,最后长期层
        """
        results = []

        # 从各层搜索
        immediate_results = self.immediate.search(query, top_k * 2)
        recent_results = self.recent.search(query, top_k * 2)
        long_results = self.long_term.search(query, top_k * 2)

        # 合并并去重
        seen_ids = set()
        for memory, score in immediate_results:
            if memory.id not in seen_ids:
                results.append((memory, score * 1.0))  # 立即层权重1.0
                seen_ids.add(memory.id)

        for memory, score in recent_results:
            if memory.id not in seen_ids:
                results.append((memory, score * 0.8))  # 近期层权重0.8
                seen_ids.add(memory.id)

        for memory, score in long_results:
            if memory.id not in seen_ids:
                results.append((memory, score * 0.5))  # 长期层权重0.5
                seen_ids.add(memory.id)

        # 排序
        results.sort(key=lambda x: x[1], reverse=True)

        return results[:top_k]

    def consolidate(self):
        """合并记忆到下一层"""
        now = datetime.utcnow()

        # 将立即层合并到近期层
        for memory in self.immediate.get_all():
            days_old = (now - memory.created_at).days
            if days_old > self.immediate_threshold:
                self.recent._memories[memory.id] = memory
                del self.immediate._memories[memory.id]

        # 将近期层合并到长期层
        for memory in self.recent.get_all():
            days_old = (now - memory.created_at).days
            if days_old > self.recent_threshold:
                self.long_term._memories[memory.id] = memory
                del self.recent._memories[memory.id]

3.3 记忆重要性评估

class MemoryImportanceEvaluator:
    """记忆重要性评估器"""

    def __init__(self, llm=None):
        self.llm = llm or openai.ChatCompletion

    async def evaluate(self, content: str) -> float:
        """评估记忆内容的重要性

        Returns:
            0.0-1.0 的重要性分数
        """
        prompt = f"""评估以下信息的重要性。

分数范围:
- 0.0-0.2: 不重要,可以忽略
- 0.2-0.4: 低重要性
- 0.4-0.6: 中等重要性
- 0.6-0.8: 高重要性
- 0.8-1.0: 非常重要,需要记住

信息:
{content}

请只输出一个0.0到1.0之间的数字:"""

        response = await self.llm.create(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.1,
            max_tokens=10
        )

        try:
            score = float(response.choices[0].message.content.strip())
            return max(0.0, min(1.0, score))
        except:
            return 0.5  # 默认中等重要性

    def evaluate_by_signals(
        self,
        content: str,
        signals: Dict
    ) -> float:
        """
        基于信号评估重要性

        Signals:
        - user_explicit: 用户明确要求记住
        - repeated: 内容重复出现
        - emotional: 包含情绪词
        - actionable: 可操作的信息
        - personal: 个人信息
        """
        score = 0.5  # 基础分

        if signals.get("user_explicit"):
            score += 0.3

        if signals.get("repeated"):
            score += 0.2

        if signals.get("emotional"):
            score += 0.1

        if signals.get("actionable"):
            score += 0.15

        if signals.get("personal"):
            score += 0.2

        # 检测内容特征
        content_lower = content.lower()

        # 检测个人信息
        personal_keywords = ["我叫", "我的名字", "我住在", "我的电话"]
        if any(kw in content_lower for kw in personal_keywords):
            score += 0.2

        # 检测重要指示词
        important_keywords = ["记住", "重要", "别忘了", "关键"]
        if any(kw in content_lower for kw in important_keywords):
            score += 0.15

        return min(1.0, score)

    def extract_facts(self, content: str) -> List[Dict]:
        """提取需要记住的事实"""
        prompt = f"""从以下内容中提取需要记住的事实。

内容:
{content}

请以JSON格式输出提取的事实:
{{
    "facts": [
        {{
            "fact": "事实内容",
            "category": "类别",
            "importance": 0.8
        }}
    ]
}}"""

        response = self.llm.create(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}],
            response_format={"type": "json_object"}
        )

        import json
        data = json.loads(response.choices[0].message.content)
        return data.get("facts", [])

4. 结构化记忆

4.1 用户偏好记忆

from typing import Dict, Any

class UserPreferenceMemory:
    """用户偏好记忆"""

    def __init__(self, storage=None):
        self.storage = storage or {}
        self._prefix = "pref:"

    def set(self, user_id: str, key: str, value: Any):
        """设置用户偏好"""
        storage_key = f"{self._prefix}{user_id}:{key}"

        if user_id not in self.storage:
            self.storage[user_id] = {}

        self.storage[user_id][key] = value

    def get(self, user_id: str, key: str, default=None) -> Any:
        """获取用户偏好"""
        if user_id not in self.storage:
            return default

        return self.storage[user_id].get(key, default)

    def get_all(self, user_id: str) -> Dict:
        """获取所有用户偏好"""
        return self.storage.get(user_id, {}).copy()

    def delete(self, user_id: str, key: str):
        """删除用户偏好"""
        if user_id in self.storage and key in self.storage[user_id]:
            del self.storage[user_id][key]

    def update(self, user_id: str, preferences: Dict):
        """批量更新偏好"""
        if user_id not in self.storage:
            self.storage[user_id] = {}

        self.storage[user_id].update(preferences)

    # 常用偏好设置
    def set_language(self, user_id: str, language: str):
        """设置语言偏好"""
        self.set(user_id, "language", language)

    def get_language(self, user_id: str) -> str:
        """获取语言偏好"""
        return self.get(user_id, "language", "zh-CN")

    def set_response_style(self, user_id: str, style: str):
        """设置响应风格"""
        # style: concise, detailed, formal, casual
        self.set(user_id, "response_style", style)

    def get_response_style(self, user_id: str) -> str:
        """获取响应风格"""
        return self.get(user_id, "response_style", "balanced")

    def set_topics(self, user_id: str, topics: List[str]):
        """设置感兴趣的话题"""
        self.set(user_id, "topics", topics)

    def get_topics(self, user_id: str) -> List[str]:
        """获取感兴趣的话题"""
        return self.get(user_id, "topics", [])

4.2 会话状态记忆

from typing import Optional
import pickle

class SessionStateMemory:
    """会话状态记忆"""

    def __init__(self, ttl: int = 3600):
        """
        Args:
            ttl: 生存时间(秒)
        """
        self.states: Dict[str, Dict] = {}
        self.timestamps: Dict[str, datetime] = {}
        self.ttl = ttl

    def set(self, session_id: str, key: str, value: Any):
        """设置会话状态"""
        if session_id not in self.states:
            self.states[session_id] = {}

        self.states[session_id][key] = value
        self.timestamps[session_id] = datetime.utcnow()

    def get(self, session_id: str, key: str, default=None) -> Any:
        """获取会话状态"""
        self._check_ttl(session_id)

        if session_id not in self.states:
            return default

        return self.states[session_id].get(key, default)

    def get_all(self, session_id: str) -> Dict:
        """获取所有会话状态"""
        self._check_ttl(session_id)
        return self.states.get(session_id, {}).copy()

    def delete(self, session_id: str, key: str = None):
        """删除会话状态"""
        if session_id not in self.states:
            return

        if key is None:
            # 删除整个会话
            del self.states[session_id]
            del self.timestamps[session_id]
        else:
            # 删除特定key
            if key in self.states[session_id]:
                del self.states[session_id][key]

    def _check_ttl(self, session_id: str):
        """检查TTL"""
        if session_id not in self.timestamps:
            return

        age = (datetime.utcnow() - self.timestamps[session_id]).total_seconds()
        if age > self.ttl:
            self.delete(session_id)

    # 常用状态方法
    def set_context(self, session_id: str, context: Dict):
        """设置上下文"""
        self.set(session_id, "context", context)

    def get_context(self, session_id: str) -> Dict:
        """获取上下文"""
        return self.get(session_id, "context", {})

    def set_user_info(self, session_id: str, user_info: Dict):
        """设置用户信息"""
        self.set(session_id, "user_info", user_info)

    def get_user_info(self, session_id: str) -> Dict:
        """获取用户信息"""
        return self.get(session_id, "user_info", {})

    def set_step(self, session_id: str, step: str):
        """设置当前步骤"""
        self.set(session_id, "current_step", step)

    def get_step(self, session_id: str) -> Optional[str]:
        """获取当前步骤"""
        return self.get(session_id, "current_step")

4.3 事实记忆

class FactMemory:
    """事实记忆 - 存储关键事实"""

    def __init__(self):
        self.facts: Dict[str, Dict] = {}

    def add_fact(
        self,
        fact_id: str,
        fact: str,
        category: str = "general",
        confidence: float = 1.0
    ):
        """添加事实"""
        self.facts[fact_id] = {
            "fact": fact,
            "category": category,
            "confidence": confidence,
            "created_at": datetime.utcnow(),
            "verified": False
        }

    def get_fact(self, fact_id: str) -> Optional[Dict]:
        """获取事实"""
        return self.facts.get(fact_id)

    def get_by_category(self, category: str) -> List[Dict]:
        """按类别获取事实"""
        return [
            {"id": k, **v}
            for k, v in self.facts.items()
            if v["category"] == category
        ]

    def verify_fact(self, fact_id: str, is_correct: bool):
        """验证事实"""
        if fact_id in self.facts:
            self.facts[fact_id]["verified"] = is_correct

    def search_facts(self, keyword: str) -> List[Dict]:
        """搜索事实"""
        keyword_lower = keyword.lower()
        return [
            {"id": k, **v}
            for k, v in self.facts.items()
            if keyword_lower in v["fact"].lower()
        ]

    # 个人信息事实
    def set_personal_info(self, user_id: str, info_type: str, value: str):
        """设置个人信息"""
        fact_id = f"{user_id}:{info_type}"
        self.add_fact(
            fact_id=fact_id,
            fact=value,
            category="personal",
            confidence=0.95
        )

    def get_personal_info(self, user_id: str, info_type: str) -> Optional[str]:
        """获取个人信息"""
        fact_id = f"{user_id}:{info_type}"
        fact = self.get_fact(fact_id)
        return fact["fact"] if fact else None

5. 记忆检索

5.1 混合记忆检索

class HybridMemoryRetriever:
    """混合记忆检索器"""

    def __init__(
        self,
        short_term: ShortTermMemory,
        long_term: LongTermMemory,
        preferences: UserPreferenceMemory
    ):
        self.short_term = short_term
        self.long_term = long_term
        self.preferences = preferences

    def retrieve(
        self,
        query: str,
        user_id: str = None,
        top_k: int = 10
    ) -> Dict:
        """
        混合检索记忆

        Returns:
            {
                "short_term": [...],
                "long_term": [...],
                "preferences": {...},
                "context": "组合的上下文"
            }
        """
        # 1. 检索短期记忆(搜索关键词)
        short_term_results = self.short_term.search(query)

        # 2. 检索长期记忆(语义搜索)
        long_term_results = self.long_term.search(query, top_k=top_k)

        # 3. 检索用户偏好
        user_preferences = None
        if user_id:
            user_preferences = self.preferences.get_all(user_id)

        # 4. 构建组合上下文
        context = self._build_context(
            short_term_results,
            long_term_results,
            user_preferences
        )

        return {
            "short_term": short_term_results,
            "long_term": long_term_results,
            "preferences": user_preferences,
            "context": context
        }

    def _build_context(
        self,
        short_term: List,
        long_term: List,
        preferences: Dict
    ) -> str:
        """构建上下文"""
        parts = []

        # 添加用户偏好
        if preferences:
            parts.append("用户偏好:")
            for key, value in preferences.items():
                parts.append(f"  - {key}: {value}")
            parts.append("")

        # 添加短期记忆
        if short_term:
            parts.append("最近对话:")
            for msg in short_term[-5:]:  # 最多5条
                parts.append(f"  {msg.role}: {msg.content[:100]}...")
            parts.append("")

        # 添加长期记忆
        if long_term:
            parts.append("相关信息:")
            for memory, score in long_term:
                parts.append(f"  - {memory.content[:150]}... (相关度: {score:.2f})")

        return "\n".join(parts)

5.2 记忆去重

class MemoryDeduplicator:
    """记忆去重器"""

    def __init__(self, similarity_threshold: float = 0.95):
        self.similarity_threshold = similarity_threshold

    def deduplicate_by_content(
        self,
        memories: List[MemoryItem]
    ) -> List[MemoryItem]:
        """按内容去重"""
        unique = []
        seen_embeddings = []

        for memory in memories:
            if not memory.embedding:
                unique.append(memory)
                continue

            # 检查是否与已有记忆相似
            is_duplicate = False
            for seen_embedding in seen_embeddings:
                similarity = self._cosine_similarity(
                    memory.embedding,
                    seen_embedding
                )
                if similarity >= self.similarity_threshold:
                    is_duplicate = True
                    break

            if not is_duplicate:
                unique.append(memory)
                seen_embeddings.append(memory.embedding)

        return unique

    def deduplicate_by_id(self, memories: List[MemoryItem]) -> List[MemoryItem]:
        """按ID去重"""
        seen_ids = set()
        unique = []

        for memory in memories:
            if memory.id not in seen_ids:
                unique.append(memory)
                seen_ids.add(memory.id)

        return unique

    def _cosine_similarity(
        self,
        vec1: List[float],
        vec2: List[float]
    ) -> float:
        """计算余弦相似度"""
        import numpy as np
        vec1 = np.array(vec1)
        vec2 = np.array(vec2)
        dot = np.dot(vec1, vec2)
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        return dot / (norm1 * norm2) if norm1 * norm2 > 0 else 0

6. 记忆管理

6.1 完整记忆管理器

class MemoryManager:
    """
    完整记忆管理器

    整合:
    - 短期记忆
    - 长期记忆
    - 用户偏好
    - 会话状态
    """

    def __init__(
        self,
        user_id: str,
        session_id: str = None
    ):
        self.user_id = user_id
        self.session_id = session_id or self._generate_session_id()

        # 初始化各层记忆
        self.short_term = ShortTermMemory()
        self.long_term = LongTermMemory()
        self.preferences = UserPreferenceMemory()
        self.session_state = SessionStateMemory()

        # 检索器
        self.retriever = HybridMemoryRetriever(
            self.short_term,
            self.long_term,
            self.preferences
        )

    def add_message(self, role: str, content: str):
        """添加消息到短期记忆"""
        self.short_term.add(role, content)

    def add_long_term_memory(
        self,
        content: str,
        metadata: Dict = None,
        importance: float = 1.0
    ) -> str:
        """添加长期记忆"""
        # 添加用户ID到元数据
        if metadata is None:
            metadata = {}
        metadata["user_id"] = self.user_id
        metadata["session_id"] = self.session_id

        return self.long_term.add(content, metadata, importance)

    def retrieve(self, query: str, top_k: int = 10) -> Dict:
        """检索记忆"""
        return self.retriever.retrieve(query, self.user_id, top_k)

    def get_context(self, max_tokens: int = 4000) -> str:
        """获取上下文"""
        # 获取短期记忆
        recent_messages = self.short_term.get_recent(5)

        # 获取用户偏好
        prefs = self.preferences.get_all(self.user_id)

        # 构建上下文
        context_parts = []

        # 用户偏好
        if prefs:
            context_parts.append("用户偏好:")
            for key, value in prefs.items():
                context_parts.append(f"  - {key}: {value}")
            context_parts.append("")

        # 最近对话
        if recent_messages:
            context_parts.append("最近对话:")
            for msg in recent_messages:
                context_parts.append(f"  {msg.role}: {msg.content}")

        return "\n".join(context_parts)

    def set_preference(self, key: str, value: Any):
        """设置用户偏好"""
        self.preferences.set(self.user_id, key, value)

    def get_preference(self, key: str, default=None) -> Any:
        """获取用户偏好"""
        return self.preferences.get(self.user_id, key, default)

    def set_session_state(self, key: str, value: Any):
        """设置会话状态"""
        self.session_state.set(self.session_id, key, value)

    def get_session_state(self, key: str, default=None) -> Any:
        """获取会话状态"""
        return self.session_state.get(self.session_id, key, default)

    def clear_session(self):
        """清空会话"""
        self.short_term.clear()
        self.session_state.delete(self.session_id)

    def _generate_session_id(self) -> str:
        """生成会话ID"""
        import uuid
        return str(uuid.uuid4())

7. 记忆优化

7.1 记忆压缩

class MemoryCompressor:
    """记忆压缩器"""

    def compress_long_term_memory(self, memory: LongTermMemory):
        """压缩长期记忆"""
        # 策略1:删除低访问次数的记忆
        # 策略2:合并相似记忆
        # 策略3:摘要旧记忆

        # 简化实现:删除3个月内未访问的记忆
        cutoff_date = datetime.utcnow() - timedelta(days=90)

        to_delete = []
        for memory_id, memory in memory._memories.items():
            if memory.last_accessed < cutoff_date:
                to_delete.append(memory_id)

        for memory_id in to_delete:
            memory.delete(memory_id)

    def compress_session_memory(self, memory: ShortTermMemory):
        """压缩会话记忆"""
        # 策略:摘要并保留最近消息
        if len(memory.messages) < 20:
            return

        # 摘要前一半消息
        split = len(memory.messages) // 2
        old_messages = memory.messages[:split]

        summary = self._create_summary(old_messages)

        message.messages = [
            Message(role="system", content=summary)
        ] + memory.messages[split:]

    def _create_summary(self, messages: List[Message]) -> str:
        """创建摘要"""
        # 简化实现
        return f"[摘要:{len(messages)}条消息]"

8. 实现示例

8.1 完整记忆系统

"""
完整的Agent记忆系统

整合所有记忆组件,提供统一的接口
"""

class AgentMemorySystem:
    """Agent记忆系统"""

    def __init__(self, user_id: str, config: Dict = None):
        self.user_id = user_id
        self.config = config or {}

        # 初始化记忆组件
        self.short_term = ShortTermMemory(
            max_messages=self.config.get("max_messages", 100),
            max_tokens=self.config.get("max_tokens", 4000)
        )

        self.long_term = LongTermMemory(
            embedding_model=self.config.get("embedding_model", "BAAI/bge-small-zh-v1.5")
        )

        self.preferences = UserPreferenceMemory()
        self.session_state = SessionStateMemory()
        self.fact_memory = FactMemory()

    # 对话接口
    def add_user_message(self, content: str):
        """添加用户消息"""
        self.short_term.add("user", content)

    def add_assistant_message(self, content: str):
        """添加助手消息"""
        self.short_term.add("assistant", content)

    def get_conversation_history(self) -> List[Dict]:
        """获取对话历史"""
        return self.short_term.to_messages_format()

    # 长期记忆接口
    def remember(
        self,
        content: str,
        category: str = None,
        importance: float = 1.0
    ):
        """记住信息"""
        metadata = {"user_id": self.user_id}
        if category:
            metadata["category"] = category

        return self.long_term.add(content, metadata, importance)

    def recall(self, query: str, top_k: int = 5) -> List[Dict]:
        """回忆相关信息"""
        results = self.long_term.search(query, top_k)

        return [
            {
                "content": memory.content,
                "score": score,
                "metadata": memory.metadata
            }
            for memory, score in results
        ]

    # 用户偏好接口
    def set_preference(self, key: str, value: Any):
        """设置偏好"""
        self.preferences.set(self.user_id, key, value)

    def get_preference(self, key: str, default=None) -> Any:
        """获取偏好"""
        return self.preferences.get(self.user_id, key, default)

    # 会话状态接口
    def set_state(self, key: str, value: Any):
        """设置状态"""
        session_id = self._get_session_id()
        self.session_state.set(session_id, key, value)

    def get_state(self, key: str, default=None) -> Any:
        """获取状态"""
        session_id = self._get_session_id()
        return self.session_state.get(session_id, key, default)

    # 事实记忆接口
    def remember_fact(self, fact_type: str, value: str):
        """记住事实"""
        self.fact_memory.set_personal_info(self.user_id, fact_type, value)

    def recall_fact(self, fact_type: str) -> Optional[str]:
        """回忆事实"""
        return self.fact_memory.get_personal_info(self.user_id, fact_type)

    # 上下文构建
    def build_context(self, query: str = None) -> str:
        """构建完整上下文"""
        parts = []

        # 1. 用户偏好
        prefs = self.preferences.get_all(self.user_id)
        if prefs:
            parts.append("用户偏好:")
            for key, value in prefs.items():
                parts.append(f"  - {key}: {value}")
            parts.append("")

        # 2. 相关长期记忆
        if query:
            recalled = self.recall(query, top_k=3)
            if recalled:
                parts.append("相关信息:")
                for item in recalled:
                    parts.append(f"  - {item['content'][:100]}...")
                parts.append("")

        # 3. 对话历史
        history = self.get_conversation_history()[-5:]
        if history:
            parts.append("最近对话:")
            for msg in history:
                content = msg["content"][:50]
                parts.append(f"  {msg['role']}: {content}")

        return "\n".join(parts)

    def _get_session_id(self) -> str:
        """获取当前会话ID"""
        return getattr(self, "_current_session_id", "default")

    def set_session(self, session_id: str):
        """设置当前会话"""
        self._current_session_id = session_id

# ============== 使用示例 ==============

if __name__ == "__main__":
    # 初始化记忆系统
    memory = AgentMemorySystem(user_id="user123")

    # 对话交互
    memory.add_user_message("我叫张三,住在北京")
    memory.add_assistant_message("你好张三!很高兴认识你。")

    # 记住重要信息
    memory.remember(
        content="张三住在北京,是一名软件工程师",
        category="personal",
        importance=0.9
    )

    # 设置用户偏好
    memory.set_preference("language", "zh-CN")
    memory.set_preference("response_style", "detailed")

    # 记住事实
    memory.remember_fact("name", "张三")
    memory.remember_fact("location", "北京")

    # 回忆信息
    query = "张三是谁?"
    recalled = memory.recall(query, top_k=3)
    print("回忆的信息:")
    for item in recalled:
        print(f"  - {item['content']} (相关度: {item['score']:.2f})")

    # 获取偏好
    language = memory.get_preference("language")
    print(f"\n语言偏好: {language}")

    # 构建上下文
    context = memory.build_context(query)
    print(f"\n上下文:\n{context}")

面试高频问法

Q1: 如何设计Agent的记忆系统?

标准回答:

记忆系统设计要点:

1. 分层架构
   - 短期记忆:当前会话对话历史
   - 长期记忆:向量存储的历史信息
   - 结构化记忆:用户偏好、状态、事实

2. 短期记忆
   - 滑动窗口:保留最近N条消息
   - Token限制:控制总token数
   - 对话摘要:旧消息摘要压缩

3. 长期记忆
   - 向量存储:语义相似度检索
   - 重要性评分:自动评估重要性
   - 分层存储:立即/近期/长期

4. 记忆检索
   - 混合检索:结合短期和长期
   - 语义搜索:基于相似度
   - 关键词搜索:快速匹配

5. 持久化
   - 保存用户偏好
   - 保存重要事实
   - 定期备份

实现示例:
```python
class MemorySystem:
    def __init__(self):
        self.short_term = ShortTermMemory()
        self.long_term = LongTermMemory()
        self.preferences = UserPreferenceMemory()

    def add_message(self, role, content):
        self.short_term.add(role, content)

    def remember(self, content, importance=1.0):
        self.long_term.add(content, importance)

    def retrieve(self, query):
        return self.long_term.search(query)</code></pre>
<p>```</p>
<h3>Q2: 如何管理长对话的上下文?</h3>
<p>标准回答:</p>
<pre><code>长对话管理策略:

1. 滑动窗口
   - 保留最近的N条消息
   - 保留系统提示
   - 保留初始请求

2. 对话摘要
   - 定期摘要旧消息
   - 用摘要代替原消息
   - 减少token占用

3. 智能筛选
   - 保留高重要性的消息
   - 保留用户明确要求记住的内容
   - 删除低价值的内容

4. 分层存储
   - 当前窗口:完整显示
   - 历史摘要:压缩存储
   - 长期记忆:按需检索

实现:
```python
def manage_long_conversation(messages, max_tokens):
    # Step 1: 检查是否超出限制
    current_tokens = sum(count_tokens(m) for m in messages)
    if current_tokens <= max_tokens:
        return messages

    # Step 2: 摘要旧消息
    split = len(messages) // 2
    old_messages = messages[:split]
    summary = summarize_messages(old_messages)

    # Step 3: 重建消息列表
    result = [
        {"role": "system", "content": summary}
    ] + messages[split:]

    return result

### Q3: 如何实现跨会话记忆?

标准回答:

跨会话记忆实现:

  1. 持久化存储

    • 向量数据库:存储长期记忆
    • 键值存储:存储用户偏好
    • 数据库:存储会话状态
  2. 记忆标识

    • 用户ID:关联用户的所有记忆
    • 会话ID:关联会话内的记忆
    • 时间戳:记忆的创建时间
  3. 记忆提取

    • 语义检索:根据查询检索相关记忆
    • 元数据过滤:按用户/类别筛选
    • 重要性排序:优先返回重要记忆
  4. 记忆更新

    • 访问计数:跟踪记忆使用频率
    • 最后访问:跟踪访问时间
    • 重要性评估:动态调整重要性

实现:

class CrossSessionMemory:
    def __init__(self, user_id):
        self.user_id = user_id
        self.long_term = VectorDB()
        self.preferences = KVStore()

    def remember(self, content, category):
        # 存储到向量数据库
        self.long_term.add(
            content=content,
            metadata={
                "user_id": self.user_id,
                "category": category
            }
        )

    def recall(self, query, top_k=5):
        # 检索相关记忆
        results = self.long_term.search(
            query=query,
            filter={"user_id": self.user_id},
            top_k=top_k
        )
        return results


---

## 总结

### 记忆系统核心要点

| 要点 | 策略 |
|------|------|
| **分层架构** | 短期/长期/结构化分离 |
| **短期记忆** | 滑动窗口、对话摘要 |
| **长期记忆** | 向量存储、语义检索 |
| **结构化记忆** | 键值存储、持久化 |
| **记忆检索** | 混合检索、去重 |
| **记忆优化** | 压缩、淘汰机制 |

### 最佳实践

1. **分层管理**:不同类型记忆用不同存储
2. **智能评估**:自动评估记忆重要性
3. **及时压缩**:定期压缩旧记忆
4. **持久化**:重要信息必须持久化
5. **隐私保护**:敏感信息加密存储
close
arrow_upward