Blog Full Notice
back to main page

1 분 소요

motivation:

https://docs.wandb.ai/quickstart

  • linux에서 pip install wandb 실행
  • 이후 wandb login 은 python 코드 안에서 진행하게 된다.
  

import wandb
wandb.login(key="42a1f3a.....")
# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",
    config={
    "learning_rate": CONFIG.LR, # 3e-4,
    "number_of_epochs" : CONFIG.N_EPOCHS, # 5,
    "batch_size": CONFIG.BATCH_SIZE, # 96,
    "sample_rate": CONFIG.SR, # 32000,
    "N_MFCC": CONFIG.N_MFCC, # 13,
    "SEED": CONFIG.SEED, # 42,
    "architecture": "MLP",
    "hidden_dim":CONFIG.hidden_dim,
    # "dataset": "audio_test50000_train55438",
    # "N_CLASSES":  2,
    }
)

이런 식으로 정의하고,

나중에 for i in epochs: loop에서,
        wandb.log({“train_loss”: _train_loss, “_val_loss”: _val_loss}) 이런 식으로 logging하면 된다. 그러면 자동으로 기록이됨.


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
  
seed_everything(CONFIG.SEED) # Seed 고정

2. watch

wandb.watch(model, log='all')

def train(model, optimizer, train_loader, val_loader, device):

    model.to(device)
    criterion = nn.BCELoss().to(device)
    
    wandb.watch(model, log='all')
    
    best_val_score = 0
    best_model = None
    for epoch in range(1, CONFIG.N_EPOCHS+1):
	    # .......


이런 식으로 model이 정의되면 wandb.watch를 하면 gradient를 추적해준다.

3. sweep

https://docs.wandb.ai/guides/sweeps/define-sweep-configuration

근데 hyperparamter도 자동으로 찾게 하게 하려면,

wandb sweep --project <propject-name> <path-to-config file> wandb agent <sweep-ID>

를 linux에서 쳐주면 된다.

나의 경우,

wandb sweep --project my-awesome-project sweep_config.yaml 이거를 실행하면 sweep-ID가 리눅스 창에서 나온다. 나의 경우 g7kdxvhl 이거 였다.

그러면

wandb agent 'wys000112-Kyung Hee University/my-awesome-project/415m1qg2' 를 하면 된다.


infer_model = train(model, optimizer, train_loader, val_loader, device)

import config # 이거는 wandb sweep --project my-awesome-project sweep_config.yaml 이거를 해야 보이는 것.
sweep_id = wandb.sweep(config.sweep_config)
wandb.agent(sweep_id, train, count=2)

댓글남기기