学習する(11:ヘイズワイズ、トップ1、nDCG@K)

学習

ペア・トップ1・上位Kの3軸から、レース内ランキングを統計的に鍛える損失パック

レース内のスコア(preds)と着順ラベル(targets)に対して、ランキング学習の3種類の損失を実装しています。目的は、

  1. ペアワイズで正しい並びを学習(勝ち負けの順序)
  2. トップ1(勝者)を確率的に当てる
  3. 上位K件の並びの良さ(nDCG@K)を最大化
    の3方向から予測を鍛えることです。
def weighted_pairwise_ranking_loss(preds, targets, weights=None):
    N = preds.size(0)
    if N < 2:
        return torch.tensor(0.0, device=preds.device)
    i, j = torch.triu_indices(N, N, offset=1, device=preds.device)
    diff = torch.sign(targets[j] - targets[i])
    pred_diff = preds[i] - preds[j]
    if weights is not None:
        w = (weights[i] + weights[j]) * 0.5
    else:
        w = 1.0
    return torch.nn.functional.softplus(-diff * pred_diff).mul(w).mean()

def soft_top1_loss(preds: torch.Tensor,
                   targets: torch.Tensor,
                   class_weights: torch.Tensor | None = None) -> torch.Tensor:
    preds   = preds.reshape(-1)
    targets = targets.reshape(-1).to(preds.device)

    if targets.numel() == 0:
        return torch.tensor(0.0, device=preds.device)

    probs = torch.softmax(preds, dim=0)
    eps = 1e-12

    if torch.all((targets == 0) | (targets == 1)) and targets.sum() > 0:
        mask = targets.float() / targets.sum()
        return -torch.sum(mask * torch.log(probs + eps))

    min_val = targets.min()
    winner_mask = (targets == min_val).float()
    winner_mask = winner_mask / winner_mask.sum()

    weight = torch.tensor(1.0, device=preds.device)
    if class_weights is not None and class_weights.numel() > 0:
        idx = int(min_val.item()) - 1
        if 0 <= idx < class_weights.numel():
            weight = class_weights[idx].to(preds.device)

    return -weight * torch.sum(winner_mask * torch.log(probs + eps))

def hard_ndcg3_loss_torch(preds: torch.Tensor,
                          targets: torch.Tensor,
                          mask: torch.Tensor | None = None,
                          k: int = 3) -> torch.Tensor:

    if preds.dim() == 1:
        preds = preds.unsqueeze(0)
        targets = targets.unsqueeze(0)
        if mask is not None and mask.dim() == 1:
            mask = mask.unsqueeze(0)

    R, S = preds.shape
    device = preds.device
    dtype = preds.dtype

    if mask is None:
        mask = torch.zeros((R, S), dtype=torch.bool, device=device)

    targ = targets.to(torch.long)
    rel = (targ.max(dim=1, keepdim=True).values - targ + 1).to(preds.dtype)

    valid = ~mask
    preds_masked = preds.masked_fill(~valid, float("-inf"))
    rel_masked   = rel.masked_fill(~valid, 0.0)

    kk = min(k, S)

    topk_idx = torch.topk(preds_masked, k=kk, dim=1).indices  
    gains    = rel_masked.gather(1, topk_idx) 
    discounts = 1.0 / torch.log2(torch.arange(2, kk + 2, device=device, dtype=dtype))
    dcg = (gains * discounts.unsqueeze(0)).sum(dim=1)  

    ideal_idx = torch.topk(rel_masked, k=kk, dim=1).indices
    ideal_gains = rel_masked.gather(1, ideal_idx)
    idcg = (ideal_gains * discounts.unsqueeze(0)).sum(dim=1) 

    ndcg = torch.where(idcg > 0, dcg / idcg, torch.ones_like(dcg))
    return (1.0 - ndcg.mean()).to(dtype)

1) weighted_pairwise_ranking_loss:ペアワイズ順位損失(重み付き)

  • 何をしている?
    レース内の全ペア (i,j)(i,j)(i,j) について、真の順序diff = sign(targets[j] - targets[i]))と 予測差pred_diff = preds[i]-preds[j])の向きが一致するように学習させます。
    損失は softplus(-diff * pred_diff)ロジスティック型マージンで、順序が正しいほど小さくなります。weights があれば (i,j)(i,j)(i,j) の重要度を平均で掛けます。
  • 実装の肝
    • torch.triu_indices上三角ペア(重複なし)を一括生成 → 完全ベクトル化
    • 2頭未満は比較できないので 0 を返して早期終了。
  • 統計学的な見立て
    • Bradley–Terry/Luce 型のペア比較モデルに相当(「勝つ確率」をスコア差のロジスティックで表現)。
    • AUC最大化に近い目標(順序一致を増やす)。weights は重要ペアのサンプリング確率調整異質性の反映に使える。

2) soft_top1_loss:トップ1(勝者)を当てるソフトクロスエントロピー

  • 何をしている?
    レース内スコアに softmax をかけて 「この馬が勝者である確率」 に変換し、
    • 二値ラベル(0/1、かつ1が存在)なら、1ラベルを 確率分布に正規化してクロスエントロピー(複数1があれば等分重み)。
    • 順位ラベル(小さいほど良い)なら、最小値(最上位)にある馬群を勝者集合として同様に確率分布化して学習。
      class_weights があれば、勝者ランク種別に重みづけを行います。
  • 実装の肝
    • probs = softmax(preds)1レース内で計算(dim=0)。
    • 同着(最小値が複数)でも、勝者確率を一様分配して安定学習。
    • eps を足して 数値安定
  • 統計学的な見立て
    • 多項ロジット/ Plackett–LuceTop-1 尤度最大化に相当(勝者の対数尤度の最大化)。
    • class_weights事前分布(クラス頻度)に基づく重み付き尤度として解釈可能。

3) hard_ndcg3_loss_torch:nDCG@K(ここではK=3)に基づく損失

  • 何をしている?
    レース内の preds上位K頭topk 抽出し、targets から作る 関連度 rel(小さい着順ほど高関連)を使って DCG@K=∑r=1Kgainrlog⁡2(r+1)\mathrm{DCG}@K = \sum_{r=1}^{K} \frac{\mathrm{gain}_r}{\log_2(r+1)}DCG@K=r=1∑K​log2​(r+1)gainr​​ を計算。理想並びの IDCG@K で割って nDCG@K を求め、損失は 1 - mean(nDCG@K)
    mask でパディング馬(出走しないスロット)を無視します。
  • 実装の肝
    • rel = max(targets) - targets + 1小さい着順ほど大きい関連度に変換。
    • masked_fill(-inf / 0)パディング除外し、topkgather高速にDCGを計算。
    • **非微分(hard top-k)**だが、評価指標一致の利点がある。
  • 統計学的な見立て
    • 情報検索のランキング指標最適化位置割引つき利得の最大化)。
    • 「上位ほど価値が大きい」現実の配当・注目度に整合。滑らかな代替としては SoftRank / SmoothDCG / LambdaRank などがあるが、ここでは指標そのものを直接最適化する近似。

使い分けの指針(ざっくり)

  • 順序全体の整合を重視 → weighted_pairwise_ranking_loss
  • **勝者当て(Top-1)**を重視 → soft_top1_loss
  • 上位K重視の実運用指標を最適化 → hard_ndcg3_loss_torch

コメント

タイトルとURLをコピーしました