habitat中的坑(一):训练模型的时候找不到数据

在habitat中训练一个模型需要指定配置文件,(根据目前的学习)一般要指定两个yaml文件:

  • 一个是训练的配置文件
  • 一个是任务的配置文件

举例如下:

import random
import numpy as np
from habitat_baselines.common.baseline_registry import baseline_registry
from habitat_baselines.config.default import get_config as get_baselines_config
import torch

if __name__ == "__main__":
    run_type = "train"      #指定是训练还是评估
    #指定训练配置文件
    config = get_baselines_config("../habitat_baselines/config/pointnav/ppo_pointnav_example.yaml")

  #下面是在代码中对一些配置参数进行修改
    config.defrost()
    config.TASK_CONFIG.DATASET.DATA_PATH="/home/yons/LK/skill_transformer-main/data/datasets/pointnav/habitat-test-scenes/v1/{split}/{split}.json.gz"
    config.TASK_CONFIG.DATASET.SCENES_DIR="/home/yons/LK/skill_transformer-main/data/scene_datasets"
    config.freeze()
    
    random.seed(config.TASK_CONFIG.SEED)
    np.random.seed(config.TASK_CONFIG.SEED)
    torch.manual_seed(config.TASK_CONFIG.SEED)
    if config.FORCE_TORCH_SINGLE_THREADED and torch.cuda.is_available():
        torch.set_num_threads(1)

    trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME)###config.TRAINER_NAME指定模型名字
    assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported"
    trainer = trainer_init(config)

    if run_type == "train":
        trainer.train()
    elif run_type == "eval":
        trainer.eval()

上面所指定的训练文件ppo_pointnav_example.yaml中有一个配置项如下:

BASE_TASK_CONFIG_PATH: "../configs/tasks/pointnav.yaml"

从上面的代码可以看出来在代码中指定训练的配置文件,在训练配置文件中配置任务配置文件。

训练过程肯定要指定数据集(TASK_CONFIG.DATASET.DATA_PATH)(在训练配置文件中配置还是在任务配置文件中配置?目前至少看到在任务配置文件中是可以的)。

如果TASK_CONFIG.DATASET.DATA_PATH没有重新指定,会有默认值(目前知道有些默认值是从…/habitat-lab/habitat_baselines/config/default.py 中定义的)。

如果是点导航任务,需要同时指定正确的DATA_PATH和SCENES_DIR,否则会报错Could not find dataset file

具体原因见下面的代码
文件位置:.../habitat/datasets/pointnav/pointnav_dataset.py

@registry.register_dataset(name="PointNav-v1")
class PointNavDatasetV1(Dataset):
    r"""Class inherited from Dataset that loads Point Navigation dataset."""

    episodes: List[NavigationEpisode]
    content_scenes_path: str = "{data_path}/content/{scene}.json.gz"

    @staticmethod
    def check_config_paths_exist(config: Config) -> bool:
        return os.path.exists(
            config.DATA_PATH.format(split=config.SPLIT)
        ) and os.path.exists(config.SCENES_DIR)

    @classmethod
    def get_scenes_to_load(cls, config: Config) -> List[str]:
        r"""Return list of scene ids for which dataset has separate files with
        episodes.
        """
        dataset_dir = os.path.dirname(
            config.DATA_PATH.format(split=config.SPLIT)
        )
        if not cls.check_config_paths_exist(config):
            raise FileNotFoundError(
                f"Could not find dataset file `{dataset_dir}`"
            )

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-11 09:58:03       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-11 09:58:03       106 阅读
  3. 在Django里面运行非项目文件

    2024-03-11 09:58:03       87 阅读
  4. Python语言-面向对象

    2024-03-11 09:58:03       96 阅读

热门阅读

  1. 【RHCSA问答题】第十章 配置和保护SSH

    2024-03-11 09:58:03       39 阅读
  2. Day41| 416 分割等和子集

    2024-03-11 09:58:03       46 阅读
  3. 【FreeRTOS任务调度机制学习】

    2024-03-11 09:58:03       37 阅读
  4. 归并排序

    2024-03-11 09:58:03       46 阅读
  5. 微信小程序-wxml语法

    2024-03-11 09:58:03       50 阅读
  6. Keepalived工具的基本介绍(原理:VRRP协议)

    2024-03-11 09:58:03       42 阅读
  7. MongoDB聚合运算符:$dayOfYear

    2024-03-11 09:58:03       48 阅读