콘텐츠로 이동

TODO: Chronos Hidden State KD 방법론 설계 및 구현

작성일: 2026-04-08
최종 업데이트: 2026-04-09
상태: 우선순위 격상 — 즉시 착수 대상 (GWN A_adp 실험 라인 종료됨)
근거 문서: docs/reference/kd_method_comparison.md §10 (TSFM Teacher 적용 가능성 분석)
전제 조건 변경: ~~A_adp 활용 실험(방법 E/C) 결과 확인 후 진행~~ → 전제 조건 해소됨 (A_adp 실험 종료)

우선순위 격상 근거 (2026-04-09)

2026-04-09 Phase 2 KD Ablation 실험 최종 결과: - GWN A_adp 기반 SC-DLinear KD 스토리라인 REJECT (2×2 Factorial + Null Test 전 조건 p > 0.05) - A_adp 대각선 지배도 23.84 (기준 5.0의 4.8배) — off-diagonal spatial information 실질적으로 없음 - FiLM conditioning이 random matrix 대비 유의미한 차이를 만들지 못함 (p=0.3711)

따라서 GWN → A_adp 경로를 통한 spatial KD는 현재 데이터로 불가능하며, Chronos Hidden State 기반 시간적(temporal) KD로 방향 전환.

전제 조건 중 남은 확인 사항: - [x] GWN Teacher 재학습 완료 (MSE=0.5049 < 0.5207) - [x] A_adp 활용 실험 종료 (REJECT 결론) - [ ] GPU 메모리 여유 확인 — 즉시 점검 필요


배경

kd_method_comparison.md §10.4의 제안:

GWN Teacher (공간 정보) + Chronos Teacher (시간 정보)의 Dual-Teacher KD가 논문 기여도 측면에서 가장 유망

현재는 GWN → DLinear의 Response-based KD(Soft-DTW)만 구현되어 있음.
Chronos-Bolt-Small (pretrained_models/chronos-bolt-small/)은 이미 로컬에 존재하며, T5 Encoder의 encoder_last_hidden_state를 추출하여 Feature-based KD를 수행하는 것이 목표.


Chronos-Bolt-Small 아키텍처 분석

항목
아키텍처 T5 Encoder-Decoder (is_encoder_decoder=True)
d_model (hidden_dim) 512
Encoder 레이어 수 6
입력 Patch 방식 input_patch_size=16, input_patch_stride=16
use_reg_token True
입력 토큰 수 (seq_len=96) ceil(96/16) + 1(reg) = 7 tokens
Encoder output shape [B, 7, 512]
출력 방식 분위수 예측 (quantile regression)

원본 Chronos (토큰 양자화)와 다름: Chronos-Bolt는 패치 임베딩 기반이며, vocab_size=2로 단순화된 구조. 내부 패치 임베딩 레이어(input_patch_embedding)가 별도로 존재.


핵심 과제: Hidden State 추출 방법 3가지

DistilTS의 Chronos.pypredict_quantiles()만 사용하므로 hidden state를 반환하지 않음.
내부 T5 모델에 직접 접근해야 함.

방법 F1: Forward Hook (비침습적, 추천)

from chronos import BaseChronosPipeline

pipe = BaseChronosPipeline.from_pretrained(
    "pretrained_models/chronos-bolt-small",
    device_map="cuda", torch_dtype=torch.bfloat16
)

cached = {}

def _hook(module, inp, out):
    # T5Stack.forward() 반환값: BaseModelOutput
    # out[0] = last_hidden_state [B, num_tokens, 512]
    cached['h'] = out[0].detach().float()  # bfloat16 → float32 변환

handle = pipe.model.encoder.register_forward_hook(_hook)

with torch.no_grad():
    _ = pipe.predict_quantiles(
        context=x_ind.squeeze(-1),   # [B, seq_len]
        prediction_length=pred_len,
        quantile_levels=[0.5],
    )

h_chronos = cached['h']  # [B, 7, 512]
handle.remove()

장점: predict_quantiles 파이프라인(정규화, 패치화)을 그대로 사용 → 입력 전처리 일관성 보장
단점: 매 forward마다 hook 등록/해제 관리 필요, gradient 흐름 차단됨(Teacher frozen이므로 무관)


방법 F2: ChronosBolt 내부 분해 (직접 호출)

Chronos-Bolt의 내부 forward를 단계별로 분해하여 직접 호출.

# pipe.model = ChronosBoltModelForForecasting
# pipe.model.model = T5ForConditionalGeneration (내부 T5)

with torch.no_grad():
    # Step 1: 입력 정규화 (내부 stochastic scale 사용)
    context = x_ind.squeeze(-1)  # [B, seq_len]
    # Chronos-Bolt 정규화: scale = context.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8)
    scale = context.abs().mean(dim=-1, keepdim=True).clamp(min=1e-8)
    context_normalized = context / scale  # [B, seq_len]

    # Step 2: Patch Embedding
    # pipe.model.input_patch_embedding: Conv1d or Linear based patch tokenizer
    patched = pipe.model.input_patch_embedding(
        context_normalized.unsqueeze(1)  # [B, 1, seq_len] → [B, num_patches, d_model]
    )  # [B, 7, 512]

    # Step 3: T5 Encoder forward
    enc_out = pipe.model.model.encoder(
        inputs_embeds=patched,
        return_dict=True,
    )
    h_chronos = enc_out.last_hidden_state  # [B, 7, 512]

주의: Chronos-Bolt의 정규화 방식 및 input_patch_embedding API가 내부 구현에 종속적 → 버전 변경 시 깨질 수 있음.
권장 용도: F1이 실패할 경우의 fallback. 또는 학습 코드에서 batch 단위 처리 최적화 시.


방법 F3: output_hidden_states=True (멀티레이어 hints)

FitNet 스타일의 멀티레이어 Feature KD를 구현할 경우.

# F2의 Step 3 대신:
enc_out = pipe.model.model.encoder(
    inputs_embeds=patched,
    return_dict=True,
    output_hidden_states=True,
)
# enc_out.hidden_states: tuple of 7 tensors (embedding + 6 layers)
# 각 텐서 shape: [B, 7, 512]

# 레이어 선택 예시: 중간 레이어 (3번째) + 마지막 레이어
h_mid = enc_out.hidden_states[3]   # [B, 7, 512]
h_last = enc_out.hidden_states[-1]  # [B, 7, 512]

용도: 다층 hint KD (FitNet, SDKD). 현재 DLinear가 단순 구조이므로 1개 레이어 정렬로도 충분. Phase 3에서 고려.


Projector 설계 (Teacher Hidden → Student 차원 정렬)

구성 입력 출력 추가 params 권장
Mean Pool + Linear [B, 7, 512] → mean → [B, 512] → Linear(512, 24) [B, 24] 12,312 기본
Reg Token + Linear [B, 7, 512] → last → [B, 512] → Linear(512, 24) [B, 24] 12,312 실험 비교
Mean Pool + 2-layer MLP [B, 7, 512] → mean → [B, 512] → Linear(512, 128) → GELU → Linear(128, 24) [B, 24] 68,760 추천
DistilTS FTA 방식 [B, 7, 512] → VarTimeFactorAligner(T=7, t_dim=512) Student hidden ~수만 Phase 3 고려

추천 Projector 구현 (Mean Pool + 2-layer MLP)

class ChronosHintProjector(nn.Module):
    """
    Chronos Encoder last_hidden_state → DLinear prediction 차원 정렬
    Teacher hidden: [B, num_tokens, d_model=512]
    Student output: [B, pred_len=24]
    """
    def __init__(self, t_hidden=512, pred_len=24, mid_dim=128):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(t_hidden, mid_dim),
            nn.GELU(),
            nn.Linear(mid_dim, pred_len),
        )

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        # h: [B, num_tokens, t_hidden]
        h_pool = h.mean(dim=1)          # [B, t_hidden] — Mean Pool
        return self.proj(h_pool)         # [B, pred_len]

설계 근거: use_reg_token=True이므로 마지막 토큰이 회귀 토큰 역할.
Mean pool이 reg token에 비해 전체 시간 패턴을 더 균등하게 반영함.
두 방식 모두 실험하여 비교 권장.


학습 전략

Loss 구성

L_total = L_task + α * L_kd_response + β * L_feat_chronos

L_task         = MSE(s_out, y_true)           # Supervised loss
L_kd_response  = SoftDTW(s_out, t_gwn_out)   # GWN Soft Target KD (기존)
L_feat_chronos = MSE(projector(h_chronos), s_out.detach())  # Chronos Feature KD

주의: L_feat_chronos에서 Student output을 target으로 사용할지,
또는 projector output을 Student에 정렬시킬지 방향 결정 필요. - 방향 A: MSE(s_out, proj(h_chronos)) — Student가 Chronos hidden의 투영을 모방 - 방향 B: MSE(proj(h_chronos), s_out.detach()) — Projector가 Student output을 근사 (projector 학습)

권장: 방향 A. Student가 직접 Chronos의 시간 표현을 흡수.

학습 파라미터

모듈 학습 여부
Chronos-Bolt (Teacher) Frozen (requires_grad=False)
GWN Teacher Frozen (학습 완료 후)
DLinear Student 학습
ChronosHintProjector 학습 (Student와 동일 optimizer)

구현 계획 (단계별)

Phase 1: Chronos Hidden State 추출 검증 (0.5일)

목표: Chronos-Bolt에서 hidden state가 올바르게 추출되는지 확인

  • tests/test_chronos_hidden.py 작성
  • Forward Hook (방법 F1) 정상 동작 확인
  • h_chronos.shape == [B, 7, 512] 검증
  • bfloat16 → float32 변환 확인
  • seq_len=96, batch_size=4 기준 smoke test
# 검증 포인트
def test_chronos_hidden_shape():
    pipe = load_chronos_pipeline("pretrained_models/chronos-bolt-small")
    hook, cached = register_encoder_hook(pipe)

    x = torch.randn(4, 96)  # [B, seq_len]
    with torch.no_grad():
        pipe.predict_quantiles(x, prediction_length=24, quantile_levels=[0.5])

    assert cached['h'].shape == (4, 7, 512)
    assert cached['h'].dtype == torch.float32
    hook.remove()

Phase 2: Projector 구현 및 ECPairTrainer 통합 (1일)

목표: ChronosHintProjectorECPairTrainer에 통합

  • src/peak_analysis/chronos_projector.py 구현
  • ChronosHintProjector 클래스
  • ChronosHiddenExtractor (Hook 방식, context manager 지원)
  • register_encoder_hook(pipe) 유틸리티

  • src/peak_analysis/ec_pair_trainer.py 수정

  • chronos_alpha 하이퍼파라미터 추가 (default: 0.0)
  • train_step()L_feat_chronos 추가
  • Chronos 입력: x_ind.squeeze(-1)[B, seq_len]

  • src/peak_analysis/config.py 업데이트

  • CHRONOS_BOLT_PATH = ROOT_DIR / "pretrained_models/chronos-bolt-small"

Phase 3: 비교 실험 (1~1.5일)

목표: Chronos Feature KD의 효과 검증

실험 ID 구성 비교 목적
C_base GWN Soft-DTW only 현재 기준선
C_feat_mean + Chronos MeanPool proj Feature KD 효과
C_feat_reg + Chronos RegToken proj Pooling 방법 비교
C_dual GWN SoftDTW + Chronos Feature Dual-Teacher KD
C_dual_E + Adj-Guided Loss (방법 E) 3-way KD

평가 지표: MSE, MAE, PAPE, 추론 속도, Student 파라미터 수

Phase 4: 논문 포지셔닝 분석 (0.5일)

  • DistilTS (ICASSP 2026)와의 차별화 포인트 정리
  • DistilTS: 일반 시계열 데이터셋 (ETTh, Weather 등)
  • 본 연구: 에너지 커뮤니티 특화, 피크 예측 목적 (PAPE 최소화)
  • 본 연구: GWN + Chronos Dual-Teacher (공간 + 시간 정보 통합)
  • 논문 기여 포인트: "Energy-Community-Aware Dual-Teacher KD"

주요 리스크 및 대응

리스크 가능성 대응
Chronos-Bolt 내부 API가 외부 접근 차단 중간 F2 방식으로 직접 분해
input_patch_embedding 접근 경로가 버전마다 다름 중간 chronos 패키지 소스 코드 직접 확인
Projector 학습이 Student를 Chronos에 과적합시킴 낮음 chronos_alpha 값을 0.1 이하로 제한, 검증 손실 모니터링
Chronos hidden이 에너지 도메인에 부적합 중간 A_adp 전달(방법 C/E) 결과와 비교 후 우선순위 결정
Dual-Teacher로 인한 학습 불안정 중간 단계적 도입: Chronos KD를 warm-up 이후에만 활성화

전제 조건 체크리스트

  • GWN Teacher (N=50) 재학습 완료 — Teacher MSE < DLinear Baseline (0.5207)
  • A_adp 활용 방법 E (Adj-Guided Loss) 실험 완료 — Chronos KD와 비교 기준 마련
  • chronos 패키지 설치 확인: uv run python -c "from chronos import BaseChronosPipeline; print('OK')"
  • GPU 메모리 여유 확인 — Chronos-Bolt-Small (약 400MB) + GWN + DLinear 동시 로드 가능 여부

참고 자료

자료 위치
DistilTS 구현 (Chronos Teacher) src/DistilTS-ICASSP2026/models/Chronos.py
DistilTS FTA 모듈 src/DistilTS-ICASSP2026/exp/exp_DistilTS.py (VarTimeFactorAligner)
Chronos-Bolt config pretrained_models/chronos-bolt-small/config.json
KD 방법론 분석 전체 kd_method_comparison.md (같은 디렉토리)