vdl_logger.py 633 B

123456789101112131415161718192021
  1. from .base_logger import BaseLogger
  2. from visualdl import LogWriter
  3. class VDLLogger(BaseLogger):
  4. def __init__(self, save_dir):
  5. super().__init__(save_dir)
  6. self.vdl_writer = LogWriter(logdir=save_dir)
  7. def log_metrics(self, metrics, prefix=None, step=None):
  8. if not prefix:
  9. prefix = ""
  10. updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()}
  11. for k, v in updated_metrics.items():
  12. self.vdl_writer.add_scalar(k, v, step)
  13. def log_model(self, is_best, prefix, metadata=None):
  14. pass
  15. def close(self):
  16. self.vdl_writer.close()