【AI Agent 知识库】29-模型推理优化

内容纲要

模型推理优化

大模型加速、量化、蒸馏、批处理优化


目录


核心概念

概念 定义 核心价值
Quantization 模型量化 降低显存、提升吞吐
Distillation 知识蒸馏 小模型学习大模型
Batching 批处理 提升吞吐
KV Cache 键值缓存 加速自回归生成
Tensor Parallelism 张量并行 分布式推理
PagedAttention 分页注意力 显存高效分配
Speculative Decoding 投机解码 加速序列生成
FlashAttention Flash注意力 加速注意力计算

模型量化

1.1 量化策略

策略 精度 显存节省 速度提升 质量损失
FP32 32位浮点 0x 1x 0%
FP16/BF16 16位浮点 2x 1.5-2x ≈0%
INT8 8位整数 4x 2-3x 1-2%
INT4 4位整数 8x 3-4x 2-5%

1.2 量化实现

# inference/quantization.py
"""
模型量化实现
包含:GPTQ、AWQ、GGUF格式
"""

from typing import Optional, Dict, List
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import PeftModel

class QuantizationConfig:
    """量化配置"""

    @staticmethod
    def fp16() -> Dict:
        """FP16量化"""
        return {
            "torch_dtype": torch.float16,
            "device_map": "auto"
        }

    @staticmethod
    def int8() -> Dict:
        """INT8量化"""
        return {
            "quantization_config": BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0
            ),
            "device_map": "auto"
        }

    @staticmethod
    def int4_gptq(
        bits: int = 4,
        group_size: int = 128,
        damp_percent: float = 0.01
    ) -> Dict:
        """INT4 GPTQ量化"""
        return {
            "quantization_config": BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            ),
            "device_map": "auto"
        }

    @staticmethod
    def awq(
        bits: int = 4,
        group_size: int = 128
    ) -> Dict:
        """AWQ量化配置"""
        return {
            "bits": bits,
            "group_size": group_size,
            "version": "GEMM"
        }

class QuantizedModelLoader:
    """量化模型加载器"""

    def __init__(self, cache_dir: str = "./models"):
        self.cache_dir = cache_dir

    def load_model(
        self,
        model_path: str,
        quant_config: Dict,
        lora_path: Optional[str] = None
    ) -> tuple:
        """
        加载量化模型

        Returns:
            (model, tokenizer)
        """
        # 加载分词器
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            cache_dir=self.cache_dir
        )

        # 加载模型
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            cache_dir=self.cache_dir,
            **quant_config
        )

        # 加载LoRA
        if lora_path:
            model = PeftModel.from_pretrained(
                model,
                lora_path,
                is_trainable=False
            )

        return model, tokenizer

    def load_gguf_model(
        self,
        model_path: str,
        n_ctx: int = 4096,
        n_gpu_layers: int = -1  # 所有层都放GPU
    ):
        """
        加载GGUF格式模型(llama.cpp)
        """
        try:
            from llama_cpp import Llama

            model = Llama(
                model_path=model_path,
                n_ctx=n_ctx,
                n_gpu_layers=n_gpu_layers,
                verbose=False
            )

            return model

        except ImportError:
            raise ImportError(
                "llama-cpp-python not installed. "
                "Install with: pip install llama-cpp-python"
            )

# ============== GPTQ量化(训练时)==============

class GPTQQuantizer:
    """GPTQ量化器"""

    def __init__(
        self,
        bits: int = 4,
        group_size: int = 128,
        damp_percent: float = 0.01
    ):
        self.bits = bits
        self.group_size = group_size
        self.damp_percent = damp_percent

    def quantize(
        self,
        model_path: str,
        output_path: str,
        calibration_data: List[str]
    ):
        """
        量化模型

        Args:
            model_path: 原始模型路径
            output_path: 量化后模型保存路径
            calibration_data: 校准数据(提示词列表)
        """
        try:
            from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig

            # 量化配置
            quantize_config = BaseQuantizeConfig(
                bits=self.bits,
                group_size=self.group_size,
                damp_percent=self.damp_percent,
                desc_act=False,
                sym=True,
                true_sequential=True
            )

            # 加载模型
            model = AutoGPTQForCausalLM.from_pretrained(
                model_path,
                quantize_config=quantize_config
            )

            # 校准
            model.quantize(
                calibration_data,
                batch_size=128,
                use_triton=False
            )

            # 保存
            model.save_quantized(output_path)

        except ImportError:
            raise ImportError(
                "auto-gptq not installed. "
                "Install with: pip install auto-gptq"
            )

# ============== AWQ量化 ==============

class AWQQuantizer:
    """AWQ量化器"""

    def __init__(
        self,
        bits: int = 4,
        group_size: int = 128,
        zero_point: bool = True
    ):
        self.bits = bits
        self.group_size = group_size
        self.zero_point = zero_point

    def quantize(
        self,
        model_path: str,
        output_path: str,
        calibration_data: List[str]
    ):
        """
        AWQ量化
        """
        try:
            from awq import AutoAWQForCausalLM

            # 加载并量化
            model = AutoAWQForCausalLM.from_pretrained(
                model_path,
                device_map="auto"
            )

            # 校准并量化
            quant_config = {
                "zero_point": self.zero_point,
                "q_group_size": self.group_size,
                "w_bit": self.bits,
                "version": "GEMM"
            }

            model.quantize(
                quant_config,
                calibration_data
            )

            # 保存
            model.save_quantized(output_path)

        except ImportError:
            raise ImportError(
                "awq not installed. "
                "Install with: pip install awq"
            )

# ============== 使用示例 ==============

if __name__ == "__main__":
    loader = QuantizedModelLoader(cache_dir="./models")

    # 1. 加载FP16模型
    config_fp16 = QuantizationConfig.fp16()
    model_fp16, tokenizer = loader.load_model(
        model_path="meta-llama/Llama-2-7b-hf",
        quant_config=config_fp16
    )
    print("FP16模型加载完成")

    # 2. 加载INT8模型
    config_int8 = QuantizationConfig.int8()
    model_int8, tokenizer = loader.load_model(
        model_path="meta-llama/Llama-2-7b-hf",
        quant_config=config_int8
    )
    print("INT8模型加载完成")

    # 3. 加载INT4模型(GPTQ)
    config_int4 = QuantizationConfig.int4_gptq()
    model_int4, tokenizer = loader.load_model(
        model_path="meta-llama/Llama-2-7b-hf",
        quant_config=config_int4
    )
    print("INT4模型加载完成")

    # 4. 加载GGUF模型(llama.cpp)
    # gguf_model = loader.load_gguf_model(
    #     model_path="models/llama-2-7b.gguf",
    #     n_ctx=4096
    # )

    # 5. GPTQ量化(训练时)
    # quantizer = GPTQQuantizer(bits=4, group_size=128)
    # quantizer.quantize(
    #     model_path="meta-llama/Llama-2-7b-hf",
    #     output_path="./models/llama-2-7b-gptq",
    #     calibration_data=["示例校准数据1", "示例校准数据2"]
    # )

模型蒸馏

2.1 蒸馏策略

策略 描述 优势
Logit Matching 匹配输出logits 简单有效
Hidden State 匹配隐藏层 保持特征
Attention Transfer 匹配注意力 保持语义
Progressive 渐进式蒸馏 稳定性好

2.2 蒸馏实现

# inference/distillation.py
"""
知识蒸馏实现
包含:教师-学生模型训练
"""

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader

class DistillationConfig:
    """蒸馏配置"""

    def __init__(
        self,
        temperature: float = 2.0,
        alpha: float = 0.5,  # KL散度权重
        beta: float = 0.5,     # 硬标签权重
        hidden_layer_weight: float = 0.1
    ):
        self.temperature = temperature
        self.alpha = alpha
        self.beta = beta
        self.hidden_layer_weight = hidden_layer_weight

class DistillationLoss(nn.Module):
    """蒸馏损失函数"""

    def __init__(self, config: DistillationConfig):
        super().__init__()
        self.config = config
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")
        self.mse_loss = nn.MSELoss()

    def forward(
        self,
        student_logits: torch.Tensor,
        teacher_logits: torch.Tensor,
        labels: torch.Tensor,
        student_hidden: Optional[torch.Tensor] = None,
        teacher_hidden: Optional[torch.Tensor] = None
    ) -> Dict[str, float]:
        """
        计算蒸馏损失

        Args:
            student_logits: 学生模型输出logits
            teacher_logits: 教师模型输出logits
            labels: 真实标签
            student_hidden: 学生模型隐藏层
            teacher_hidden: 教师模型隐藏层
        """
        # 1. KL散度损失(软目标)
        teacher_probs = nn.functional.softmax(
            teacher_logits / self.config.temperature,
            dim=-1
        )
        student_log_probs = nn.functional.log_softmax(
            student_logits / self.config.temperature,
            dim=-1
        )

        kl_loss = self.kl_loss(student_log_probs, teacher_probs)
        kl_loss = kl_loss * (self.config.temperature ** 2)  # 缩放

        # 2. 硬标签损失(真实标签)
        ce_loss = nn.functional.cross_entropy(
            student_logits,
            labels,
            ignore_index=-100
        )

        # 3. 隐藏层损失(可选)
        hidden_loss = torch.tensor(0.0)
        if student_hidden is not None and teacher_hidden is not None:
            hidden_loss = self.mse_loss(student_hidden, teacher_hidden)

        # 综合损失
        total_loss = (
            self.config.alpha * kl_loss +
            self.config.beta * ce_loss +
            self.config.hidden_layer_weight * hidden_loss
        )

        return {
            "total_loss": total_loss.item(),
            "kl_loss": kl_loss.item(),
            "ce_loss": ce_loss.item(),
            "hidden_loss": hidden_loss.item()
        }

class DistillationTrainer:
    """蒸馏训练器"""

    def __init__(
        self,
        teacher_model: nn.Module,
        student_model: nn.Module,
        tokenizer: any,
        config: DistillationConfig
    ):
        self.teacher = teacher_model
        self.student = student_model
        self.tokenizer = tokenizer
        self.config = config

        self.teacher.eval()
        self.criterion = DistillationLoss(config)

    def train(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        num_epochs: int = 3,
        learning_rate: float = 1e-4,
        output_dir: str = "./output"
    ):
        """
        训练学生模型
        """
        optimizer = optim.AdamW(
            self.student.parameters(),
            lr=learning_rate
        )

        best_val_loss = float('inf')

        for epoch in range(num_epochs):
            # 训练阶段
            train_loss = self._train_epoch(train_loader, optimizer)
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}")

            # 验证阶段
            val_loss = self._validate_epoch(val_loader)
            print(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss:.4f}")

            # 保存最佳模型
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                self._save_checkpoint(output_dir, epoch, val_loss)

    def _train_epoch(
        self,
        dataloader: DataLoader,
        optimizer: optim.Optimizer
    ) -> float:
        """训练一个epoch"""
        self.student.train()
        total_loss = 0
        num_batches = 0

        for batch in dataloader:
            inputs = batch["input_ids"].to(self.student.device)
            attention_mask = batch["attention_mask"].to(self.student.device)
            labels = batch["labels"].to(self.student.device)

            # 前向传播 - 教师模型
            with torch.no_grad():
                teacher_outputs = self.teacher(
                    input_ids=inputs,
                    attention_mask=attention_mask
                )

            # 前向传播 - 学生模型
            student_outputs = self.student(
                input_ids=inputs,
                attention_mask=attention_mask,
                output_hidden_states=True
            )

            # 计算损失
            loss_dict = self.criterion(
                student_logits=student_outputs.logits,
                teacher_logits=teacher_outputs.logits,
                labels=labels,
                student_hidden=student_outputs.hidden_states[-1].mean(dim=1),
                teacher_hidden=teacher_outputs.hidden_states[-1].mean(dim=1)
            )

            loss = torch.tensor(loss_dict["total_loss"])
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss_dict["total_loss"]
            num_batches += 1

        return total_loss / num_batches

    def _validate_epoch(self, dataloader: DataLoader) -> float:
        """验证一个epoch"""
        self.student.eval()
        total_loss = 0
        num_batches = 0

        with torch.no_grad():
            for batch in dataloader:
                inputs = batch["input_ids"].to(self.student.device)
                attention_mask = batch["attention_mask"].to(self.student.device)
                labels = batch["labels"].to(self.student.device)

                # 教师模型
                teacher_outputs = self.teacher(
                    input_ids=inputs,
                    attention_mask=attention_mask
                )

                # 学生模型
                student_outputs = self.student(
                    input_ids=inputs,
                    attention_mask=attention_mask,
                    output_hidden_states=True
                )

                # 损失
                loss_dict = self.criterion(
                    student_logits=student_outputs.logits,
                    teacher_logits=teacher_outputs.logits,
                    labels=labels,
                    student_hidden=student_outputs.hidden_states[-1].mean(dim=1),
                    teacher_hidden=teacher_outputs.hidden_states[-1].mean(dim=1)
                )

                total_loss += loss_dict["total_loss"]
                num_batches += 1

        return total_loss / num_batches

    def _save_checkpoint(self, output_dir: str, epoch: int, loss: float):
        """保存检查点"""
        import os

        os.makedirs(output_dir, exist_ok=True)
        checkpoint_path = os.path.join(output_dir, f"checkpoint_{epoch}")

        torch.save({
            "epoch": epoch,
            "model_state_dict": self.student.state_dict(),
            "loss": loss
        }, checkpoint_path)

# ============== 使用示例 ==============

if __name__ == "__main__":
    # 加载教师和学生模型
    teacher_model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        torch_dtype=torch.float16
    )

    student_model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-1b-hf",
        torch_dtype=torch.float16
    )

    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

    # 配置蒸馏
    config = DistillationConfig(
        temperature=2.0,
        alpha=0.5,
        beta=0.5,
        hidden_layer_weight=0.1
    )

    # 创建训练器
    trainer = DistillationTrainer(
        teacher_model=teacher_model,
        student_model=student_model,
        tokenizer=tokenizer,
        config config
    )

    # 准备数据
    # train_loader = ...
    # val_loader = ...

    # 训练
    # trainer.train(
    #     train_loader=train_loader,
    #     val_loader=val_loader,
    #     num_epochs=3,
    #     learning_rate=1e-4
    # )

推理框架

3.1 主流推理框架

框架 特点 适用场景
vLLM PagedAttention、高吞吐 生产服务
TGI HuggingFace集成、易用 快速部署
LM Studio 本地GUI、易用 个人开发
llama.cpp CPU友好、轻量 边缘设备
TensorRT-LLM NVIDIA优化、极致性能 NVIDIA GPU

3.2 vLLM实现

# inference/vllm_serving.py
"""
vLLM推理服务实现
"""

from typing import List, Optional
from vllm import LLM, SamplingParams

class VLLMEngineInference:
    """vLLM推理引擎"""

    def __init__(
        self,
        model_path: str,
        tensor_parallel_size: int = 1,
        gpu_memory_utilization: float = 0.9,
        max_model_len: int = 4096,
        trust_remote_code: bool = False
    ):
        self.llm = LLM(
            model=model_path,
            tensor_parallel_size=tensor_parallel_size,
            gpu_memory_utilization=gpu_memory_utilization,
            max_model_len=max_model_len,
            trust_remote_code=trust_remote_code
        )

    def generate(
        self,
        prompts: List[str],
        max_tokens: int = 512,
        temperature: float = 1.0,
        top_p: float = 0.9,
        top_k: int = 40,
        stop: Optional[List[str]] = None
    ) -> List[str]:
        """
        批量生成

        Args:
            prompts: 输入提示词列表(支持批处理)
            max_tokens: 最大生成长度
            temperature: 采样温度
            top_p: nucleus采样
            top_k: top-k采样
            stop: 停止词列表

        Returns:
            生成的文本列表
        """
        sampling_params = SamplingParams(
            n=1,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_tokens=max_tokens,
            stop=stop
        )

        # 批量推理
        outputs = self.llm.generate(prompts, sampling_params)

        # 提取生成文本
        generated_texts = [output.outputs[0].text for output in outputs]

        return generated_texts

    def generate_with_prompt_logprobs(
        self,
        prompts: List[str],
        logprobs: int = 1,
        **kwargs
    ) -> List[dict]:
        """
        生成并返回概率

        Returns:
            [
                {
                    "text": str,
                    "logprobs": List[List[float]]  # 每个token的概率
                },
                ...
            ]
        """
        sampling_params = SamplingParams(
            logprobs=logprobs,
            **kwargs
        )

        outputs = self.llm.generate(prompts, sampling_params)

        results = []
        for output in outputs:
            result = {
                "text": output.outputs[0].text,
                "logprobs": output.outputs[0].logprobs
            }
            results.append(result)

        return results

    def get_model_info(self) -> dict:
        """获取模型信息"""
        return {
            "model_name": self.llm.llm_engine.model_config.model_name,
            "max_model_len": self.llm.llm_engine.model_config.max_model_len,
            "tensor_parallel_size": self.llm.llm_engine.tensor_parallel_size
        }

# ============== vLLM OpenAI兼容服务 ==============

class VLLMOpenAIServer:
    """vLLM OpenAI兼容服务"""

    def __init__(
        self,
        model_path: str,
        host: str = "0.0.0.0",
        port: int = 8000
    ):
        self.model_path = model_path
        self.host = host
        self.port = port

    def start(self):
        """启动服务"""
        from vllm.entrypoints.openai.api_server import (
            create_server,
            uvicorn
        )

        app = create_server(
            model=self.model_path,
            served_model_name=self.model_path,
            disable_log_stats=True
        )

        uvicorn.run(
            app,
            host=self.host,
            port=self.port,
            log_level="info"
        )

# ============== 使用示例 ==============

if __name__ == "__main__":
    # 1. 单GPU推理
    engine = VLLMEngineInference(
        model_path="meta-llama/Llama-2-7b-hf",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.9
    )

    # 2. 批量生成
    prompts = [
        "解释什么是人工智能?",
        "写一首关于春天的诗",
        "Python中列表和元组的区别是什么?"
    ]

    results = engine.generate(
        prompts=prompts,
        max_tokens=256,
        temperature=0.7
    )

    for prompt, result in zip(prompts, results):
        print(f"\n问: {prompt}")
        print(f"答: {result}")

    # 3. 获取概率
    with_logprobs = engine.generate_with_prompt_logprobs(
        prompts=["什么是机器学习?"],
        logprobs=5,
        max_tokens=50
    )

    print(f"\n生成文本: {with_logprobs[0]['text']}")
    print(f"Token概率: {with_logprobs[0]['logprobs']}")

    # 4. 多GPU推理(Tensor并行)
    # multi_gpu_engine = VLLMEngineInference(
    #     model_path="meta-llama/Llama-2-70b-hf",
    #     tensor_parallel_size=4  # 4张GPU
    # )

    # 5. 启动OpenAI兼容服务
    # server = VLLMOpenAIServer(
    #     model_path="meta-llama/Llama-2-7b-hf",
    #     host="0.0.0.0",
    #     port=8000
    # )
    # server.start()

批处理优化

4.1 动态批处理

# inference/batching.py
"""
批处理优化实现
包含:动态批处理、连续批处理
"""

from typing import List, Dict, Optional
from dataclasses import dataclass
import time
from collections import deque
import threading

@dataclass
class Request:
    """推理请求"""
    id: str
    prompt: str
    max_tokens: int
    temperature: float
    created_at: float = None
    callback: callable = None

@dataclass
class Batch:
    """批处理"""
    requests: List[Request]
    tokens: int
    max_tokens: int

    def add(self, request: Request):
        """添加请求"""
        self.requests.append(request)
        self.tokens += len(request.prompt.split())
        if request.max_tokens > self.max_tokens:
            self.max_tokens = request.max_tokens

class DynamicBatcher:
    """动态批处理器"""

    def __init__(
        self,
        max_batch_size: int = 32,
        max_batch_tokens: int = 4096,
        max_wait_time: float = 0.05  # 50ms
    ):
        self.max_batch_size = max_batch_size
        self.max_batch_tokens = max_batch_tokens
        self.max_wait_time = max_wait_time

        self.request_queue = deque()
        self.active_requests: Dict[str, Request] = {}
        self.lock = threading.Lock()

    def add_request(self, request: Request):
        """添加请求到队列"""
        with self.lock:
            request.created_at = time.time()
            self.request_queue.append(request)

    def get_batch(self) -> Optional[Batch]:
        """
        获取一个批处理
        Returns:
            如果有足够请求或等待超时,返回Batch
            否则返回None
        """
        with self.lock:
            if not self.request_queue:
                return None

            # 检查是否可以立即批处理
            if len(self.request_queue) >= self.max_batch_size:
                return self._build_batch()

            # 检查等待时间
            first_request = self.request_queue[0]
            wait_time = time.time() - first_request.created_at

            if wait_time >= self.max_wait_time:
                return self._build_batch()

            # 检查是否满足token条件
            batch = self._estimate_batch()
            if batch and batch.tokens >= self.max_batch_tokens:
                return batch

            return None

    def _build_batch(self) -> Batch:
        """构建批处理"""
        batch = Batch(requests=[], tokens=0, max_tokens=0)

        while (
            len(batch.requests) < self.max_batch_size
            and self.request_queue
        ):
            request = self.request_queue.popleft()
            batch.add(request)
            self.active_requests[request.id] = request

        return batch

    def _estimate_batch(self) -> Optional[Batch]:
        """估算批处理(不实际移除)"""
        batch = Batch(requests=[], tokens=0, max_tokens=0)

        for request in self.request_queue:
            if (
                len(batch.requests) >= self.max_batch_size
                or batch.tokens + len(request.prompt.split()) > self.max_batch_tokens
            ):
                break
            batch.add(request)

        return batch if batch.requests else None

    def complete_request(self, request_id: str, result: str):
        """完成请求"""
        with self.lock:
            if request_id in self.active_requests:
                request = self.active_requests.pop(request_id)
                if request.callback:
                    request.callback(result)

# ============== 连续批处理 ==============

class ContinuousBatcher:
    """连续批处理器(Continuous Batching)

    与动态批处理不同,连续批处理允许:
    1. 新请求随时加入正在处理的batch
    2. 完成的请求随时移出batch
    3. 不同长度的请求并行处理
    """

    def __init__(
        self,
        model,
        tokenizer,
        max_batch_size: int = 32
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.max_batch_size = max_batch_size

        self.active_sequences: Dict[str, any] = {}
        self.lock = threading.Lock()

    def add_sequence(self, request: Request):
        """添加序列"""
        with self.lock:
            if len(self.active_sequences) < self.max_batch_size:
                self.active_sequences[request.id] = {
                    "request": request,
                    "tokens": self.tokenizer.encode(request.prompt),
                    "generated": [],
                    "done": False
                }
                return True
            return False

    def step(self) -> Dict[str, any]:
        """
        执行一步推理

        Returns:
            {request_id: generated_token or None}
        """
        with self.lock:
            if not self.active_sequences:
                return {}

            # 准备输入
            sequences = list(self.active_sequences.values())

            # 使用模型的前向传播
            # 这里简化处理,实际需要使用支持连续批处理的模型
            outputs = self._model_forward(sequences)

            # 处理输出
            results = {}
            for seq, output in zip(sequences, outputs):
                if output.done:
                    results[seq["request"].id] = "".join(seq["generated"])
                    del self.active_sequences[seq["request"].id]
                else:
                    seq["generated"].append(output.token)
                    results[seq["request"].id] = output.token

            return results

    def _model_forward(self, sequences: List[dict]):
        """模型前向传播(简化)"""
        # 实际实现需要:
        # 1. 对齐不同长度的序列
        # 2. 构建attention mask
        # 3. 调用模型
        # 4. 处理输出
        pass
        return []

# ============== 使用示例 ==============

if __name__ == "__main__":
    # 创建动态批处理器
    batcher = DynamicBatcher(
        max_batch_size=8,
        max_batch_tokens=4096,
        max_wait_time=0.05
    )

    # 模拟添加请求
    for i in range(10):
        request = Request(
            id=f"req-{i}",
            prompt=f"这是请求{i}的内容",
            max_tokens=100,
            temperature=0.7
        )
        batcher.add_request(request)

    # 获取批处理
    batch = batcher.get_batch()
()
    if batch:
        print(f"批处理包含 {len(batch.requests)} 个请求")
        print(f"总token数: {batch.tokens}")
        print(f"最大max_tokens: {batch.max_tokens}")

        # 执行推理...
        # results = engine.generate([...])

        # 完成请求
        # for req in batch.requests:
        #     batcher.complete_request(req.id, "生成结果")

KV Cache优化

5.1 KV Cache管理

# inference/kv_cache.py
"""
KV Cache优化实现
"""

from typing import Dict, List, Tuple
import torch
from dataclasses import dataclass

@dataclass
class KVCacheEntry:
    """KV Cache条目"""
    key: torch.Tensor  # [num_layers, num_heads, seq_len, head_dim]
    value: torch.Tensor  # [num_layers, num_heads, seq_len, head_dim]
    sequence_length: int

class KVCacheManager:
    """KV Cache管理器"""

    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        max_cache_size: int = 1000  # 最大缓存条目数
    ):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_cache_size = max_cache_size

        self.cache: Dict[str, KVCacheEntry] = {}
        self.lru_keys: List[str] = []

    def get_or_create(
        self,
        prompt_hash: str,
        device: torch.device
    ) -> KVCacheEntry:
        """获取或创建Cache"""
        if prompt_hash in self.cache:
            # 更新LRU
            self._update_lru(prompt_hash)
            return self.cache[prompt_hash]

        # 创建新Cache
        cache_entry = self._create_empty_cache(device)
        self.cache[prompt_hash] = cache_entry
        self.lru_keys.append(prompt_hash)

        # 检查大小限制
        if len(self.cache) > self.max_cache_size:
            self._evict()

        return cache_entry

    def update(
        self,
        prompt_hash: str,
        layer_idx: int,
        key: torch.Tensor,  # [batch, num_heads, seq_len, head_dim]
        value: torch.Tensor  # [batch, num_heads, seq_len, head_dim]
    ):
        """更新Cache"""
        if prompt_hash not in self.cache:
            return

        cache_entry = self.cache[prompt_hash]

        # 拼接到现有cache
        if cache_entry.sequence_length == 0:
            # 初始化
            cache_entry.key[layer_idx] = key[0]  # 去掉batch维度
            cache_entry.value[layer_idx] = value[0]
        else:
            # 拼接
            cache_entry.key[layer_idx] = torch.cat(
                [cache_entry.key[layer_idx], key[0]],
                dim=-2
            )
            cache_entry.value[layer_idx] = torch.cat(
                [cache_entry.value[layer_idx], value[0]],
                dim=-2
            )

        cache_entry.sequence_length += key.shape[-2]

    def get(
        self,
        prompt_hash: str,
        layer_idx: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取指定层的KV"""
        if prompt_hash not in self.cache:
            return None, None

        cache_entry = self.cache[prompt_hash]
        return (
            cache_entry.key[layer_idx],
            cache_entry.value[layer_idx]
        )

    def clear(self):
        """清空Cache"""
        self.cache.clear()
        self.lru_keys.clear()

    def _create_empty_cache(self, device: torch.device) -> KVCacheEntry:
        """创建空Cache"""
        key = [torch.empty(0, 0, 0, device=device)
                for _ in range(self.num_layers)]
        value = [torch.empty(0, 0, 0, device=device)
                 for _ in range(self.num_layers)]

        return KVCacheEntry(
            key=key,
            value=value,
            sequence_length=0
        )

    def _update_lru(self, key: str):
        """更新LRU"""
        if key in self.lru_keys:
            self.lru_keys.remove(key)
        self.lru_keys.append(key)

    def _evict(self):
        """驱逐最老的条目"""
        if not self.lru_keys:
            return

        oldest_key = self.lru_keys.pop(0)
        del self.cache[oldest_key]

# ============== PagedAttention(分页注意力)==============

class PagedAttention:
    """
    PagedAttention实现

    核心思想:
    1. 将KV Cache按固定大小的块(页)存储
    2. 使用链表管理页的分配和释放
    3. 避免内存碎片化
    """

    def __init__(
        self,
        page_size: int = 16,  # 每页的token数
        num_pages: int = 1000  # 总页数
    ):
        self.page_size = page_size
        self.num_pages = num_pages
        self.page_size_bytes = page_size * 2 * 2  # key+value * float16

        # 页管理
        self.page_bitmap = torch.zeros(num_pages, dtype=torch.bool)
        self.page_tables: Dict[str, List[int]] = {}  # sequence_id -> [page_idx, ...]

    def allocate_pages(self, sequence_id: str, num_tokens: int) -> List[int]:
        """为序列分配页"""
        num_pages_needed = (num_tokens + self.page_size - 1) // self.page_size
        allocated_pages = []

        for _ in range(num_pages_needed):
            page_idx = self._find_free_page()
            if page_idx is None:
                raise RuntimeError("No free pages available")

            self.page_bitmap[page_idx] = True
            allocated_pages.append(page_idx)

        self.page_tables[sequence_id] = allocated_pages
        return allocated_pages

    def free_pages(self, sequence_id: str):
        """释放序列的页"""
        if sequence_id not in self.page_tables:
            return

        for page_idx in self.page_tables[sequence_id]:
            self.page_bitmap[page_idx] = False

        del self.page_tables[sequence_id]

    def get_pages(self, sequence_id: str) -> List[int]:
        """获取序列的页"""
        return self.page_tables.get(sequence_id, [])

    def _find_free_page(self) -> Optional[int]:
        """找到空闲页"""
        free_pages = (self.page_bitmap == 0).nonzero(as_tuple=True)[0]
        if len(free_pages) > 0:
            return free_pages[0].item()
        return None

    def get_stats(self) -> dict:
        """获取统计信息"""
        allocated = self.page_bitmap.sum().item()
        return {
            "total_pages": self.num_pages,
            "allocated_pages": allocated,
            "free_pages": self.num_pages - allocated,
            "utilization": allocated / self.num_pages
        }

# ============== 使用示例 ==============

if __name__ == "__main__":
    # 创建KV Cache管理器
    cache_manager = KVCacheManager(
        num=32,  # 32层
        num_heads=32,
        head_dim=128,
        max_cache_size=100
    )

    # 模拟处理提示词
    prompt = "解释什么是人工智能?"
    prompt_hash = str(hash(prompt))

    # 获取或创建Cache
    cache_entry = cache_manager.get_or_create(
        prompt_hash=prompt_hash,
        device=torch.device("cuda")
    )

    # 模拟逐层更新
    for layer_idx in range(32):
        key = torch.randn(1, 32, 10, 128, device="cuda")
        value = torch.randn(1, 32, 10, 128, device="cuda")

        cache_manager.update(prompt_hash, layer_idx, key, value)

    print(f"Sequence length: {cache_entry.sequence_length}")

    # PagedAttention
    paged_attn = PagedAttention(page_size=16, num_pages=1000)

    pages = paged_attn.allocate_pages("seq-1", num_tokens=100)
    print(f"Allocated pages: {pages}")

    stats = paged_attn.get_stats()
    print(f"Page utilization: {stats['utilization']:.2%}")

    paged_attn.free_pages("seq-1")

Tensor并行

6.1 模型并行策略

# inference/tensor_parallel.py
"""
Tensor并行实现
"""

import torch
import torch.distributed as dist
from typing import List

class TensorParallelStrategy:
    """张量并行策略"""

    def __init__(
        self,
        model,
        tensor_parallel_size: int,
        world_rank: int
    ):
        self.model = model
        self.tensor_parallel_size = tensor_parallel_size
        self.world_rank = world_rank

        # 初始化进程组
        if not dist.is_initialized():
            dist.init_process_group(backend='nccl')

    def parallelize_linear(
        self,
        linear: torch.nn.Linear,
        parallel_type: str  # "column" or "row"
    ) -> torch.nn.Linear:
        """
        并行化线性层

        Args:
            linear: 原始线性层
            parallel_type: 并行类型
                - "column": 列并行(按输出维度切分)
                - "row": 行并行(按输入维度切分)
        """
        if parallel_type == "column":
            return self._column_parallel_linear(linear)
        elif parallel_type == "row":
            return self._row_parallel_linear(linear)
        else:
            raise ValueError(f"Unknown parallel_type: {parallel_type}")

    def _column_parallel_linear(
        self,
        linear: torch.nn.Linear
    ) -> torch.nn.Linear:
        """列并行"""
        original_weight = linear.weight.data
        original_bias = linear.bias.data if linear.bias is not None else None

        # 计算切分大小
        output_dim = original_weight.shape[0]
        per_dim = output_dim // self.tensor_parallel_size

        # 切分权重
        start_idx = self.world_rank * per_dim
        end_idx = start_idx + per_dim
        sliced_weight = original_weight[start_idx:end_idx, :]

        # 切分偏置
        sliced_bias = None
        if original_bias is not None:
            sliced_bias = original_bias[start_idx:end_idx]

        # 创建新的线性层
        new_linear = torch.nn.Linear(
            in_features=linear.in_features,
            out_features=per_dim,
            bias=(sliced_bias is not None)
        )

        new_linear.weight.data = sliced_weight
        if sliced_bias is not None:
            new_linear.bias.data = sliced_bias

        return new_linear

    def _row_parallel_linear(
        self,
        linear: torch.nn.Linear
    ) -> torch.nn.Linear:
        """行并行"""
        original_weight = linear.weight.data
        original_bias = linear.bias.data if linear.bias is not None else None

        # 计算切分大小
        input_dim = original_weight.shape[1]
        per_dim = input_dim // self.tensor_parallel_size

        # 切分权重
        start_idx = self.world_rank * per_dim
        end_idx = start_idx + per_dim
        sliced_weight = original_weight[:, start_idx:end_idx]

        # 创建新的线性层
        new_linear = torch.nn.Linear(
            in_features=per_dim,
            out_features=linear.out_features,
            bias=(original_bias is not None)
        )

        new_linear.weight.data = sliced_weight
        if original_bias is not None:
            new_linear.bias.data = original_bias

        return new_linear

    def all_reduce(self, tensor: torch.Tensor) -> torch.Tensor:
        """All-Reduce操作"""
        if self.tensor_parallel_size > 1:
            dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
            tensor.div_(self.tensor_parallel_size)
        return tensor

    def all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
        """All-Gather操作"""
        if self.tensor_parallel_size > 1:
            gathered = torch.zeros(
                tensor.size(0) * self.tensor_parallel_size,
                *tensor.shape[1:],
                dtype=tensor.dtype,
                device=tensor.device
            )
            dist.all_gather_into_tensor(gathered, tensor)
            return gathered
        return tensor

# ============== 使用示例 ==============

if __name__ == "__main__":
    import torch.multiprocessing as mp

    def run(rank, world_size):
        """运行分布式推理"""
        # 创建模型
        model = ...  # 加载模型

        # 初始化张量并行
        tp = TensorParallelStrategy(
            model=model,
            tensor_parallel_size=world_size,
            world_rank=rank
        )

        # 并行化各层
        for layer in model.layers:
            # Column parallel
            layer.fc1 = tp.parallelize_linear(layer.fc1, "column")

            # Row parallel
            layer.fc2 = tp.parallelize_linear(layer.fc2, "row")

        # 推理...
        pass

    # 启动多进程
    # world_size = 4  # 4张GPU
    # mp.spawn(run, args=(world_size,), nprocs=world_size)

PagedAttention

7.1 vLLM PagedAttention集成

# inference/paged_attention.py
"""
PagedAttention高级用法
"""

from typing import List, Optional
from vllm import LLM, SamplingParams

class PagedAttentionEngine:
    """PagedAttention引擎"""

    def __init__(
        self,
        model_path: str,
        block_size: int = 16,  # 块大小
        max_num_blocks: int = 1000,  # 最大块数
        max_num_seqs: int = 256  # 最大序列数
    ):
        self.model_path = model_path
        self.block_size = block_size
        self.max_num_blocks = max_num_blocks
        self.max_num_seqs = max_num_seqs

        self.llm = LLM(
            model=model_path,
            block_size=block_size,
            max_num_blocks=max_num_blocks,
            max_num_seqs=max_num_seqs
        )

    def generate_stream(
        self,
        prompts: List[str],
        sampling_params: SamplingParams
    ):
        """
        流式生成

        Returns:
            生成器,每次yield一个token
        """
        outputs = self.llm.generate(prompts, sampling_params)

        for output in outputs:
            for request_output in output.outputs:
                yield {
                    "request_id": output.request_id,
                    "text": request_output.text,
                    "finish_reason": request_output.finish_reason
                }

    def get_memory_stats(self) -> dict:
        """获取内存统计"""
        return {
            "block_size": self.block_size,
            "max_num_blocks": self.max_num_blocks,
            "max_num_seqs": self.max_num_seqs,
            "gpu_memory": self.llm.llm_engine.model_runner.gpu_memory
        }

# ============== 使用示例 ==============

if __name__ == "__main__":
    # 创建PagedAttention引擎
    engine = PagedAttentionEngine(
        model_path="meta-llama/Llama-2-7b-hf",
        block_size=16,
        max_num_blocks=1000,
        max_num_seqs=256
    )

    # 流式生成
    sampling_params = SamplingParams(
        temperature=0.7,
        max_tokens=512
    )

    prompts = ["写一个关于AI的故事"]

    for chunk in engine.generate_stream(prompts, sampling_params):
        print(chunk["text"], end="", flush=True)
        if chunk["finish_reason"]:
            print(f"\n完成: {chunk['finish_reason']}")

    # 查看内存统计
    stats = engine.get_memory_stats()
    print(f"\nGPU Memory: {stats['gpu_memory']:.2f} GB")

性能基准测试

# inference/benchmark.py
"""
性能基准测试
"""

import time
from typing import List, Dict
import numpy as np

class InferenceBenchmark:
    """推理基准测试"""

    def __init__(self, engine):
        self.engine = engine

    def benchmark_throughput(
        self,
        prompts: List[str],
        max_tokens: int = 100,
        num_iterations: int = 10
    ) -> Dict:
        """
        吞吐量测试

        Returns:
            {
                "tokens_per_second": float,
                "requests_per_second": float,
                "avg_latency_ms": float,
                "p50_latency_ms": float,
                "p95_latency_ms": float,
                "p99_latency_ms": float
            }
        """
        latencies = []
        total_tokens = 0

        for i in range(num_iterations):
            start_time = time.time()

            results = self.engine.generate(
                prompts=prompts,
                max_tokens=max_tokens
            )

            end_time = time.time()
            latency_ms = (end_time - start_time) * 1000
            latencies.append(latency_ms)

            total_tokens += sum(len(r) for r in results)

        # 计算统计
        total_time = sum(latencies) / 1000
        total_requests = num_iterations * len(prompts)

        return {
            "total_time_s": total_time,
            "total_requests": total_requests,
            "total_tokens": total_tokens,
            "tokens_per_second": total_tokens / total_time,
            "requests_per_second": total_requests / total_time,
            "avg_latency_ms": np.mean(latencies),
            "p50_latency_ms": np.percentile(latencies, 50),
            "p95_latency_ms": np.percentile(latencies, 95),
            "p99_latency_ms": np.percentile(latencies, 99),
            "min_latency_ms": np.min(latencies),
            "max_latency_ms": np.max(latencies)
        }

    def benchmark_memory(self) -> Dict:
        """内存使用测试"""
        import torch

        if torch.cuda.is_available():
            return {
                "gpu_allocated_gb": torch.cuda.memory_allocated() / 1e9,
                "gpu_reserved_gb": torch.cuda.memory_reserved() / 1e9,
                "gpu_max_allocated_gb": torch.cuda.max_memory_allocated() / 1e9
            }
        return {}

    def print_benchmark_report(self, results: Dict):
        """打印基准测试报告"""
        print("=" * 50)
        print("推理性能基准测试报告")
        print("=" * 50)

        print(f"\n总测试时间: {results['total_time_s']:.2f}s")
        print(f"总请求数: {results['total_requests']}")
        print(f"总生成tokens: {results['total_tokens']}")

        print("\n--- 吞吐量 ---")
        print(f"Tokens/秒: {results['tokens_per_second']:.2f}")
        print(f"Requests/秒: {results['requests_per_second']:.2f}")

        print("\n--- 延迟 ---")
        print(f"平均值: {results['avg_latency_ms']:.2f}ms")
        print(f"P50: {results['p50_latency_ms']:.2f}ms")
        print(f"P95: {results['p95_latency_ms']:.2f}ms")
        print(f"P99: {results['p99_latency_ms']:.2f}ms")
        print(f"最小值: {results['min_latency_ms']:.2f}ms")
        print(f"最大值: {results['max_latency_ms']:.2f}ms")
        print("=" * 50)

# ============== 使用示例 ==============

if __name__ == "__main__":
    from inference.vllm_serving import VLLMEngineInference

    # 创建推理引擎
    engine = VLLMEngineInference(
        model_path="meta-llama/Llama-2-7b-hf"
    )

    # 创建基准测试
    benchmark = InferenceBenchmark(engine)

    # 测试数据
    test_prompts = [
        "解释什么是机器学习?",
        "写一个Python函数计算斐波那契数列",
        "什么是注意力机制?"
    ] * 10  # 30个请求

    # 运行基准测试
    results = benchmark.benchmark_throughput(
        prompts=test_prompts,
        max_tokens=100,
        num_iterations=10
    )

    # 打印报告
    benchmark.print_benchmark_report(results)

参考资料

量化工具

推理框架

技术论文


文档版本: 1.0
最后更新: 2026-01-22

close
arrow_upward