基于mT5的心理问答系统设计与实现
系统概述
本项目构建了一个基于mT5多语言Transformer的心理问答系统,融合了知识图谱、大语言模型API的多层响应策略。系统通过prompt引导的序列到序列学习,实现了对心理疾病相关问题的智能回答,并具备多轮对话、上下文记忆等高级功能。
核心技术架构
1. 多层级响应策略
系统采用三层响应机制,确保问题得到准确回答:
def process_question(self, question, is_repeat=False):
"""三层响应策略:训练模型 → 知识图谱 → 大语言模型API"""
# 1. 优先使用训练好的mT5模型
if self.qa_model:
try:
model_answer = self.qa_model.predict(question)
if model_answer and model_answer != "抱歉,我现在无法回答这个问题。":
return model_answer
except Exception as e:
print(f"模型预测出错: {str(e)}")
# 2. 知识图谱查询作为备选
try:
res_classify = self.classifier.classify(question)
if res_classify:
res_sql = self.parser.parser_main(res_classify)
if res_sql:
final_answers = self.searcher.search_main(res_sql)
if final_answers:
return '\n'.join(final_answers)
except Exception as e:
print(f"知识图谱查询出错: {str(e)}")
# 3. 星火大模型API兜底
try:
llm_answer = self.call_spark_api(question, is_repeat)
return llm_answer if llm_answer else default_answer
except Exception as e:
print(f"调用星火API出错: {str(e)}")
return default_answer
2. Prompt引导的序列到序列学习
核心创新在于使用prompt模板引导mT5模型学习问答任务:
class QADataset(Dataset):
def __getitem__(self, idx):
qa_pair = self.qa_pairs[idx]
question = qa_pair['question']
answer = qa_pair['answer']
# 关键:使用prompt模板引导模型学习
prompt = f"问题:{question} 回答:"
input_enc = self.tokenizer(
prompt, max_length=self.max_input_len,
truncation=True, padding='max_length', return_tensors='pt'
)
label_enc = self.tokenizer(
answer, max_length=self.max_output_len,
truncation=True, padding='max_length', return_tensors='pt'
)
labels = label_enc['input_ids'].squeeze()
labels[labels == self.tokenizer.pad_token_id] = -100 # 忽略pad token
return {
'input_ids': input_enc['input_ids'].squeeze(),
'attention_mask': input_enc['attention_mask'].squeeze(),
'labels': labels,
'prompt': prompt # 保存prompt用于验证集推理
}
3. 自定义DataLoader与批处理优化
解决验证集推理时prompt为空的关键技术:
def collate_fn(batch):
"""自定义批处理函数,支持字符串prompt的批处理"""
input_ids = torch.stack([item['input_ids'] for item in batch])
attention_mask = torch.stack([item['attention_mask'] for item in batch])
labels = torch.stack([item['labels'] for item in batch])
prompts = [item['prompt'] for item in batch] # 字符串列表
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'prompt': prompts
}
# 训练时使用自定义collate_fn
train_loader = DataLoader(
train_dataset, batch_size=batch_size,
shuffle=True, collate_fn=collate_fn
)
4. 多轮对话上下文管理
实现对话历史的智能管理:
def _update_history(self, question, answer):
"""更新对话历史,支持多轮对话"""
self.conversation_history.append({"role": "user", "content": question})
self.conversation_history.append({"role": "assistant", "content": answer})
def call_spark_api(self, question, is_repeat=False):
"""构建完整对话上下文调用大模型API"""
messages = [{
"role": "system",
"content": "你是一个具备丰富心理学知识、善于倾听的专业心理咨询师..."
}]
# 添加历史对话上下文
if not is_repeat and self.conversation_history:
messages.extend(self.conversation_history)
messages.append({"role": "user", "content": question})
# 调用星火API
data = {
"model": "x1",
"messages": messages,
"temperature": 0.7,
"max_tokens": 1024
}
训练流程与可视化
1. 数据预处理与分割
def split_data(qa_pairs: List[Dict], test_size: float = 0.2):
"""手动实现数据分割,避免sklearn版本冲突"""
qa_pairs_shuffled = qa_pairs.copy()
random.shuffle(qa_pairs_shuffled)
split_idx = int(len(qa_pairs_shuffled) * (1 - test_size))
train_data = qa_pairs_shuffled[:split_idx]
val_data = qa_pairs_shuffled[split_idx:]
return train_data, val_data
2. 训练过程监控
def train(self, train_data, val_data, epochs=3, batch_size=8, lr=2e-5):
"""训练过程包含损失计算、验证评估、指标记录"""
history = {
'train_loss': [],
'val_loss': [],
'accuracy': [],
'f1_scores': []
}
for epoch in range(epochs):
# 训练阶段
self.model.train()
for batch in tqdm(train_loader, desc=f"训练 Epoch {epoch+1}"):
outputs = self.model(
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 验证阶段
self.model.eval()
with torch.no_grad():
for batch in val_loader:
# 使用batch['prompt']进行推理
prompts = batch['prompt']
input_enc = self.tokenizer(prompts, ...)
generated_ids = self.model.generate(...)
3. 训练指标可视化
def plot_metrics(metrics: Dict[str, List[float]], output_dir: str):
"""生成训练过程的可视化图表"""
# 损失曲线
plt.figure(figsize=(10, 6))
epochs = range(1, len(metrics['train_loss']) + 1)
plt.plot(epochs, metrics['train_loss'], 'b-', label='训练损失')
plt.plot(epochs, metrics['val_loss'], 'r-', label='验证损失')
plt.title('训练过程中的损失变化')
plt.savefig(os.path.join(output_dir, 'loss_history.png'))
# 准确率和F1分数曲线
plt.plot(epochs, metrics['accuracy'], 'g-', label='准确率')
plt.plot(epochs, metrics['f1_scores'], 'm-', label='F1分数')
系统优化与亮点
1. 技术亮点
- Prompt工程优化:通过”问题:xxx 回答:”模板引导mT5学习中文问答任务
- 多模态融合:结合训练模型、知识图谱、大语言模型的三层响应机制
- 上下文记忆:支持多轮对话,维护完整的对话历史
- 错误容错:每层都有异常处理,确保系统稳定性
2. 性能优化
- 批处理优化:自定义collate_fn解决字符串prompt批处理问题
- 内存管理:合理设置max_length,避免显存溢出
- 推理加速:使用beam search和early stopping优化生成质量
3. 可扩展性设计
- 模块化架构:QA模型、知识图谱、API调用相互独立
- 配置灵活:支持不同模型路径、参数调整
- 接口统一:chat_main方法提供统一的对话接口
学习成果与技能提升
核心技术掌握
Transformer架构深入理解
- mT5多语言模型的fine-tuning技术
- 序列到序列学习的prompt工程
- 注意力机制在问答任务中的应用
PyTorch深度学习框架
- 自定义Dataset和DataLoader实现
- 批处理函数的优化设计
- 训练循环的精细化控制
NLP任务工程化
- 中文tokenization和编码处理
- 问答对数据的预处理和增强
- 评估指标的计算和可视化
系统架构设计
- 多层响应策略的架构设计
- 异常处理和容错机制
- 模块化代码组织
项目亮点
- 创新性:首次将prompt引导应用于心理问答领域
- 实用性:三层响应机制确保问题得到准确回答
- 可维护性:清晰的模块划分和错误处理
- 可扩展性:支持新模型、新数据源的接入
系统架构图
用户问题输入
↓
┌─────────────────┐
│ ChatBotGraph │
│ (主控制器) │
└─────────────────┘
↓
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ PsychQAModel │ │ 知识图谱查询 │ │ 星火API调用 │
│ (训练模型) │ │ (Neo4j) │ │ (大语言模型) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
↓ ↓ ↓
┌─────────────────────────────────────────────────────────────┐
│ 多层级响应策略 │
│ 1. 优先使用训练模型 2. 知识图谱查询 3. API兜底 │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────┐
│ 智能回答输出 │
└─────────────────┘
总结
本项目成功构建了一个功能完整、技术先进的心理问答系统。通过prompt引导的mT5模型训练、多层响应策略设计、以及完善的错误处理机制,实现了对心理疾病相关问题的智能回答。项目不仅展示了深度学习在NLP领域的应用,更体现了系统工程思维在AI应用开发中的重要性。
核心价值:将前沿的Transformer技术与实际应用场景结合,为心理健康的智能化服务提供了技术支撑,具有重要的学术价值和实用意义。
基于mT5的心理问答系统设计与实现