Source code for chipiron.utils.chi_nn

"""
Module for the ChiNN class
"""

import sys
import traceback

import torch
import torch.nn as nn

from chipiron.utils import path
from chipiron.utils.logger import chipiron_logger
from chipiron.utils.small_tools import resolve_package_path


[docs]class ChiNN(nn.Module): """ The Generic Neural network class of chipiron """ def __init__(self) -> None: """ Initializes an instance of the ChiNN class. """ super(ChiNN, self).__init__() def __getstate__(self) -> None: """ Get the state of the object for pickling. Returns: None """ return None
[docs] def init_weights(self) -> None: """ Initialize the weights of the model. """ pass
[docs] def load_weights_from_file(self, path_to_param_file: path) -> None: """ Loads the neural network weights from a file or initializes them if the file doesn't exist. Args: path_to_param_file (str): The path to the parameter file. authorisation_to_create_file (bool): Flag indicating whether the program has authorization to create a new file. Returns: None """ chipiron_logger.info(f"load_or_init_weights from {path_to_param_file}") try: # load resolved_path = resolve_package_path(str(path_to_param_file)) with open(resolved_path, "rb") as fileNNR: chipiron_logger.info(f"loading the existing param file {resolved_path}") if torch.cuda.is_available(): self.load_state_dict(torch.load(fileNNR)) else: self.load_state_dict( torch.load(fileNNR, map_location=torch.device("cpu")) ) except EnvironmentError: # init # Print the full traceback to stderr traceback.print_exc() chipiron_logger.error(f"no file {path_to_param_file} at {resolved_path}") sys.exit( "Error: no NN weights file and no rights to create it for file {}".format( path_to_param_file, ) )
[docs] def log_readable_model_weights_to_file(self, file_path: str) -> None: raise Exception("not implemented in base class")