馬ごとの履歴をレース単位でまとめ、時系列パディング+レース内パディングまで施した完全バッチ化関数
今回は、同じレースに属する馬をひとまとめにしてバッチ化するカスタム collate_fn
です。DataLoader
に渡すことで、レース単位の集合入力をそのまま扱える形式に整えています。
def grouped_race_collate_fn(batch, sequence_length=3):
races = defaultdict(list)
for item in batch:
races[item["race_id"]].append(item)
batch_X, batch_y, batch_hist = [], [], []
batch_cat = defaultdict(list)
batch_ids = [] # [R, S]
for race_items in races.values():
X_list, y_list, hist_list = [], [], []
cat_temp = defaultdict(list)
ids_list = []
for i in race_items:
x_num = i["X_num"]
# --- 時系列Tのパディング/切り詰め ---
if x_num.size(0) < sequence_length:
pad_len = sequence_length - x_num.size(0)
x_num = torch.cat([x_num, x_num[-1:].repeat(pad_len, 1)], dim=0)
elif x_num.size(0) > sequence_length:
x_num = x_num[-sequence_length:]
X_list.append(x_num)
y_list.append(torch.as_tensor(i["y"], dtype=torch.long))
ids_list.append(torch.as_tensor(i["horse_id"], dtype=torch.long))
if i.get("hist", None) is not None:
hist_list.append(torch.as_tensor(i["hist"], dtype=torch.float32))
for col, val in i.items():
if col in ["X_num", "y", "race_id", "hist", "horse_id"]:
continue
v = val if isinstance(val, torch.Tensor) else torch.tensor(val, dtype=torch.long)
if v.dim() == 0:
v = v.unsqueeze(0)
if v.size(0) < sequence_length:
v = torch.cat([v, v[-1:].repeat(sequence_length - v.size(0))], dim=0)
elif v.size(0) > sequence_length:
v = v[-sequence_length:]
cat_temp[col].append(v)
batch_X.append(torch.stack(X_list, dim=0))
batch_y.append(torch.stack(y_list, dim=0))
batch_ids.append(torch.stack(ids_list, dim=0))
if hist_list:
batch_hist.append(torch.stack(hist_list, dim=0))
else:
batch_hist.append(torch.zeros((len(X_list), 1), dtype=torch.float32))
for col in cat_temp:
batch_cat[col].append(torch.stack(cat_temp[col], dim=0))
X_pack = pad_sequence(batch_X, batch_first=True)
y_pack = pad_sequence(batch_y, batch_first=True, padding_value=-1)
ids_pack = pad_sequence(batch_ids, batch_first=True, padding_value=-1)
mask_pack = (ids_pack == -1)
hist_pack = pad_sequence(batch_hist, batch_first=True, padding_value=0.0) if batch_hist[0] is not None else None
for col in batch_cat:
batch_cat[col] = pad_sequence(batch_cat[col], batch_first=True, padding_value=0)
return X_pack, y_pack, hist_pack, dict(batch_cat), ids_pack, mask_pack
この関数 grouped_race_collate_fn
は、PyTorch の DataLoader
用 コラテ関数(collate_fn) で、ミニバッチを 「レースごとに馬をまとめたテンソル」 に変換します。
特に、**各馬の時系列履歴(T ステップ)**を統一長に揃え、**レース単位(S 頭)**でパディングし、モデルに投入できる形を作ります。
主な処理ステップ
- レース単位でグループ化
races = defaultdict(list)
for item in batch:
races[item["race_id"]].append(item)
- 馬ごとの時系列パディング/切り詰め
- 不足時:最後の時点を複製して延長(パディング)。
- 過剰時:直近
sequence_length
分だけ残す。
→ 各馬の履歴長を 固定長 T に揃える。
if x_num.size(0) < sequence_length:
pad_len = sequence_length - x_num.size(0)
x_num = torch.cat([x_num, x_num[-1:].repeat(pad_len, 1)], dim=0)
elif x_num.size(0) > sequence_length:
x_num = x_num[-sequence_length:]
- 教師ラベル・ID・履歴長・カテゴリ列の整理
y
(着順ラベル)、horse_id
、hist
(履歴長)をテンソル化。- カテゴリ列(例:騎手ID・競馬場IDなど)は同様にパディング/切り詰め。
cat_temp
に列ごとにまとめてからスタック。
- レース単位のテンソル化
- 1レース内の全馬をまとめて [S, T, D] 形式のテンソルにする。
hist_list
が無い場合はゼロ埋めを挿入。
batch_X.append(torch.stack(X_list, dim=0)) # [S, T, D]
batch_y.append(torch.stack(y_list, dim=0)) # [S]
batch_ids.append(torch.stack(ids_list, dim=0)) # [S]
- レース間のパディング(pad_sequence)
- レース数 R が異なるため、
pad_sequence
で揃える。 ids_pack
の-1
埋めを基に マスクテンソルを作り、モデルに渡すことで「パディング部分を無視」できる。- カテゴリ列(
batch_cat
)も同様にpad_sequence
で [R,S,T] に揃える。
- レース数 R が異なるため、
X_pack = pad_sequence(batch_X, batch_first=True) # [R, S, T, D]
y_pack = pad_sequence(batch_y, batch_first=True, padding_value=-1)
ids_pack = pad_sequence(batch_ids, batch_first=True, padding_value=-1)
mask_pack = (ids_pack == -1) # パディング位置のマスク
出力
戻り値は以下の6要素:
X_pack
:数値特徴([R, S, T, D])y_pack
:ラベル([R, S]、-1
がパディング)hist_pack
:履歴長([R, S, 1])dict(batch_cat)
:カテゴリ特徴(列ごとに [R, S, T])ids_pack
:馬ID([R, S]、-1
がパディング)mask_pack
:パディング位置マスク([R, S]、bool)
✨ まとめ
この collate_fn
により、
- 馬ごとの履歴長を統一(T)
- レース内の馬数を統一(S、パディングで揃える)
- レース単位でミニバッチ化(R)
- 数値・カテゴリ・メタ情報を一貫して整列
が実現され、モデル側は [R, S, T, D] + マスク という統一入力を前提に処理できるようになります。
コメント