Blog Full Notice
back to main page
wandb 사용하는 방법
motivation:
https://docs.wandb.ai/quickstart
- linux에서
pip install wandb실행 - 이후
wandb login은 python 코드 안에서 진행하게 된다.
import wandb
wandb.login(key="42a1f3a.....")
# start a new wandb run to track this script
wandb.init(
# set the wandb project where this run will be logged
project="my-awesome-project",
config={
"learning_rate": CONFIG.LR, # 3e-4,
"number_of_epochs" : CONFIG.N_EPOCHS, # 5,
"batch_size": CONFIG.BATCH_SIZE, # 96,
"sample_rate": CONFIG.SR, # 32000,
"N_MFCC": CONFIG.N_MFCC, # 13,
"SEED": CONFIG.SEED, # 42,
"architecture": "MLP",
"hidden_dim":CONFIG.hidden_dim,
# "dataset": "audio_test50000_train55438",
# "N_CLASSES": 2,
}
)
이런 식으로 정의하고,
나중에 for i in epochs: loop에서,
wandb.log({“train_loss”: _train_loss, “_val_loss”: _val_loss})
이런 식으로 logging하면 된다. 그러면 자동으로 기록이됨.
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
seed_everything(CONFIG.SEED) # Seed 고정
2. watch
wandb.watch(model, log='all')
def train(model, optimizer, train_loader, val_loader, device):
model.to(device)
criterion = nn.BCELoss().to(device)
wandb.watch(model, log='all')
best_val_score = 0
best_model = None
for epoch in range(1, CONFIG.N_EPOCHS+1):
# .......
이런 식으로 model이 정의되면 wandb.watch를 하면 gradient를 추적해준다.
3. sweep
https://docs.wandb.ai/guides/sweeps/define-sweep-configuration
근데 hyperparamter도 자동으로 찾게 하게 하려면,
wandb sweep --project <propject-name> <path-to-config file>
wandb agent <sweep-ID>
를 linux에서 쳐주면 된다.
나의 경우,
wandb sweep --project my-awesome-project sweep_config.yaml
이거를 실행하면 sweep-ID가 리눅스 창에서 나온다. 나의 경우 g7kdxvhl 이거 였다.
그러면
wandb agent 'wys000112-Kyung Hee University/my-awesome-project/415m1qg2' 를 하면 된다.
infer_model = train(model, optimizer, train_loader, val_loader, device)
import config # 이거는 wandb sweep --project my-awesome-project sweep_config.yaml 이거를 해야 보이는 것.
sweep_id = wandb.sweep(config.sweep_config)
wandb.agent(sweep_id, train, count=2)
댓글남기기