学習する(改良6:Permutation Feature Importance 実装を高速化・安定化)

データ収集

Permutation Feature Importance 実装を高速化・安定化しました 🚀

前回まで使っていた Permutation Feature Importance (PFI) のコードを、
より高速で安定して動作するように全面的に見直しました。

今回の改良とバグ修正のポイントをまとめます👇


🔧 バグ修正

  1. グローバル変数 _RUNTIME_READY の扱い修正
    • 旧コードでは _RUNTIME_READY__RUNTIME_READY が混在しており、正しく初期化されないケースがありました。
    • これを一貫して _RUNTIME_READY に統一し、初期化漏れによる誤動作を防止しました。
  2. feature_names_seq の末尾チェック削除
    • 以前は assert で末尾 3 カラムを強制していましたが、構成が変わった場合に不用意に落ちることがありました。
    • 今回はチェックを外して柔軟に対応できるようにしています。

⚡ 高速化

  1. 繰り返し用シャッフルの前計算(precomputation)
    • 数値特徴量・カテゴリ特徴量ともに「レース内 S 軸シャッフル」のインデックスを事前に展開して保持。
    • これにより各特徴量で毎回ランダム生成&ループするコストを削減しました。
    ➡️ 大規模データでも ループ部分がかなり高速化
  2. NumPy の一括代入に最適化
    • 以前は for rid, (idxs, order) のループで都度代入していましたが、
      今回は rows_all / cols_all / src_all を事前に作り、一括で mut_col[...] = base_col[...] する方式に変更。
    ➡️ Python ループを除去し、メモリ効率も改善。
  3. time.time()time.perf_counter()
    • 計測精度を上げるために変更。細かい改善ですが、PFI のような「特徴量ごと数十秒かかる処理」では有効です。

🛠️ 改良

  1. 自動バッチサイズ調整の安定化
    • OOM(CUDA メモリ不足)が出たときのリトライ戦略を明確化。
    • キャッシュ済み DataLoader をリセットして再構築するように修正しました。
    ➡️ GPU メモリ状況が変動しても 途中で落ちにくい
  2. 繰り返し回数の削減(5 → 3)
    • 平均化によるノイズ低減は維持しつつ、計算時間を短縮するために 3 回平均に。
    • 本番環境での実行時間が 4割以上短縮されました。
  3. ランダム生成の統一管理
    • np.random.default_rng(42) を使って再現性を確保。
    • 以前よりも 実験ごとの結果が安定

📊 結果出力の改善

  • 数値・カテゴリ特徴量をそれぞれ影響度でソートして表示。
  • 「NDCG@3 の低下量」を 平均 ± SD 形式で出力し、矢印で重要度を直感的に把握可能にしました。

✅ まとめ

今回の改良により、

  • 計算速度が大幅に向上(特徴量数が多いほど効果的)
  • OOM に強く安定
  • 結果の再現性・解釈性も向上

という、実運用に耐える PFI 実装になりました。

次は、このフレームワークを使って「特徴量選択」や「モデル改善」にどこまで役立てられるかを探っていきたいと思います。


👉 あなたの環境でも PFI が遅い・不安定…と感じているなら、この実装方法をぜひ試してみてください!


# ================= Permutation Feature Importance (NDCG@3, 3回平均, S軸シャッフル) =================

with open("scale_stats.json", "r", encoding="utf-8") as f:
    ss = json.load(f)
keep_for_seq = np.array(ss["keep_for_seq"], dtype=np.int64)

with open("embedding_info.json", "r", encoding="utf-8") as f:
    embedding_info = json.load(f)
embedding_cols_all = list(embedding_info["embedding_cols"])
embedding_cols_pfi = [c for c in embedding_cols_all if c != "馬ID" and c in X_cat_val]

try:
    with open("final_feature_names_seq.json", "r", encoding="utf-8") as f:
        feature_names_seq = json.load(f)
    print("✅ final_feature_names_seq 保存済み")
except FileNotFoundError:
    with open("main_feature_order.json", "r", encoding="utf-8") as f:
        feature_columns = json.load(f)
    numeric_cols_final = [feature_columns[i] for i in keep_for_seq]
    numeric_cols_final = [c for c in numeric_cols_final if c not in set(embedding_cols_all)]
    extra_cols = ["similarity_max", "history_len_norm_seq", "pad_flag"]
    feature_names_seq = numeric_cols_final + extra_cols
    with open("final_feature_names_seq.json", "w", encoding="utf-8") as f:
        json.dump(feature_names_seq, f, ensure_ascii=False, indent=2)
    print("✅ final_feature_names_seq 保存済み")

numeric_cols_final = feature_names_seq[:-3]
num_feat_to_idx = {name: i for i, name in enumerate(numeric_cols_final)}  # 0..D_main-1

race_to_indices = defaultdict(list)
for i, rid in enumerate(race_ids_val):
    race_to_indices[str(rid)].append(i)
race_to_indices = {k: np.asarray(v, dtype=np.int64) for k, v in race_to_indices.items()}

Xv_base = X_val_seq 
Xv_mut  = X_val_seq.copy() 
Xc_mut  = {k: v.copy() for k, v in X_cat_val.items()} 

# ===== 自動バッチサイズ調整付き Loader & 推論 =====

_RUNTIME_READY = False
_LOADER_CACHE  = None
_LOADER_BS     = None

def _setup_runtime_fast():
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        if hasattr(torch.backends, "cudnn"):
            torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass
    try:
        torch.set_num_threads(min(8, os.cpu_count() or 8))
        torch.set_num_interop_threads(min(8, os.cpu_count() or 8))
    except Exception:
        pass

def _ensure_runtime():
    global _RUNTIME_READY
    if not _runtime_is_ready():
        _setup_runtime_fast()
        _RUNTIME_READY = True

def _runtime_is_ready():
    return _RUNTIME_READY

def _pick_amp_dtype():
    # RTX 30番台は FP16 が無難(bfloat16 は多くの機種で遅い/非対応)
    if not torch.cuda.is_available():
        return torch.float32
    try:
        if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
            return torch.bfloat16
    except Exception:
        pass
    return torch.float16

def _build_loader(bs: int):
    ds = HorseDataset(
        X_num=Xv_mut,
        X_cat=Xc_mut,
        race_ids=race_ids_val,
        horse_ids=horse_ids_val,
        y=y_val,
        history_lengths=history_lengths_val,
    )
    return DataLoader(
        ds,
        batch_size=bs,
        shuffle=False,
        collate_fn=grouped_race_collate_fn,
        num_workers=0,  # Windows安全
    )

def reset_loader_cache():
    global _LOADER_CACHE, _LOADER_BS
    _LOADER_CACHE, _LOADER_BS = None, None

def make_loader_from_mutables(batch_size: int | None = None):
    _ensure_runtime()
    global _LOADER_CACHE, _LOADER_BS
    if batch_size is None:
        if _LOADER_CACHE is None:
            _LOADER_CACHE = _build_loader(1024)
            _LOADER_BS = 1024
        return _LOADER_CACHE
    if _LOADER_CACHE is None or _LOADER_BS != batch_size:
        _LOADER_CACHE = _build_loader(batch_size)
        _LOADER_BS = batch_size
    return _LOADER_CACHE

def _is_cuda_oom(err: Exception) -> bool:
    m = str(err).lower()
    return ("out of memory" in m) or ("cuda error: out of memory" in m)

def eval_ndcg3_with_mutables(model, loader=None):
    _ensure_runtime()
    use_cuda = torch.cuda.is_available()
    amp_dtype = _pick_amp_dtype()
    autocast_ctx = (torch.amp.autocast(device_type="cuda", dtype=amp_dtype)
                    if use_cuda else contextlib.nullcontext())
    if loader is None:
        loader = make_loader_from_mutables()
    model.eval()
    with torch.inference_mode():
        return run_validation(model, loader, device, class_weights=None, autocast_ctx=autocast_ctx)

def eval_ndcg3_with_auto_batch(model, prefer_bs: int = 4096, floor_bs: int = 128):
    candidates_raw = [prefer_bs, 3072, 2048, 1536, 1024, 768, 512, 384, 256, floor_bs]
    seen = set()
    candidates = [bs for bs in candidates_raw if (isinstance(bs, int) and bs > 0 and (bs not in seen) and not seen.add(bs))]
    for bs in candidates:
        try:
            loader = make_loader_from_mutables(batch_size=bs)
            val = eval_ndcg3_with_mutables(model, loader)
            print(f"[auto-batch] OK with batch_size={bs}", flush=True)
            return val
        except RuntimeError as e:
            if _is_cuda_oom(e):
                print(f"[auto-batch] OOM at batch_size={bs} → shrink & retry", flush=True)
                torch.cuda.empty_cache()
                reset_loader_cache()
                continue
            raise
    print("[auto-batch] fallback to batch_size=128", flush=True)
    loader = make_loader_from_mutables(batch_size=128)
    return eval_ndcg3_with_mutables(model, loader)


# ベースライン
baseline_ndcg = eval_ndcg3_with_auto_batch(model, prefer_bs=4096)
print(f"\n📊 Baseline NDCG@3: {baseline_ndcg:.4f}", flush=True)

n_repeats = 3
rng = np.random.default_rng(42)

T_num = int(Xv_mut.shape[1])
arange_T_num = np.arange(T_num, dtype=np.int64)

# ---- per-race permutation(数値用)----
race_perm_by_rep_num = []
for rep in range(n_repeats):
    race_perm = {}
    for rid, idxs in race_to_indices.items():
        L = idxs.size
        if L < 2:
            continue
        order = rng.random((L, T_num)).argsort(axis=0)
        race_perm[rid] = (idxs, order)
    race_perm_by_rep_num.append(race_perm)

# ---- 一括インデクス(数値用)----
precomp_index_by_rep = []
for rep in range(n_repeats):
    race_perm = race_perm_by_rep_num[rep]
    if not race_perm:
        precomp_index_by_rep.append((None, None, None))
        continue

    total_elems = sum(idxs.size * T_num for idxs, _ in race_perm.values())
    rows_all = np.empty(total_elems, dtype=np.int64)
    cols_all = np.empty(total_elems, dtype=np.int64)
    src_all  = np.empty(total_elems, dtype=np.int64)

    pos = 0
    for _, (idxs, order) in race_perm.items():
        L = idxs.size
        if L < 2:
            continue
        block = L * T_num
        end = pos + block
        rows_all[pos:end] = np.tile(idxs, T_num)
        cols_all[pos:end] = np.repeat(arange_T_num, L)
        src_all[pos:end]  = idxs[order].reshape(-1, order='F')
        pos = end

    precomp_index_by_rep.append((rows_all, cols_all, src_all))

_eval   = eval_ndcg3_with_mutables
_loader = make_loader_from_mutables() 
precomp = precomp_index_by_rep
nrep    = n_repeats

_base = Xv_base
_mut  = Xv_mut

num_results = []
total_num   = len(num_feat_to_idx)
print(f"\n🚀 Numeric PFI: {total_num} features × {nrep} repeats", flush=True)

# ---- Numeric PFI ----
for kk, (feat_name, j) in enumerate(num_feat_to_idx.items(), start=1):
    t0 = time.perf_counter()

    base_col = _base[:, :, j]
    mut_col  = _mut[:, :, j]
    scores   = np.empty(nrep, dtype=np.float32)

    for rep in range(nrep):
        np.copyto(mut_col, base_col)
        rows_all, cols_all, src_all = precomp[rep]
        if rows_all is not None:
            mut_col[rows_all, cols_all] = base_col[src_all, cols_all]
        scores[rep] = _eval(model, _loader)

    drop = baseline_ndcg - scores
    num_results.append((feat_name, float(drop.mean()), float(drop.std())))

    dt = time.perf_counter() - t0
    print(f"✅ done: {feat_name}  ΔNDCG@3 = {drop.mean():+.4f} ± {drop.std():.4f}  ({dt:.1f}s)", flush=True)

# ---- Categorical PFI ----
cat_results = []
total_cat   = len(embedding_cols_pfi)
print(f"\n🚀 Categorical PFI: {total_cat} features × {nrep} repeats", flush=True)

_perm_cache = {
    f"T={T_num}": race_perm_by_rep_num,
    f"IDX={T_num}": precomp_index_by_rep,
}

def ensure_perm_cache_for_T(T_needed, cache_dict):
    key_perm = f"T={T_needed}"
    key_idx  = f"IDX={T_needed}"
    if key_perm in cache_dict and key_idx in cache_dict:
        return cache_dict[key_perm]

    race_perm_by_rep = []
    for rep in range(n_repeats):
        race_perm = {}
        for rid, idxs in race_to_indices.items():
            L = idxs.size
            if L < 2:
                continue
            order = rng.random((L, T_needed)).argsort(axis=0)
            race_perm[rid] = (idxs, order)
        race_perm_by_rep.append(race_perm)

    precomp_index_by_rep = []
    arange_T = np.arange(T_needed, dtype=np.int64)
    for rep in range(n_repeats):
        race_perm = race_perm_by_rep[rep]
        if not race_perm:
            precomp_index_by_rep.append((None, None, None))
            continue

        rows_list, cols_list, src_list = [], [], []
        for _, (idxs, order) in race_perm.items():
            L = idxs.size
            rows = np.tile(idxs, T_needed)
            cols = np.repeat(arange_T, L)
            src_rows = idxs[order].reshape(-1, order='F')
            rows_list.append(rows); cols_list.append(cols); src_list.append(src_rows)

        rows_all = np.concatenate(rows_list).astype(np.int64, copy=False)
        cols_all = np.concatenate(cols_list).astype(np.int64, copy=False)
        src_all  = np.concatenate(src_list).astype(np.int64, copy=False)
        precomp_index_by_rep.append((rows_all, cols_all, src_all))

    cache_dict[key_perm] = race_perm_by_rep
    cache_dict[key_idx]  = precomp_index_by_rep
    return race_perm_by_rep

for kk, col in enumerate(embedding_cols_pfi, start=1):
    if col not in Xc_mut:
        continue
    t0 = time.perf_counter()
    scores = []

    T_cat = int(Xc_mut[col].shape[1])
    ensure_perm_cache_for_T(T_cat, _perm_cache)
    precomp_idx_by_rep = _perm_cache[f"IDX={T_cat}"]

    base_mat = X_cat_val[col]
    mut_mat  = Xc_mut[col]

    for rep in range(n_repeats):
        mut_mat[:, :] = base_mat
        rows_all, cols_all, src_all = precomp_idx_by_rep[rep]
        if rows_all is not None:
            mut_mat[rows_all, cols_all] = base_mat[src_all, cols_all]
        scores.append(_eval(model, _loader))

    scores = np.asarray(scores, dtype=np.float32)
    drop = baseline_ndcg - scores
    cat_results.append((col, float(drop.mean()), float(drop.std())))

    dt = time.perf_counter() - t0
    print(f"✅ done: {col}  ΔNDCG@3 = {drop.mean():+.4f} ± {drop.std():.4f}  ({dt:.1f}s)", flush=True)

# ---- 出力整形 ----
num_results.sort(key=lambda x: (-x[1], x[2]))
cat_results.sort(key=lambda x: (-x[1], x[2]))

print("\n📋 数値特徴量の影響(NDCG@3 低下量: 平均±SD):", flush=True)
for name, m, s in num_results:
    arrow = "⬆️" if m > 0 else "⬇️"
    print(f"  {name:20s} drop = {m:+.4f} ± {s:.4f} {arrow}", flush=True)

print("\n📋 カテゴリ特徴量の影響(NDCG@3 低下量: 平均±SD):", flush=True)
for name, m, s in cat_results:
    arrow = "⬆️" if m > 0 else "⬇️"
    print(f"  {name:20s} drop = {m:+.4f} ± {s:.4f} {arrow}", flush=True)

print("\n✅ PFI (NDCG@3 × 3回平均, レース内S軸シャッフル, 馬ID除外) 完了", flush=True)

コメント

タイトルとURLをコピーしました