在habitat中训练一个模型需要指定配置文件,(根据目前的学习)一般要指定两个yaml文件:
- 一个是训练的配置文件
- 一个是任务的配置文件
举例如下:
import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torch
if __name__ == "__main__":
run_type = "train" #指定是训练还是评估
#指定训练配置文件
config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")
#下面是在代码中对一些配置参数进行修改
config.defrost()
config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"
config.freeze()
random.seed(config.TASK_CONFIG.SEED)
np.random.seed(config.TASK_CONFIG.SEED)
torch.manual_seed(config.TASK_CONFIG.SEED)
if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
torch.set_num_threads(1)
trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)###config.TRAINER_NAME指定模型名字
assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported"
trainer = trainer_init(config)
if run_type == "train":
trainer.train()
elif run_type == "eval":
trainer.eval()
上面所指定的训练文件ppo_pointnav_example.yaml中有一个配置项如下:
BASE_TASK_CONFIG_PATH: "../configs/tasks/pointnav.yaml"
从上面的代码可以看出来在代码中指定训练的配置文件,在训练配置文件中配置任务配置文件。
训练过程肯定要指定数据集(TASK_CONFIG.DATASET.DATA_PATH)(在训练配置文件中配置还是在任务配置文件中配置?目前至少看到在任务配置文件中是可以的)。
如果TASK_CONFIG.DATASET.DATA_PATH没有重新指定,会有默认值(目前知道有些默认值是从…/habitat-lab/habitat_baselines/config/default.py 中定义的)。
如果是点导航任务,需要同时指定正确的DATA_PATH和SCENES_DIR,否则会报错Could not find dataset file
具体原因见下面的代码
文件位置:.../habitat/datasets/pointnav/pointnav_dataset.py
@registry.register_dataset(name="PointNav-v1")
class PointNavDatasetV1(Dataset):
r"""Class inherited from Dataset that loads Point Navigation dataset."""
episodes: List[NavigationEpisode]
content_scenes_path: str = "{data_path}/content/{scene}.json.gz"
@staticmethod
def check_config_paths_exist(config: Config) -> bool:
return os.path.exists(
config.DATA_PATH.format(split=config.SPLIT)
) and os.path.exists(config.SCENES_DIR)
@classmethod
def get_scenes_to_load(cls, config: Config) -> List[str]:
r"""Return list of scene ids for which dataset has separate files with
episodes.
"""
dataset_dir = os.path.dirname(
config.DATA_PATH.format(split=config.SPLIT)
)
if not cls.check_config_paths_exist(config):
raise FileNotFoundError(
f"Could not find dataset file `{dataset_dir}`"
)