学習する(10:パディング、マスク)

学習

馬ごとの履歴をレース単位でまとめ、時系列パディング+レース内パディングまで施した完全バッチ化関数

今回は、同じレースに属する馬をひとまとめにしてバッチ化するカスタム collate_fn です。
DataLoader に渡すことで、レース単位の集合入力をそのまま扱える形式に整えています。

def grouped_race_collate_fn(batch, sequence_length=3):
    races = defaultdict(list)
    for item in batch:
        races[item["race_id"]].append(item)

    batch_X, batch_y, batch_hist = [], [], []
    batch_cat = defaultdict(list)
    batch_ids = []  # [R, S]

    for race_items in races.values():
        X_list, y_list, hist_list = [], [], []
        cat_temp = defaultdict(list)
        ids_list = []

        for i in race_items:
            x_num = i["X_num"] 

            # --- 時系列Tのパディング/切り詰め ---
            if x_num.size(0) < sequence_length:
                pad_len = sequence_length - x_num.size(0)
                x_num = torch.cat([x_num, x_num[-1:].repeat(pad_len, 1)], dim=0)
            elif x_num.size(0) > sequence_length:
                x_num = x_num[-sequence_length:]
            X_list.append(x_num) 

            y_list.append(torch.as_tensor(i["y"], dtype=torch.long))

            ids_list.append(torch.as_tensor(i["horse_id"], dtype=torch.long))

            if i.get("hist", None) is not None:
                hist_list.append(torch.as_tensor(i["hist"], dtype=torch.float32))

            for col, val in i.items():
                if col in ["X_num", "y", "race_id", "hist", "horse_id"]:
                    continue
                v = val if isinstance(val, torch.Tensor) else torch.tensor(val, dtype=torch.long)
                if v.dim() == 0:
                    v = v.unsqueeze(0)
                if v.size(0) < sequence_length:
                    v = torch.cat([v, v[-1:].repeat(sequence_length - v.size(0))], dim=0)
                elif v.size(0) > sequence_length:
                    v = v[-sequence_length:]
                cat_temp[col].append(v) 

        batch_X.append(torch.stack(X_list, dim=0))     
        batch_y.append(torch.stack(y_list, dim=0))         
        batch_ids.append(torch.stack(ids_list, dim=0))

        if hist_list:
            batch_hist.append(torch.stack(hist_list, dim=0))
        else:
            batch_hist.append(torch.zeros((len(X_list), 1), dtype=torch.float32))

        for col in cat_temp:
            batch_cat[col].append(torch.stack(cat_temp[col], dim=0)) 
    X_pack   = pad_sequence(batch_X,  batch_first=True)                        
    y_pack   = pad_sequence(batch_y,  batch_first=True, padding_value=-1)    
    ids_pack = pad_sequence(batch_ids, batch_first=True, padding_value=-1)    
    mask_pack = (ids_pack == -1)                
    hist_pack = pad_sequence(batch_hist, batch_first=True, padding_value=0.0) if batch_hist[0] is not None else None
    for col in batch_cat:
        batch_cat[col] = pad_sequence(batch_cat[col], batch_first=True, padding_value=0) 

    return X_pack, y_pack, hist_pack, dict(batch_cat), ids_pack, mask_pack

この関数 grouped_race_collate_fn は、PyTorch の DataLoaderコラテ関数(collate_fn) で、ミニバッチを 「レースごとに馬をまとめたテンソル」 に変換します。
特に、**各馬の時系列履歴(T ステップ)**を統一長に揃え、**レース単位(S 頭)**でパディングし、モデルに投入できる形を作ります。


主な処理ステップ

  1. レース単位でグループ化
races = defaultdict(list)
for item in batch:
    races[item["race_id"]].append(item)

  1. 馬ごとの時系列パディング/切り詰め
    • 不足時:最後の時点を複製して延長(パディング)。
    • 過剰時:直近 sequence_length 分だけ残す。
      → 各馬の履歴長を 固定長 T に揃える。
if x_num.size(0) < sequence_length:
    pad_len = sequence_length - x_num.size(0)
    x_num = torch.cat([x_num, x_num[-1:].repeat(pad_len, 1)], dim=0)
elif x_num.size(0) > sequence_length:
    x_num = x_num[-sequence_length:]

  1. 教師ラベル・ID・履歴長・カテゴリ列の整理
    • y(着順ラベル)、horse_idhist(履歴長)をテンソル化。
    • カテゴリ列(例:騎手ID・競馬場IDなど)は同様にパディング/切り詰め。
    • cat_temp に列ごとにまとめてからスタック。

  1. レース単位のテンソル化
    • 1レース内の全馬をまとめて [S, T, D] 形式のテンソルにする。
    • hist_list が無い場合はゼロ埋めを挿入。
batch_X.append(torch.stack(X_list, dim=0))  # [S, T, D]
batch_y.append(torch.stack(y_list, dim=0))  # [S]
batch_ids.append(torch.stack(ids_list, dim=0)) # [S]

  1. レース間のパディング(pad_sequence)
    • レース数 R が異なるため、pad_sequence で揃える。
    • ids_pack-1 埋めを基に マスクテンソルを作り、モデルに渡すことで「パディング部分を無視」できる。
    • カテゴリ列(batch_cat)も同様に pad_sequence で [R,S,T] に揃える。
X_pack   = pad_sequence(batch_X,  batch_first=True)         # [R, S, T, D]
y_pack   = pad_sequence(batch_y,  batch_first=True, padding_value=-1)
ids_pack = pad_sequence(batch_ids, batch_first=True, padding_value=-1)
mask_pack = (ids_pack == -1)  # パディング位置のマスク

出力

戻り値は以下の6要素:

  1. X_pack:数値特徴([R, S, T, D])
  2. y_pack:ラベル([R, S]、-1 がパディング)
  3. hist_pack:履歴長([R, S, 1])
  4. dict(batch_cat):カテゴリ特徴(列ごとに [R, S, T])
  5. ids_pack:馬ID([R, S]、-1 がパディング)
  6. mask_pack:パディング位置マスク([R, S]、bool)

✨ まとめ

この collate_fn により、

  • 馬ごとの履歴長を統一(T)
  • レース内の馬数を統一(S、パディングで揃える)
  • レース単位でミニバッチ化(R)
  • 数値・カテゴリ・メタ情報を一貫して整列

が実現され、モデル側は [R, S, T, D] + マスク という統一入力を前提に処理できるようになります。

コメント

タイトルとURLをコピーしました