まず、最初にコードから。。。。
@torch.no_grad()
def preprocess_like_training(
X,
stats: dict,
*,
inplace: bool = False,
out_dtype_torch: torch.dtype = torch.float32,
return_numpy_if_numpy_input: bool = True,
):
TAIL = int(stats["TAIL"])
keep_numeric_idx = stats["keep_numeric_idx"]
means_np = stats["means"].astype(np.float32)
stds_np = stats["stds"].astype(np.float32)
keep_for_seq = stats["keep_for_seq"]
stds_np_safe = np.where(stds_np == 0.0, 1.0, stds_np)
if isinstance(X, np.ndarray):
Xw = X if inplace else X.copy()
if Xw.ndim == 2:
Xw = Xw[None, ...]
squeeze_back = True
elif Xw.ndim == 3:
squeeze_back = False
else:
raise ValueError("X は (S,D_all) か (B,S,D_all)")
assert Xw.shape[-1] >= TAIL, "TAIL 列不足(列順要確認)"
X_main = Xw[..., :-TAIL] if TAIL > 0 else Xw # ★ TAIL=0時も全部をメイン扱い
X_tail = Xw[..., -TAIL:] if TAIL > 0 else Xw[..., :0]
# 標準化(安全版)
part = X_main[:, :, keep_numeric_idx]
part = (part - means_np.reshape(1,1,-1)) / stds_np_safe.reshape(1,1,-1)
X_main[:, :, keep_numeric_idx] = part
X_main_kept = np.take(X_main, keep_for_seq, axis=2)
X_seq = np.concatenate([X_main_kept, X_tail], axis=2) if TAIL > 0 else X_main_kept
# ★ 非有界値サニタイズ
X_seq = np.where(np.isfinite(X_seq), X_seq, 0.0).astype(np.float32, copy=False)
if squeeze_back:
X_seq = X_seq[0]
return X_seq if return_numpy_if_numpy_input else torch.from_numpy(X_seq)
if torch.is_tensor(X):
Xw = X if inplace else X.clone()
if Xw.dim() == 2:
Xw = Xw.unsqueeze(0)
squeeze_back = True
elif Xw.dim() == 3:
squeeze_back = False
else:
raise ValueError("X は (S,D_all) か (B,S,D_all)")
device, dtype = Xw.device, Xw.dtype
kni = torch.as_tensor(keep_numeric_idx, dtype=torch.long, device=device)
means_t = torch.as_tensor(means_np, dtype=dtype, device=device)
stds_t = torch.as_tensor(stds_np, dtype=dtype, device=device)
# ★ 分母0対策
stds_t = torch.where(stds_t == 0, torch.ones_like(stds_t), stds_t)
kfs = torch.as_tensor(keep_for_seq, dtype=torch.long, device=device)
assert Xw.size(-1) >= TAIL, "TAIL 列不足(列順要確認)"
X_main = Xw[..., :-TAIL] if TAIL > 0 else Xw
X_tail = Xw[..., -TAIL:] if TAIL > 0 else Xw[..., :0]
part = X_main.index_select(-1, kni)
part.sub_(means_t).div_(stds_t)
X_main.index_copy_(-1, kni, part)
X_main_kept = X_main.index_select(-1, kfs)
X_seq = torch.cat([X_main_kept, X_tail], dim=-1) if TAIL > 0 else X_main_kept
# ★ 非有界値サニタイズ
X_seq = torch.where(torch.isfinite(X_seq), X_seq, torch.zeros_like(X_seq))
if X_seq.dtype != out_dtype_torch:
X_seq = X_seq.to(out_dtype_torch)
if squeeze_back:
X_seq = X_seq.squeeze(0)
return X_seq
raise TypeError("X は np.ndarray か torch.Tensor を指定")
stats = load_scale_stats("scale_stats.json")
X_seq = preprocess_like_training(
X_sequences_raw,
stats,
inplace=False,
out_dtype_torch=torch.float32,
return_numpy_if_numpy_input=True
)
機械学習モデルを実運用する時に必ず直面するのが、「学習時と推論時の前処理をどう揃えるか」という問題です。
学習時には平均・標準偏差で正規化したり、特定の列を残したりといった処理をしていますが、これを推論時にも一貫して再現しないと、モデルが「学習時と違う入力」を受け取ってしまい、精度が大きく落ちることがあります。
今回紹介する preprocess_like_training
関数は、まさにその課題を解決するために作られています。
ざっくりとした流れ
この関数のやっていることをシンプルにまとめると、
- データの形を揃える
2次元(S, 特徴数)でも3次元(B, S, 特徴数)でも受け取り、バッチ軸を整えます。 - 列を分ける
TAIL
で指定した「最後の数列」は触らずそのまま残す。- 残りは正規化対象。
- 学習時の統計量で標準化
mean
とstd
を使って (x – mean) / std に変換。std=0
の場合は 1 に置換してゼロ除算を防止。
- 学習時と同じ列順に並べ替え
keep_for_seq
に保存しておいた列インデックスで並び替え。- これで学習時と完全一致。
- NaN / Inf を安全化
- 万一異常値が入っていても 0 に置き換えて、モデル入力を壊さない。
- 戻り値を調整
- 入力が numpy なら numpy で返す。
- torch.Tensor なら指定の dtype で返す。
- 2次元入力なら最後に2次元へ戻す。
具体例
推論時にはこんな風に使います。
stats = load_scale_stats("scale_stats.json") # 学習時の統計情報をロード
X_seq = preprocess_like_training(
X_sequences_raw, # モデルに入れたい生の特徴行列
stats,
inplace=False,
out_dtype_torch=torch.float32,
return_numpy_if_numpy_input=True
)
これで X_seq
は、学習時と同じ前処理が適用されたデータに変換されます。
あとはそのままモデルへ渡すだけです。
なぜ大事か?
もし学習時と推論時で「列順がズレる」「スケールが違う」などの不整合があると、モデルは全く別の意味のデータを受け取ってしまいます。
たとえば「斤量」と思っていた列に「馬体重」が入ったり、「標準化済み」と思っていた列が未処理だったり…。
そうなるとせっかくのモデルもただの乱数生成器になってしまいます。
preprocess_like_training
はこうしたリスクをゼロにするための関数で、推論時に学習時と同じ処理を強制する守護神といえます。
まとめ
- 学習時の統計量や列順を保存しておき、推論時にもそのまま適用する。
preprocess_like_training
はそのための「再現前処理関数」。- 標準化・列選択・安全化を一括で行い、モデルが期待する形に揃えてくれる。
安心して予測を出すためには、こうした「地味だけど超大事」な仕組みが欠かせません。
コメント