模仿学习案例
- 模仿学习案例
概述
Kuavo IL 是一个完整的机器人模仿学习(Imitation Learning)框架,支持数据采集、转换、模型训练和部署的全流程。本框架集成了多种先进的模仿学习方法,包括 Diffusion Policy 和 LeRobot,可用于各种机器人操作任务。
⚠️ 注意:该案例适用于MAX系列机器人,其他版本需自行适配
项目地址及结构说明
kuavo_il/
├── kuavo/ # 核心功能模块
│ ├── kuavo_1convert/ # 数据采集与转换
│ ├── kuavo_2train/ # 模型训练
│ ├── kuavo_3deploy/ # 模型部署
│ └── kuavo_utils/ # 通用工具函数
├── diffusion_policy/ # Diffusion Policy 框架
└── lerobot/ # LeRobot 框架
1. 数据采集转换 (kuavo_1convert)
1.1 数据采集
数据采集步骤:
- 启动ROS环境和相关传感器(相机、机器人关节编码器等)
- 打开Gradio网页应用进行数据采集控制
- 进行人类示范操作,同时记录视觉和机器人状态数据
- 数据将以rosbag格式保存,包含多个ROS话题(topics)
1.1.1 使用 app.py UI 界面采集(方法一,适合用户)
1.1.2 使用 record.py 脚本采集(方法二)
- 组织目录结构
KUAVO_IL_WORKSPACE=~/hx/kuavo_il
cd $KUAVO_IL_WORKSPACE
DATASET_DIR=~/kuavo
TASK_NAME=TASK0_weighting
mkdir -p $DATASET_DIR/$TASK_NAME/rosbag
cp kuavo/kuavo_1convert/collect_data/record.py $DATASET_DIR/$TASK_NAME/rosbag
/home/$USER/
├── kuavo
│ └── Task2-RearangeToy
│ └── rosbag
│ └── record.py
- 在 rosbag 目录里开启终端运行
python record.py
记录
cd $DATASET_DIR/$TASK_NAME/rosbag
python record.py -b ./ -c 50 -d 20 -w 5 #-b本地rosbag目录, -c采集episodes数量, -d任务持续时间, episode间休息还原场景时间
- 若有任务变化,调整 record.py 中的
TOPIC_LIST
适应任务
TOPIC_LIST = [
"/cam_l/color/image_raw/compressed", # 左手腕部相机
"/cam_r/color/image_raw/compressed", # 右手腕部相机
"/joint_cmd", # 机器人控制话题, dualArm_joint = left_arm + right_arm = msg.joint_q[12:29] + msg.joint_q[19:26],
"/sensors_data_raw", # 机器人状态话题, dualArm_joint = left_arm + right_arm = msg.joint_data.joint_q[12:29] + msg.joint_data.joint_q[19:26],
"/control_robot_hand_position", # 灵巧手控制话题, len(msg.left_hand_position) + len(msg.right_hand_position) = 6 + 6 = 12. 0开100合,除拇指两个关节值,其他手指一个关节值,
"/control_robot_hand_position_state", # 灵巧手状态话题(目前无法获取,采集数据暂时只更新时间戳转发灵巧手控制话题)
"/zedm/zed_node/left/image_rect_color/compressed", # 头部zed左眼相机
"/zedm/zed_node/right/image_rect_color/compressed", # 头部zed右眼相机
]
- record.py会检测以上
话题
及其期望频率
,以及是否有时间戳
(具体修改record.py) - 及时检查 rosbag 包,确保话题和帧率符合要求(使用可视化软件
Foxglove
或者PlotJuggler
) - 录制过程出错及时
ctrl+c
终止并删除最新bag包然后重新运行record.py
1.2 数据转换
Kuavo IL 支持多种数据格式转换,以适应不同的训练框架需求。
1.2.1 ROS Bag 转 Zarr 格式
Zarr 格式是一种高效的多维数组存储格式,适合大规模机器学习数据集。
# bag -> zarr
python kuavo/kuavo_1convert/cvt_rosbag2zarr.py -b ~/hx/kuavo/Task12_zed_dualArm/rosbag -c kuavo/kuavo_1convert/config/Task12_zed_dualArm.yaml -n 50 -l 40
参数说明:
-b, --bag_folder_path
: ROS bag 文件目录-c, --config
: 配置文件路径(根据bag话题修改对应的yaml文件,内包含不同话题的处理方法以及抽帧值)-n, --num_of_bag
: 需要处理的bag包数量, 默认bag_folder_path
目录所有bag包-a, --append
: 如果需要合并多个任务的数据,这里可以加上-a
-l, --jpeg_compress_level
: 图像的压缩质量,越小质量越低(0-100)
转换完毕后的目录结构:
/home/$USER/
├── kuavo
├── Task12_zed_dualArm
│ ├── kuavo-zarr
│ ├── plt-check(检查电机的cmd和state曲线)
│ ├── raw-video
│ ├── rosbag
│ └── sample-video
1.2.2 ROS Bag 转 LeRobot 格式
LeRobot 格式是 LeRobot 框架专用的数据格式。
# bag -> lerobot
python kuavo/kuavo_1convert/cvt_rosbag2lerobot.py --raw_dir $DATASET_DIR/Task12_zed_dualArm/rosbag --repo_id Task12_zed_dualArm/lerobot
这会在raw_dir
上级目录生成lerobot
格式数据
参数说明:
--raw_dir
: ROS bag 文件目录--repo_id
: 数据集仓库 ID
/home/$USER/
├── kuavo
├── Task12_zed_dualArm
│ ├── kuavo-zarr
│ ├── lerobot(新生成的lerobot格式数据集)
│ ├── plt-check
│ ├── raw-video
│ ├── rosbag
│ └── sample-video
1.2.3 数据可视化
转换完成后,可以使用 LeRobot 提供的工具进行数据可视化,检查数据质量:
# 可视化转换完毕的 lerobot dataset:
python lerobot/lerobot/scripts/visualize_dataset.py --repo-id Task12_zed_dualArm/lerobot --root $DATASET_DIR/kuavo/Task12_zed_dualArm/lerobot --episode 55 --local-files-only 1
2. 模型训练 (kuavo_2train)
Kuavo IL 支持多种模型训练框架,包括 Diffusion Policy 和 LeRobot。
2.1 Diffusion Policy 训练
Diffusion Policy 是一种基于扩散模型的机器人策略学习方法。
单卡训练
python diffusion_policy/train.py --config-name=Task12_zed_dualArm
单机多卡并行训练(使用 Ray)
export CUDA_VISIBLE_DEVICES=0,1
python diffusion_policy/ray_train_multirun.py --config-dir=diffusion_policy/config --config-name=Task12_zed_dualArm.yaml
2.2 LeRobot 训练
LeRobot 是 Hugging Face 开发的机器人学习框架。
单机训练
CUDA_VISIBLE_DEVICES=2,3
python lerobot/lerobot/scripts/train.py \
--dataset.repo_id Task12_zed_dualArm/lerobot \
--policy.type act \
--dataset.local_files_only true \
--dataset.root ~/hx/kuavo/Task12_zed_dualArm/lerobot
分布式训练
CUDA_VISIBLE_DEVICES=2,3
GPUS=2
accelerate launch --num_processes=$GPUS --main_process_port 29501 \
~/hx/kuavo_il/lerobot/lerobot/scripts/train_distributed.py \
--dataset.repo_id Task12_zed_dualArm/lerobot \
--policy.type diffusion \
--dataset.local_files_only true \
--dataset.root ~/hx/kuavo/Task12_zed_dualArm/lerobot
3. 模型部署 (kuavo_3deploy)
训练完成的模型可以部署到实际机器人上进行测试和应用。Kuavo IL 提供了多种部署方式,支持不同的机器人平台。
3.1 环境检查
在部署前,建议先检查环境是否满足要求:
python kuavo/kuavo_3deploy/env.py
若没有灵巧手状态
话题,转发灵巧手控制
话题:
python kuavo/kuavo_3deploy/fake_state.py
3.2 模型评估
评估 [dp|lerobot] 模型
- 修改eval.py的test_cfg
test_cfg = {
'model_fr': 'lerobot', # 'oridp' or 'lerobot'
'task': 'Task12_zed_dualArm', # task_name
'ckpt': [ # 默认取ckpt[0]
'/home/leju-ali/hx/kuavo/Task12_zed_dualArm/train_lerobot/outputs/train/2025-03-21/02-25-36_act/checkpoints/120000/pretrained_model', #双gpu训练act
'/home/leju-ali/hx/kuavo/Task12_zed_dualArm/train_lerobot/outputs/train/2025-03-17/16-08-09_act/checkpoints/140000/pretrained_model',
'/home/leju-ali/hx/kuavo/Task12_zed_dualArm/kuavo-zar_480_toolarge/epoch=0060-train_loss=0.012.ckpt',
'/home/leju-ali/hx/kuavo/Task12_zed_dualArm/kuavo-zar_480_toolarge/epoch=0040-train_loss=0.016.ckpt',
'/home/leju-ali/hx/ckpt/epoch=0060-train_loss=0.013.ckpt',
'/home/leju-ali/hx/ckpt/wks/dataset/dataset_wason_20250307/data/outputs/2025.03.07/21.26.03_train_diffusion_unet_image_Task11_Toy/checkpoints/epoch=0140-train_loss=0.004.ckpt',
'/home/leju-ali/hx/ckpt/wks/dataset/dataset_wason_20250307/data/outputs/2025.03.09/23.31.47_train_diffusion_unet_image_Task11_Toy/checkpoints/epoch=0100-train_loss=0.007.ckpt',
],
'debug': False,
'is_just_img': False,
'bag': '/home/leju-ali/hx/kuavo/Task12_zed_dualArm/rosbag/rosbag_2025-03-15-14-35-05.bag',
'fps': 10, # 根据训练模型时数据集的帧率修改(常用10hz)
}
- 修改环境配置文件
W, H = 640, 480 # 根据训练数据调整
DEFAULT_OBS_KEY_MAP = {
"img":{
"img01": {
"topic":"/zedm/zed_node/left/image_rect_color/compressed",
"msg_type":CompressedImage,
'frequency': 30, # 根据具体相机的发布频率(要求发布频率稳定)
'handle': {
"params": {
'resize_wh': (W, H),
}
}
},
"img02": {
"topic":"/zedm/zed_node/right/image_rect_color/compressed",
"msg_type":CompressedImage,
'frequency': 30,
'handle': {
"params": {
'resize_wh': (W, H),
}
}
},
"img03": {
"topic":"/cam_l/color/compressed",
"msg_type":CompressedImage,
'frequency': 30,
'handle': {
"params": {
'resize_wh': (W, H),
}
}
},
"img04": {
"topic":"/cam_r/color/compressed",
"msg_type":CompressedImage,
'frequency': 30,
'handle': {
"params": {
'resize_wh': (W, H),
}
}
},
},
"low_dim":{
"state_joint": {
"topic":"/sensors_data_raw",
"msg_type":sensorsData,
"frequency": 500, # 机器人状态话题发布频率
'handle': {
"params": {
'slice': [(12,19), (19, 26)] # 0:12, 12:19(左手7joints), 19:26(右手7joints), 26:28.
}
},
},
"state_gripper": {
"topic":"/control_robot_hand_position_state",
"msg_type":robotHandPosition,
"frequency": 100, # 灵巧手状态话题发布频率
'handle': {
"params": {
'slice': [(0,1), (6,7)] # 一只灵巧手有6个关节。模型训练只取左,右手的大拇指第一个关节的状态表示所有关节的状态。
}
},
},
}
}
env输出接口统一:
env.obs_buffer.wait_buffer_ready(just_img = False)
obs_data, camera_obs, camera_obs_timestamps, robot_obs, robot_obs_timestamps = env.get_obs(just_img = False)
obs = {
'agent_pos': (n_obs_steps, low_dim)
'img01': (n_obs_steps, h, w, c)
'img02': (n_obs_steps, h, w, c)
'img..':(n_obs_steps, h, w, c)
"timestamp": (n_obs_steps,)
}
将机器人控制至初始位置(VR或者外部控制) TODO
运行模型推理代码
运行外部控制前:如果之前的控制来自VR,则须退出VR控制模式(长摁遥控器B)
,然后在推理终端运行rosservice call /arm_traj_change_mode "control_mode: 2"
,运行后手会小幅度抽搐
python kuavo/kuavo_3deploy/eval.py
3.3 端测部署
TODO
将训练好的模型部署到Jetson Orin.
TODO
4. 常见问题与解决方案
问题: 数据采集时 ROS 话题不正确 解决方案: 检查 record.py 中的话题配置,确保与机器人发布的话题一致
问题: 模型训练时内存不足 解决方案: 减小批量大小或使用分布式训练
问题: 模型部署时机器人动作不准确 解决方案: 检查传感器校准,调整模型参数,或增加训练数据多样性