Source code for dunedn.training.callbacks

"""This module implements onjects to keep track of training details. """
from pprint import pformat


[docs]class Callback: def __init__(self): pass
[docs] def on_train_begin(self, logs=None): pass
[docs] def on_train_end(self, logs=None): pass
[docs] def on_train_batch_begin(self, logs=None): pass
[docs] def on_train_batch_end(self, logs=None): pass
[docs] def on_epoch_begin(self, logs=None): pass
[docs] def on_epoch_end(self, logs=None): pass
[docs] def on_eval_begin(self, logs=None): pass
[docs] def on_eval_end(self, logs=None): pass
[docs]class CallbackList(Callback): def __init__(self, callbacks: list[Callback]): self.callback_list = callbacks
[docs] def hook(self, hook_name: str, logs: dict): """An utility function to call each callback method. Parameters ---------- hook_name: str The name of the method to be called. logs: dict The dictionary to be logged. """ for callback in self.callback_list: hook = getattr(callback, hook_name) hook(logs)
[docs] def on_train_begin(self, logs=None): self.hook("on_train_begin", logs)
[docs] def on_train_end(self, logs=None): self.hook("on_train_end", logs)
[docs] def on_train_batch_begin(self, logs=None): self.hook("on_train_batch_begin", logs)
[docs] def on_train_batch_end(self, logs=None): self.hook("on_train_batch_end", logs)
[docs] def on_epoch_begin(self, logs=None): self.hook("on_epoch_begin", logs)
[docs] def on_epoch_end(self, logs=None): self.hook("on_epoch_end", logs)
[docs] def on_eval_begin(self, logs=None): self.hook("on_eval_begin", logs)
[docs] def on_eval_end(self, logs=None): self.hook("on_eval_end", logs)
[docs]class History(Callback): def __init__(self): self.logs = None def __repr__(self): return pformat(self.logs, indent=2)
[docs] def on_train_begin(self, logs=None): self.reset()
[docs] def on_train_batch_end(self, logs=None): self.append(logs)
[docs] def on_epoch_end(self, logs=None): self.append(logs)
[docs] def append(self, logs=None): logs = logs or {} for k, v in logs.items(): self.logs.setdefault(k, []).append(v)
[docs] def reset(self): self.logs = {}