기존의 MoeLayer 클래스에 전문가 선택 빈도를 추적하고, 이를 바탕으로 보정된 보조 손실을 계산하는 기능을 추가하는 방식으로 클래스를 재구성해 보겠습니다. 이를 위해, 각 전문가의 선택 빈도를 추적하는 메커니즘과, 이 정보를 활용하여 보조 손실을 조정하는 로직을 추가합니다.

python Copy code import torch import torch.nn as nn import torch.nn.functional as F

class MoeLayer(nn.Module): def init(self, experts: list[nn.Module], gate: nn.Module, num_experts_per_tok: int): super().init() assert len(experts) > 0 self.experts = nn.ModuleList(experts) self.gate = gate self.num_experts_per_tok = num_experts_per_tok

    # 전문가 선택 빈도 추적
    self.expert_selection_frequency = torch.zeros(len(experts), dtype=torch.float32)
    
    self.initialize_weights()

def forward(self, inputs: torch.Tensor):
    gate_logits = self.gate(inputs)
    weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok)
    weights = F.softmax(weights, dim=-1)
    results = torch.zeros_like(inputs)

    for i, expert in enumerate(self.experts):
        batch_idx, nth_expert = torch.where(selected_experts == i)
        # 결과에 가중치 적용 및 업데이트
        results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx])
        # 전문가 선택 빈도 업데이트
        self.expert_selection_frequency[i] += batch_idx.size(0)

    # 보조 손실 계산
    aux_loss = self.compute_auxiliary_loss(gate_logits)

    return results, aux_loss

def compute_auxiliary_loss(self, gate_logits):
    gate_probs = F.softmax(gate_logits, dim=-1)
    f_i = gate_probs.mean(dim=0)
    P_i = torch.mean(torch.eq(gate_probs, torch.max(gate_probs, dim=-1, keepdim=True).values).float(), dim=0)

    # 전문가 선택 빈도를 기반으로 보정된 보조 손실 계산
    # 선택 빈도가 낮은 전문가에 대한 가중치 증가
    selection_frequency_normalized = self.expert_selection_frequency / self.expert_selection_frequency.sum()
    corrected_f_i = f_i * (1 / selection_frequency_normalized.clamp(min=1e-5))
    
    aux_loss = corrected_f_i @ P_i * len(self.experts)
    return aux_loss

def initialize_weights(self):
    for expert in self.experts:
        if isinstance(expert, nn.Linear):
            nn.init.kaiming_normal_(expert.weight, mode='fan_out')
    nn.init.kaiming_normal_(self.gate.weight, mode='fan_out') 이 클래스의 주요 변경 사항은 다음과 같습니다:

전문가 선택 빈도 추적: self.expert_selection_frequency 변수를 추가하여 각 전문가가 선택된 횟수를 추적합니다. 보정된 보조 손실 계산: compute_auxiliary_loss 메소드에서 전문가의 선택 빈도를 고려하여 보정된 보조 손실을 계산합니다. 이는 전체적으로 균등한 전문가 선택을 유도합니다. 이 코드는 모델의 전문가 선택 메커니즘에 대한 균형을 더 잘 유지하도록 설계되었습니다. 하지만