art package
Subpackages
Submodules
art.checks module
- class art.checks.Check
Bases:
ABC
Abstract base class for defining checks.
- name
Name of the check.
- Type:
str
- description
Description of the check.
- Type:
str
- required_files
List of files that are required for this check.
- Type:
List[str]
- abstract check(step) ResultOfCheck
Abstract method to execute the check on the provided step.
- Parameters:
step – The step to check.
- Returns:
The result of the check.
- Return type:
- description: str
- name: str
- required_files: List[str]
- class art.checks.CheckResult
Bases:
Check
Abstract class for checks that are based on the results of a step.
- check(step) ResultOfCheck
Executes the check on the result of the provided step.
- Parameters:
step – The step whose results are to be checked.
- Returns:
The result of the check.
- Return type:
- class art.checks.CheckResultExists(required_key)
Bases:
CheckResult
Concrete check class to verify the existence of a required key in the results.
- required_key
The key that should exist in the results.
- Type:
str
- __init__(required_key)
Constructor for the CheckResultExists class.
- Parameters:
required_key (str) – The key that should exist in the results.
- class art.checks.CheckScore(metric: Union[str, Any], value: float)
Bases:
CheckResult
Base class for checking scores based on a specific metric.
- metric
An object used to calculate the metric or the string with the name of the metric.
- value
The expected value of the metric.
- Type:
float
- build_required_key(step, metric)
Constructs the key for the metric based on the metric’s name, the model’s name, the current step’s name, and the current check stage.
- Parameters:
step – The step in which the metric was calculated.
metric – The metric object.
- check(step) ResultOfCheck
Executes the check on the result of the provided step.
- Parameters:
step – The step whose results are to be checked.
- Returns:
The result of the check.
- Return type:
- class art.checks.CheckScoreCloseTo(metric, value: float, rel_tol: float = 1e-09, abs_tol: float = 0.0)
Bases:
CheckScore
Check to verify if a score in the results based on a specific metric is close to an expected value.
- rel_tol
Relative tolerance. Defaults to 1e-09.
- Type:
float
- abs_tol
Absolute tolerance. Defaults to 0.0.
- Type:
float
- class art.checks.CheckScoreEqualsTo(metric: Union[str, Any], value: float)
Bases:
CheckScore
Check to verify if a score in the results based on a specific metric is equal to an expected value.
- class art.checks.CheckScoreExists(metric)
Bases:
CheckScore
Check to verify the existence of a score in the results based on a specific metric.
- metric
An object used to calculate the metric.
- class art.checks.CheckScoreGreaterThan(metric: Union[str, Any], value: float)
Bases:
CheckScore
Check to verify if a score in the results based on a specific metric is greater than an expected value.
- class art.checks.CheckScoreLessThan(metric: Union[str, Any], value: float)
Bases:
CheckScore
Check to verify if a score in the results based on a specific metric is less than an expected value.
- class art.checks.ResultOfCheck(is_positive: bool = True, error: str = '')
Bases:
object
Dataclass representing the result of a check operation.
- is_positive
Indicates if the check was successful. Defaults to True.
- Type:
bool
- error
Error message if the check was not successful. Defaults to an empty string.
- Type:
str
- error: str = ''
- is_positive: bool = True
art.core module
- class art.core.ArtModule
Bases:
LightningModule
,ABC
- baseline_train(data: Dict)
Baseline train.
- Parameters:
data (Dict) – Data to train.
- Returns:
Data with loss.
- Return type:
Dict
- check_setup()
Check if the metric calculator has been set.
- Raises:
ValueError – If the metric calculator has not been set.
- compute_loss(data: Dict)
Compute loss.
- Parameters:
data (Dict) – Data to compute loss.
- Returns:
Data with loss.
- Return type:
Dict
- compute_metrics(data: Dict)
Compute metrics.
- Parameters:
data (Dict) – Data to compute metrics.
- Returns:
Data with metrics.
- Return type:
Dict
- abstract log_params()
- ml_parse_data(data: Dict)
Parse data for machine learning training.
- Parameters:
data (Dict) – Data to parse.
- Returns:
Parsed data.
- Return type:
Dict
- ml_train(data: Dict)
Machine learning train.
- Parameters:
data (Dict) – Data to train.
- Returns:
Data with loss.
- Return type:
Dict
- parse_data(data: Dict)
Parse data.
- Parameters:
data (Dict) – Data to parse.
- Returns:
Parsed data.
- Return type:
Dict
- predict(data: Dict)
Predict.
- Parameters:
data (Dict) – Data to predict.
- Returns:
Data with predictions.
- Return type:
Dict
- prepare_for_metric(data: Dict)
Prepare data for metric calculation.
- Parameters:
data (Dict) – Data to prepare.
- Returns:
Data with unified type.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- set_metric_calculator(metric_calculator: MetricCalculator)
Set metric calculator.
- Parameters:
metric_calculator (MetricCalculator) – A metric calculator.
- set_pipelines()
Reset pipelines for training, validation, and testing.
- test_step(batch: Union[Dict[str, Any], DataLoader, Tensor], batch_idx: int)
Test step.
- Parameters:
batch (Union[Dict[str, Any], DataLoader, torch.Tensor]) – Batch to test.
batch_idx (int) – Batch index.
- training_step(batch: Union[Dict[str, Any], DataLoader, Tensor], batch_idx: int)
Training step.
- Parameters:
batch (Union[Dict[str, Any], DataLoader, torch.Tensor]) – Batch to train.
batch_idx (int) – Batch index.
- Returns:
Data with loss.
- Return type:
Dict
- unify_type(x: Any)
Unify type - x to torch.Tensor.
- Parameters:
x (Any) – Data to unify type.
- Returns:
Data with unified type.
- Return type:
torch.Tensor
- validation_step(batch: Union[Dict[str, Any], DataLoader, Tensor], batch_idx: int)
Validation step.
- Parameters:
batch (Union[Dict[str, Any], DataLoader, torch.Tensor]) – Batch to validate.
batch_idx (int) – Batch index.
art.decorators module
- class art.decorators.BatchSaver(how_many_batches=10, image_key_name='input')
Bases:
object
Save images from batch to debug_images folder
- __init__(how_many_batches=10, image_key_name='input')
- Parameters:
how_many_batches (int, optional) – How many batches to save. Defaults to 10.
image_key_name (str, optional) – under what . Defaults to “input”.
- class art.decorators.EvolutionSaver(wanted_class_id: int)
Bases:
object
Track evolution of logits for a given class
- __init__(wanted_class_id: int)
- Parameters:
wanted_class_id (int) – Which class to track
- visualize()
Visualizes evolution of logits for a given class.
- class art.decorators.LogInputStats(suppress_stdout=True, custom_logger=None)
Bases:
object
Log input stats to art logger
- __init__(suppress_stdout=True, custom_logger=None)
- Parameters:
suppress_stdout (bool, optional) – Whether to suppress stdout. Defaults to True.
custom_logger (_type_, optional) – By default art_logger will be used. You can pass your custom logger if you want. Defaults to None.
- class art.decorators.ModelDecorator(funcion_name: str, input_decorator: Optional[Callable] = None, output_decorator: Optional[Callable] = None)
Bases:
object
- funcion_name: str
- input_decorator: Optional[Callable] = None
- output_decorator: Optional[Callable] = None
- art.decorators.art_decorate(functions: List[Tuple[object, str]], input_decorator: Optional[Callable] = None, output_decorator: Optional[Callable] = None)
Decorates list of objects functions. It doesn’t modify output of a function put can be used for logging additional information during training.
- Parameters:
functions (List[Tuple[object, str]]) – List of tuples of objects and methods to decorate.
function_in (function, optional) – Function applied on the input. Defaults to None.
function_out (function, optional) – Function applied on the output. Defaults to None.
- art.decorators.art_decorate_single_func(visualizing_function_in=None, visualizing_function_out=None)
Decorates input and output of a function.
- Parameters:
function_in (function, optional) – Function applied on the input. Defaults to None.
function_out (function, optional) – Function applied on the output. Defaults to None.
- Returns:
Decorated function.
- Return type:
function
art.loggers module
- class art.loggers.LoggerFlags(value)
Bases:
Enum
An enumeration.
- SUPRESS_STDOUT = 'supress_stdout'
- class art.loggers.NeptuneLoggerAdapter(*args, **kwargs)
Bases:
NeptuneLogger
This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers.
- add_tags(tags: Union[List[str], str])
Adds tags to the Neptune experiment.
- Parameters:
tags (Union[List[str], str]) – Tag or list of tags to add.
- download_ckpt(id: str, name: Optional[str] = None, type: str = 'last', path: str = './checkpoints')
Downloads a checkpoint from Neptune.
- Parameters:
id (str) – Run ID.
name (str, optional) – Name of the checkpoint. Defaults to None.
type (str, optional) – Type of the checkpoint. Defaults to “last”.
path (str, optional) – Path to download checkpoint to. Defaults to “./checkpoints”.
- Raises:
Exception – If the checkpoint does not exist.
Exception – If the type is not “last” or “best”.
- Returns:
Path to downloaded checkpoint.
- Return type:
str
- log_config(configFile, path: str = 'hydra/config')
Logs a config file to Neptune.
- Parameters:
configFile (str) – Path to config file.
path (str, optional) – Path to log config file to. Defaults to “hydra/config”.
- log_figure(figure, path: str = 'figure')
Logs a figure to Neptune.
- Parameters:
figure (Any) – Figure to log.
path (str, optional) – Path to log figure to. Defaults to “figure”.
- log_img(image, path: str = 'image')
Logs an image to Neptune.
- Parameters:
image (np.ndarray) – Image to log.
path (str, optional) – Path to log image to. Defaults to “image”.
- stop()
- class art.loggers.WandbLoggerAdapter(*args, **kwargs)
Bases:
WandbLogger
This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers. Logging plots in Wandb supports Plotly only. If you want to log matplotlib figures, you need to convert them to Plotly first or log them as images.
- add_tags(tags: Union[List[str], str])
Adds tags to the Wandb run.
- Parameters:
tags (Union[List[str], str]) – Tag or list of tags to add.
- log_config(configFile: str)
Logs a config file to Wandb.
Works only when run as an admin.
- Parameters:
configFile (str) – Path to config file.
- log_figure(figure, path='figure')
Logs a figure to Wandb.
- Parameters:
figure (Any) – Figure to log.
path (str, optional) – Path to log figure to. Defaults to “figure”.
- log_img(image, path: Union[str, ndarray] = 'image')
Logs an image to Wandb.
- Parameters:
image (np.ndarray) – Image to log.
path (str, optional) – Path to log image to. Defaults to “image”.
- art.loggers.add_logger(log_file_path: Path) int
- art.loggers.get_new_log_file_name(run_id: str) str
- art.loggers.get_run_id() str
- art.loggers.log_yellow_warning(message: str)
- art.loggers.remove_logger(logger_id: int)
- art.loggers.supress_stdout(current_logger: Logger) Logger
art.metrics module
- class art.metrics.DefaultMetric
Bases:
object
Placeholder for a default metric.
- class art.metrics.DefaultModel
Bases:
object
Placeholder for a default model.
- class art.metrics.MetricCalculator(experiment: ArtProject)
Bases:
object
Facilitates the management and application of metrics during different stages of training.
This class makes preparing templates for different kinds of projects easy.
- add_metrics(metric: Any)
Add metrics to the list.
- Parameters:
metric (Any) – The metric to add.
- build_name(metric: Any) str
Builds a name for the metric based on its type, current stage.
- Parameters:
metric (Any) – The metric being calculated.
- compile(skipped_metrics: List[SkippedMetric])
Organize metrics based on stages, skipping specified ones.
- Parameters:
skipped_metrics (List[SkippedMetric]) – A list of SkippedMetric instances.
- to(device: str)
Move all metrics to a specified device.
- Parameters:
device (str) – The device to move the metrics to.
- class art.metrics.SkippedMetric(metric, stages: List[str] = ['train', 'validate'])
Bases:
object
Represents a metric that should be skipped during certain training stages.
- art.metrics.build_metric_name(metric: Any, stage: str) str
Builds a name for the metric based on its type and given training stage.
- Parameters:
metric (Any) – The metric being calculated.
stage (str) – The current stage of training.
art.project module
- class art.project.ArtProject(name: str, datamodule: LightningDataModule, use_metric_calculator: bool = True, **kwargs)
Bases:
object
Represents a single Art project, encapsulating steps, state, metrics, and logging.
- __init__(name: str, datamodule: LightningDataModule, use_metric_calculator: bool = True, **kwargs)
Initialize an Art project.
- Parameters:
name (str) – The name of the project.
datamodule (L.LightningDataModule) – Data module to be used in this project.
use_metric_calculator (bool) – Whether to use the metric calculator.
**kwargs – Additional keyword arguments.
- add_step(step: Step, checks: Optional[List[Check]] = [], skipped_metrics: List[SkippedMetric] = [])
Add a step to the project.
- Parameters:
step (Step) – The step to be added.
checks (List[Check]) – A list of checks associated with the step.
skipped_metrics (List[SkippedMetric]) – A list of metrics to skip for this step.
- check_if_must_be_run(step: Step, checks: List[Check]) bool
Check if a given step needs to be executed or if it can be skipped.
- fill_step_states(step: Step)
Update step states with the results from the given step.
- Parameters:
step (Step) – The step whose results need to be recorded.
- get_step(step_id: int) Step
Retrieve a specific step by its ID.
- Parameters:
step_id (int) – The ID of the step to retrieve.
- Returns:
The specified step.
- Return type:
- get_steps()
Retrieve all steps in the project.
- Returns:
List of steps.
- Return type:
List[Dict[str, Any]]
- print_summary()
Prints a summary of the project.
- register_metrics(metrics: List[Any])
Register metrics to the project.
- Parameters:
metrics (List[Any]) – A list of metrics to be registered.
- replace_step(step: Step, step_id: int = -1)
Replace an existing step with a new one.
- Parameters:
step (Step) – The new step.
step_id (int) – The ID of the step to replace. Default is the last step.
- run_all(force_rerun=False, model_decorators: List[ModelDecorator] = [], trainer_kwargs: Dict[str, Any] = {})
Execute all steps in the project.
- Parameters:
force_rerun (bool) – Whether to force rerun all steps.
model_decorators (List[ModelDecorator]) – List of model decorators to be applied.
- run_step(step: Step, skipped_metrics: List[SkippedMetric], model_decorators: List[ModelDecorator], trainer_kwargs: Dict[str, Any], run_id: str)
Run a given step.
- Parameters:
step (Step) – The step to run.
skipped_metrics (List[SkippedMetric]) – List of metrics to skip for this step.
model_decorators (List[Tuple(str, Callable)]) – List of model decorators to be applied.
run_id (str) – The ID of the run.
- update_datamodule(datamodule: LightningDataModule)
Update the data module of the project.
- Parameters:
datamodule (L.LightningDataModule) – New data module to be used in the project.
- class art.project.ArtProjectState
Bases:
object
- current_stage: TrainingStage = 'train'
- get_current_stage()
Gets current stage
- Returns:
Current stage
- Return type:
- get_steps()
Returns all steps that were run
- Returns:
[description]
- Return type:
Dict[str, Dict[str, Dict[str, str]]]
- status: str
- A class for managing the state of a project.
steps:{ “model_name”: {
“step_name”: {/step state/}, “step_name2”: {/step state/},
} “model2_name: {
“step_name”: {/step state/}, “step_name2”: {/step state/},
}
}
- step_states: Dict[str, Dict[str, Dict[str, str]]]
art.steps module
- class art.steps.CheckLossOnInit(model: ArtModule)
Bases:
ModelStep
This step checks whether the loss on init is as expected
- description = 'Checks loss on init'
- do(previous_states: Dict)
This method checks loss on init. It validates the model on the train dataloader and checks whether the loss is as expected.
- Parameters:
previous_states (Dict) – previous states
- name = 'Check Loss On Init'
- requires_ckpt_callback = False
- class art.steps.EvaluateBaseline(baseline: ArtModule, device: Optional[str] = 'cpu')
Bases:
ModelStep
This class takes a baseline and evaluates/trains it on the dataset
- description = 'Evaluates a baseline on the dataset'
- do(previous_states: Dict)
This method evaluates baseline on the dataset
- Parameters:
previous_states (Dict) – previous states
- name = 'Evaluate Baseline'
- requires_ckpt_callback = False
- class art.steps.ExploreData
Bases:
Step
This class checks whether we have some markdown file description of the dataset + we implemented visualizations
- description = 'This step allows you to perform data analysis and extract information that is necessery in next steps'
- name = 'Data analysis'
- class art.steps.ModelStep(model_class: ArtModule, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [], logger: Optional[Logger] = None)
Bases:
Step
A specialized step in the project, representing a model-based step.
- __init__(model_class: ArtModule, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [], logger: Optional[Logger] = None)
Initialize a model-based step.
- Parameters:
model_class (ArtModule) – The model’s class associated with this step.
trainer_kwargs (Dict, optional) – Arguments to be passed to the trainer. Defaults to {}.
model_kwargs (Dict, optional) – Arguments to be passed to the model. Defaults to {}.
model_modifiers (List[Callable], optional) – List of functions to be applied to the model. Defaults to [].
datamodule_modifiers (List[Callable], optional) – List of functions to be applied to the data module. Defaults to [].
logger (Optional[Logger], optional) – Logger to be used. Defaults to None.
- check_ckpt_callback(trainer_kwargs: Dict)
- get_check_stage() str
Get the validation stage value from the TrainingStage enum.
- Returns:
Validation stage value.
- Return type:
str
- get_current_stage() str
Retrieve the current training stage of the trainer.
- Returns:
Current training stage.
- Return type:
str
- get_full_step_name() str
Retrieve the step ID, combining model name (if available) with the index.
- Returns:
The step ID.
- Return type:
str
- get_hash() str
Compute a hash based on the source code of the step’s class.
- Returns:
MD5 hash of the step’s source code.
- Return type:
str
- get_trainloader()
- get_valloader()
- log_data_params()
- log_model_params(model)
- requires_ckpt_callback = True
- reset_trainer(logger: Optional[Logger] = None, trainer_kwargs: Dict = {})
Reset the trainer. :param trainer_kwargs: Arguments to be passed to the trainer. :type trainer_kwargs: Dict :param logger: Logger to be used. Defaults to None. :type logger: Optional[Logger], optional
- test(trainer_kwargs: Dict)
Test the model using the provided trainer arguments.
- Parameters:
trainer_kwargs (Dict) – Arguments to be passed to the trainer for testing the model.
- train(trainer_kwargs: Dict)
Train the model using the provided trainer arguments.
- Parameters:
trainer_kwargs (Dict) – Arguments to be passed to the trainer for training the model.
- validate(trainer_kwargs: Dict)
Validate the model using the provided trainer arguments.
- Parameters:
trainer_kwargs (Dict) – Arguments to be passed to the trainer for validating the model.
- class art.steps.NoModelUsed
Bases:
object
- class art.steps.Overfit(model: ArtModule, logger: Optional[Logger] = None, max_epochs: int = 1)
Bases:
ModelStep
This step tries to overfit the model
- description = 'Overfits model'
- do(previous_states: Dict)
This method overfits the model
- Parameters:
previous_states (Dict) – previous states
- get_check_stage()
Returns check stage
- log_model_params(model)
- name = 'Overfit'
- class art.steps.OverfitOneBatch(model: ArtModule, number_of_steps: int = 50)
Bases:
ModelStep
This step tries to Overfit one train batch
- description = 'Overfits one batch'
- do(previous_states: Dict)
This method overfits one batch
- Parameters:
previous_states (Dict) – previous states
- get_check_stage()
Returns check stage
- log_model_params(model)
- name = 'Overfit One Batch'
- requires_ckpt_callback = False
- class art.steps.Regularize(model: ArtModule, logger: Optional[Logger] = None, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [])
Bases:
ModelStep
This step tries applying regularization to the model
- __init__(model: ArtModule, logger: Optional[Logger] = None, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [])
- Parameters:
model (ArtModule) – model
logger (Logger, optional) – logger. Defaults to None.
trainer_kwargs (Dict, optional) – Kwargs passed to lightning Trainer. Defaults to {}.
model_kwargs (Dict, optional) – Kwargs passed to model. Defaults to {}.
model_modifiers (List[Callable], optional) – model modifiers. Defaults to [].
datamodule_modifiers (List[Callable], optional) – datamodule modifiers. Defaults to [].
- check_if_already_tried()
The idea of this function is to help the project decide if there is any more reason the step shouldn’t be run even though it has failed.
- continue_on_failure = True
- description = 'Regularizes model'
- do(previous_states: Dict)
This method regularizes the model
- Parameters:
previous_states (Dict) – previous states
- get_latest_run() Dict
If step was run returns itself, otherwise returns the latest run from the JSONStepSaver.
In case of regularization we are interested in the latest successful run.
- Returns:
The latest run.
- Return type:
Dict
- name = 'Regularize'
- save_to_disk()
- set_successful()
- stringify_modifiers(modifiers: List[Callable])
- update_was_already_tried()
This method verify if such Regularize parameters was already tried and updates self.was_already_tried
- class art.steps.Squeeze(model_class: ArtModule, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [], logger: Optional[Logger] = None)
Bases:
ModelStep
- class art.steps.Step
Bases:
ABC
An abstract base class representing a generic step in a project.
- __init__()
Initialize the step with an empty results dictionary.
- add_result(name: str, value: Any)
Add a result to the step’s results dictionary.
- Parameters:
name (str) – Name of the result.
value (Any) – Value of the result.
- check_if_already_tried()
The idea of this function is to help the project decide if there is any more reason the step shouldn’t be run even though it has failed.
- continue_on_failure = False
- abstract do(previous_states: Dict)
Abstract method to execute the step. Must be implemented by child classes.
- Parameters:
previous_states (Dict) – Dictionary containing the previous step states.
- fill_basic_results()
Fill basic results like hash and commit id
- get_full_step_name() str
Retrieve the full name of the step, which is a combination of its ID and name.
- Returns:
The full step name.
- Return type:
str
- get_hash() str
Compute a hash based on the source code of the step’s class.
- Returns:
MD5 hash of the step’s source code.
- Return type:
str
- get_latest_run() Dict
If step was run returns itself, otherwise returns the latest run from the JSONStepSaver.
- Returns:
The latest run.
- Return type:
Dict
- is_successful()
- model = <art.steps.NoModelUsed object>
- name = 'Data analysis'
- save_to_disk()
- set_successful()
- was_run() bool
Check if the step was already executed based on the existence of saved results.
- Returns:
True if the step was run, otherwise False.
- Return type:
bool
- class art.steps.TransferLearning(model: ArtModule, model_modifiers: List[Callable] = [], logger: Optional[Logger] = None, freezed_trainer_kwargs: Dict = {}, unfreezed_trainer_kwargs: Dict = {}, freeze_names: Optional[list[str]] = None, keep_unfrozen: Optional[int] = None, fine_tune_lr: float = 1e-05, fine_tune: bool = True)
Bases:
ModelStep
This step tries performing proper transfer learning
- __init__(model: ArtModule, model_modifiers: List[Callable] = [], logger: Optional[Logger] = None, freezed_trainer_kwargs: Dict = {}, unfreezed_trainer_kwargs: Dict = {}, freeze_names: Optional[list[str]] = None, keep_unfrozen: Optional[int] = None, fine_tune_lr: float = 1e-05, fine_tune: bool = True)
This method initializes the step
- Parameters:
model (ArtModule) – model
model_modifiers (List[Callable], optional) – model modifiers. Defaults to [].
logger (Logger, optional) – logger. Defaults to None.
freezed_trainer_kwargs (Dict, optional) – trainer kwargs use for transfer learning with freezed weights. Defaults to {}.
unfreezed_trainer_kwargs (Dict, optional) – trainer kwargs use for fine tuning with unfreezed weights. Defaults to {}.
freeze_names (Optional[list[str]], optional) – name of model to freeze which appears in layers. Defaults to None.
keep_unfrozen (Optional[int], optional) – number of last layers to keep unfrozen. Defaults to None.
fine_tune_lr (float, optional) – fine tune lr. Defaults to 1e-5.
fine_tune (bool, optional) – whether or not perform fine tuning. Defaults to True.
- add_freezing()
Adds freezing to the model
- add_lr_change()
Adds lr change to the model
- add_unfreezing()
Adds unfreezing to the model
- description = 'This step tries performing proper transfer learning'
- do(previous_states: Dict)
This method trains the model :param previous_states: previous states :type previous_states: Dict
- get_check_stage()
Returns check stage
- name = 'TransferLearning'