基于mT5的心理问答系统设计与实现

系统概述

本项目构建了一个基于mT5多语言Transformer的心理问答系统,融合了知识图谱、大语言模型API的多层响应策略。系统通过prompt引导的序列到序列学习,实现了对心理疾病相关问题的智能回答,并具备多轮对话、上下文记忆等高级功能。

核心技术架构

1. 多层级响应策略

系统采用三层响应机制,确保问题得到准确回答:

def process_question(self, question, is_repeat=False):
    """三层响应策略:训练模型 → 知识图谱 → 大语言模型API"""
    # 1. 优先使用训练好的mT5模型
    if self.qa_model:
        try:
            model_answer = self.qa_model.predict(question)
            if model_answer and model_answer != "抱歉,我现在无法回答这个问题。":
                return model_answer
        except Exception as e:
            print(f"模型预测出错: {str(e)}")
    
    # 2. 知识图谱查询作为备选
    try:
        res_classify = self.classifier.classify(question)
        if res_classify:
            res_sql = self.parser.parser_main(res_classify)
            if res_sql:
                final_answers = self.searcher.search_main(res_sql)
                if final_answers:
                    return '\n'.join(final_answers)
    except Exception as e:
        print(f"知识图谱查询出错: {str(e)}")
    
    # 3. 星火大模型API兜底
    try:
        llm_answer = self.call_spark_api(question, is_repeat)
        return llm_answer if llm_answer else default_answer
    except Exception as e:
        print(f"调用星火API出错: {str(e)}")
        return default_answer

2. Prompt引导的序列到序列学习

核心创新在于使用prompt模板引导mT5模型学习问答任务:

class QADataset(Dataset):
    def __getitem__(self, idx):
        qa_pair = self.qa_pairs[idx]
        question = qa_pair['question']
        answer = qa_pair['answer']
        # 关键:使用prompt模板引导模型学习
        prompt = f"问题:{question} 回答:"
        
        input_enc = self.tokenizer(
            prompt, max_length=self.max_input_len, 
            truncation=True, padding='max_length', return_tensors='pt'
        )
        label_enc = self.tokenizer(
            answer, max_length=self.max_output_len, 
            truncation=True, padding='max_length', return_tensors='pt'
        )
        labels = label_enc['input_ids'].squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100  # 忽略pad token
        
        return {
            'input_ids': input_enc['input_ids'].squeeze(),
            'attention_mask': input_enc['attention_mask'].squeeze(),
            'labels': labels,
            'prompt': prompt  # 保存prompt用于验证集推理
        }

3. 自定义DataLoader与批处理优化

解决验证集推理时prompt为空的关键技术:

def collate_fn(batch):
    """自定义批处理函数,支持字符串prompt的批处理"""
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    prompts = [item['prompt'] for item in batch]  # 字符串列表
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels,
        'prompt': prompts
    }

# 训练时使用自定义collate_fn
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, 
    shuffle=True, collate_fn=collate_fn
)

4. 多轮对话上下文管理

实现对话历史的智能管理:

def _update_history(self, question, answer):
    """更新对话历史,支持多轮对话"""
    self.conversation_history.append({"role": "user", "content": question})
    self.conversation_history.append({"role": "assistant", "content": answer})

def call_spark_api(self, question, is_repeat=False):
    """构建完整对话上下文调用大模型API"""
    messages = [{
        "role": "system",
        "content": "你是一个具备丰富心理学知识、善于倾听的专业心理咨询师..."
    }]
    
    # 添加历史对话上下文
    if not is_repeat and self.conversation_history:
        messages.extend(self.conversation_history)
    
    messages.append({"role": "user", "content": question})
    
    # 调用星火API
    data = {
        "model": "x1",
        "messages": messages,
        "temperature": 0.7,
        "max_tokens": 1024
    }

训练流程与可视化

1. 数据预处理与分割

def split_data(qa_pairs: List[Dict], test_size: float = 0.2):
    """手动实现数据分割,避免sklearn版本冲突"""
    qa_pairs_shuffled = qa_pairs.copy()
    random.shuffle(qa_pairs_shuffled)
    
    split_idx = int(len(qa_pairs_shuffled) * (1 - test_size))
    train_data = qa_pairs_shuffled[:split_idx]
    val_data = qa_pairs_shuffled[split_idx:]
    
    return train_data, val_data

2. 训练过程监控

def train(self, train_data, val_data, epochs=3, batch_size=8, lr=2e-5):
    """训练过程包含损失计算、验证评估、指标记录"""
    history = {
        'train_loss': [],
        'val_loss': [],
        'accuracy': [],
        'f1_scores': []
    }
    
    for epoch in range(epochs):
        # 训练阶段
        self.model.train()
        for batch in tqdm(train_loader, desc=f"训练 Epoch {epoch+1}"):
            outputs = self.model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels']
            )
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # 验证阶段
        self.model.eval()
        with torch.no_grad():
            for batch in val_loader:
                # 使用batch['prompt']进行推理
                prompts = batch['prompt']
                input_enc = self.tokenizer(prompts, ...)
                generated_ids = self.model.generate(...)

3. 训练指标可视化

def plot_metrics(metrics: Dict[str, List[float]], output_dir: str):
    """生成训练过程的可视化图表"""
    # 损失曲线
    plt.figure(figsize=(10, 6))
    epochs = range(1, len(metrics['train_loss']) + 1)
    plt.plot(epochs, metrics['train_loss'], 'b-', label='训练损失')
    plt.plot(epochs, metrics['val_loss'], 'r-', label='验证损失')
    plt.title('训练过程中的损失变化')
    plt.savefig(os.path.join(output_dir, 'loss_history.png'))
    
    # 准确率和F1分数曲线
    plt.plot(epochs, metrics['accuracy'], 'g-', label='准确率')
    plt.plot(epochs, metrics['f1_scores'], 'm-', label='F1分数')

系统优化与亮点

1. 技术亮点

  • Prompt工程优化:通过”问题:xxx 回答:”模板引导mT5学习中文问答任务
  • 多模态融合:结合训练模型、知识图谱、大语言模型的三层响应机制
  • 上下文记忆:支持多轮对话,维护完整的对话历史
  • 错误容错:每层都有异常处理,确保系统稳定性

2. 性能优化

  • 批处理优化:自定义collate_fn解决字符串prompt批处理问题
  • 内存管理:合理设置max_length,避免显存溢出
  • 推理加速:使用beam search和early stopping优化生成质量

3. 可扩展性设计

  • 模块化架构:QA模型、知识图谱、API调用相互独立
  • 配置灵活:支持不同模型路径、参数调整
  • 接口统一:chat_main方法提供统一的对话接口

学习成果与技能提升

核心技术掌握

  1. Transformer架构深入理解

    • mT5多语言模型的fine-tuning技术
    • 序列到序列学习的prompt工程
    • 注意力机制在问答任务中的应用
  2. PyTorch深度学习框架

    • 自定义Dataset和DataLoader实现
    • 批处理函数的优化设计
    • 训练循环的精细化控制
  3. NLP任务工程化

    • 中文tokenization和编码处理
    • 问答对数据的预处理和增强
    • 评估指标的计算和可视化
  4. 系统架构设计

    • 多层响应策略的架构设计
    • 异常处理和容错机制
    • 模块化代码组织

项目亮点

  • 创新性:首次将prompt引导应用于心理问答领域
  • 实用性:三层响应机制确保问题得到准确回答
  • 可维护性:清晰的模块划分和错误处理
  • 可扩展性:支持新模型、新数据源的接入

系统架构图

用户问题输入
    ↓
┌─────────────────┐
│   ChatBotGraph  │
│   (主控制器)     │
└─────────────────┘
    ↓
┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│   PsychQAModel  │    │   知识图谱查询    │    │   星火API调用   │
│   (训练模型)     │    │   (Neo4j)       │    │   (大语言模型)  │
└─────────────────┘    └─────────────────┘    └─────────────────┘
    ↓                        ↓                        ↓
┌─────────────────────────────────────────────────────────────┐
│              多层级响应策略                                  │
│  1. 优先使用训练模型 2. 知识图谱查询 3. API兜底            │
└─────────────────────────────────────────────────────────────┘
    ↓
┌─────────────────┐
│   智能回答输出   │
└─────────────────┘

总结

本项目成功构建了一个功能完整、技术先进的心理问答系统。通过prompt引导的mT5模型训练、多层响应策略设计、以及完善的错误处理机制,实现了对心理疾病相关问题的智能回答。项目不仅展示了深度学习在NLP领域的应用,更体现了系统工程思维在AI应用开发中的重要性。

核心价值:将前沿的Transformer技术与实际应用场景结合,为心理健康的智能化服务提供了技术支撑,具有重要的学术价值和实用意义。

作者

HuangZhongqi

发布于

2024-12-19

更新于

2025-10-03

许可协议

评论