DDBDataLoader 详细介绍

安装步骤

DolphinDB Python API 自 1.30.22.2 版本起提供深度学习工具类 DDBDataLoader,提供对 DolphinDB SQL 对应的数据集进行批量拆分和重新洗牌的易用接口,将 DolphinDB 中的数据直接对接到 PyTorch 中。

pip install dolphindb-tools 

期待输出:


img

DolphinDB类型与Tensor类型对照表

DolphinDB 类型 Tensor 类型
BOOL [不含空值] torch.bool
CHAR [不含空值] torch.int8
SHORT [不含空值] torch.int16
INT [不含空值] torch.int32
LONG [不含空值] torch.int64
FLOAT torch.float32
DOUBLE torch.float64
CHAR/SHORT/INT/LONG [包含空值] torch.float64

注意

  1. 若 sql 查询的结果表中包含不支持的类型,即便其列名被包含在 targetCol 中,即表示迭代中 y 对应的列名,详细见接口说明,该数据列也不会出现在输入数据和目标数据中。
  2. 支持上述类型的 ArrayVector 类型。如果使用 ArrayVector 列,需要保证输入数据或目标数据全部为 ArrayVector 类型。
  3. torch.bool 不支持布尔型数据的空值,因此获取 BOOL 类型数据前需确保不包含空值。

接口介绍

提供 DDBDataLoader 类来加载和访问数据,接口如下:

DDBDataLoader(
    ddbSession: Session,
    sql: str,
    targetCol: List[str],
    batchSize: int = 1,
    shuffle: bool = True,
    windowSize: Union[List[int], int, None] = None,
    windowStride: Union[List[int], int, None] = None,
    *,
    inputCol: Optional[List[str]] = None,
    excludeCol: Optional[List[str]] = None,
    repartitionCol: str = None,
    repartitionScheme: List[str] = None,
    groupCol: str = None,   
    groupScheme: List[str] = None,
    seed: Optional[int] = None,
    dropLast: bool = False,
    offset: int = None,
    device: str = "cpu",
    prefetchBatch: int = 1,
    prepartitionNum: int = 2,
    groupPoolSize: int = 3,
    **kwargs
)

必选参数 (基础信息)

  • ddbSession(dolphindb.Session): 用于获取数据的 Session 连接,包含训练所需的上下文信息。
  • sql(str): 表示将数据取出用于训练的 SQL 语句,特别的,该语句必须为查询语句元代码,应尽可能简单,目前不支持 group by/context by 子句。

参数(迭代列名类)

  • targetCol(List[str]): 必填参数,字符串或者字符串列表。表示迭代中 y 对应的列名。
    • 如果指定了 inputCol,x 的数据为 inputCol 对应的列名, y 的数据为 targetCol 对应的列名,excludeCol 不生效。
    • 不指定 inputCol,指定 excludeCol:x 的数据为 所有列 - excludeCol 指定的列名,y 的数据为 targetCol 对应的列名
    • 不指定 inputCol,也不指定 excludeCol:x 的数据为所有列,y 的数据为 targetCol 对应的列名
  • inputCol(Optional[List[str]]): 可选参数,字符串或者字符串列表。表示迭代中 x 对应的列名,如果不指定则表示所有列,默认值为 None
  • excludeCol(Optional[List[str]]): 可选参数,字符串或者字符串列表。表示迭代中 x 排除的列名,默认值为 None

可选参数(取数规则类)

  • batchSize(int): 批次大小,指定每个批次数据中样本数量。表示每个批次只包含一个样本,默认值为 1。
  • shuffle(bool): 是否对数据进行随机打乱,表示不对数据进行打乱,默认值为 False。
  • seed(Optional[int]): 随机种子,该种子仅在 DDBDataLoader 对象中生效,与外界隔离。默认值为 None,表示不指定随机种子。
  • dropLast(bool): 是否丢弃不足 batchSize 的数据,其值为 True 时,如果 excludeColl 无法整除查询结果的大小,则会丢弃最后一组不足 excludeColl 的数据。默认值为 False,表示不丢弃最后一组不足 excludeColl 的数据。

可选参数(窗口类)

  • windowSize(Union[List[int], int, None]): 用于指定滑动窗口的大小,默认值为 None
    • 如果不指定该参数,表示不使用滑动窗口。
    • 如果传入一个整数值(int),例如 windowSize=3,表示 x 的滑动窗口大小为 3,y 的滑动窗口大小为 1。
    • 如果传入两个整数值的列表,例如 windowSize=[4, 2],表示 x 的滑动窗口大小为 4,y 的滑动窗口大小为 2。
  • windowStride(Union[List[int], int, None]): 用于指定滑动窗口在数据上滑动的步长,默认值为 None
    • 不指定 windowSize 时,该参数无效。
    • 如果传入一个整数值(int),例如 windowStride=2,那么表示 x 的滑动窗口步长为 2,而 y 的滑动窗口步长为 1。
    • 如果传入两个整数值的列表,例如 windowStride=[3, 1],那么表示 x 的滑动窗口步长为 3,而 y 的滑动窗口步长为 1。
  • offset(Optional[int]): y 相对于 x 偏移的行数(非负数)。不启用滑动窗口时,表示训练数据都在同一行中。指定滑动窗口时,该参数默认为 x 对应滑动窗口的大小,默认为 0。

可选参数(数据切分类)

  • repartitionCol(Optional[str]): 用于进一步拆分分组查询为子查询的列。默认值为 None
  • repartitionScheme(Optional[List[str]]): 分区点值,是一个字符串列表。每个列表元素将和 repartitionCol 指定的列一起使用,以通过条件 where repartitionCol = value 对数据做进一步筛选和分割,默认值为 None
  • groupCol(Optional[str]): 用于将查询划分成组的列。这个列的值将用于定义分组,默认值为 None
  • groupScheme(Optional[List[str]]): 分组点值,是一个字符串列表。每个列表元素将与 groupCol 指定的列一起使用,以通过条件 where groupCol = value 对数据进一步筛选和分组,默认值为 None

  1. 其中 repartitionColrepartitionScheme 功能可用于解决单个分区数据较多,无法直接进行全量运算的情况。通过将数据根据 repartitionScheme 的值进行筛选,可以将数据分割成多个子分区,每个子分区将按照 repartitionScheme 中的顺序排列。例如,如果 repartitionCol 为 date(TradeTime), repartitionScheme 为 ["2020.01.01", "2020.01.02", "2020.01.03"],则数据将被细分为三个分区,每个分区对应一个日期值。
  2. 不同于 repartitionCol/repartitionScheme,其中 groupColgroupScheme 的分组之间不会出现跨分组的数据,例如,如果 groupCol 为 Code,groupScheme 为 [“`000001.SH”, “`000002.SH”, “`000003.SH”],则数据将被划分为三个不相交的分组,每个分组对应一个股票代码。

其他可选参数(不常用类)

  • device(Optional[str]): 用于指定张量将被创建在哪个设备上。你可以将其设置为 “cuda“ 或其他支持的设备名称,以便在 GPU 上创建张量。默认值为 "cpu" ,表示在 CPU 上创建张量。
  • prefetchBatch(int): 表示预加载的批数,用于控制一次性加载多少批次的数据,默认值为 1
  • prepartitionNum(int): 表示每个数据源预加载的分区数。工作线程将会在后台预加载分区到内存中。如果预载分区过多可能导致内存不足, 默认值为 2
  • groupPoolSize(int): 如果指定 groupColgroupScheme,所有数据将被划分为若干个数据源,并在其中选择 groupPoolSize 个数据源准备数据。当有数据源中的数据被全部使用,新的数据源将被加入,直至所有数据源中的数据都被使用,默认值为 3