art package

Subpackages

Submodules

art.checks module

class art.checks.Check

Bases: ABC

Abstract base class for defining checks.

name

Name of the check.

Type:

str

description

Description of the check.

Type:

str

required_files

List of files that are required for this check.

Type:

List[str]

abstract check(step) ResultOfCheck

Abstract method to execute the check on the provided step.

Parameters:

step – The step to check.

Returns:

The result of the check.

Return type:

ResultOfCheck

description: str
name: str
required_files: List[str]
class art.checks.CheckResult

Bases: Check

Abstract class for checks that are based on the results of a step.

check(step) ResultOfCheck

Executes the check on the result of the provided step.

Parameters:

step – The step whose results are to be checked.

Returns:

The result of the check.

Return type:

ResultOfCheck

class art.checks.CheckResultExists(required_key)

Bases: CheckResult

Concrete check class to verify the existence of a required key in the results.

required_key

The key that should exist in the results.

Type:

str

__init__(required_key)

Constructor for the CheckResultExists class.

Parameters:

required_key (str) – The key that should exist in the results.

class art.checks.CheckScore(metric: Union[str, Any], value: float)

Bases: CheckResult

Base class for checking scores based on a specific metric.

metric

An object used to calculate the metric or the string with the name of the metric.

value

The expected value of the metric.

Type:

float

build_required_key(step, metric)

Constructs the key for the metric based on the metric’s name, the model’s name, the current step’s name, and the current check stage.

Parameters:
  • step – The step in which the metric was calculated.

  • metric – The metric object.

check(step) ResultOfCheck

Executes the check on the result of the provided step.

Parameters:

step – The step whose results are to be checked.

Returns:

The result of the check.

Return type:

ResultOfCheck

class art.checks.CheckScoreCloseTo(metric, value: float, rel_tol: float = 1e-09, abs_tol: float = 0.0)

Bases: CheckScore

Check to verify if a score in the results based on a specific metric is close to an expected value.

rel_tol

Relative tolerance. Defaults to 1e-09.

Type:

float

abs_tol

Absolute tolerance. Defaults to 0.0.

Type:

float

class art.checks.CheckScoreEqualsTo(metric: Union[str, Any], value: float)

Bases: CheckScore

Check to verify if a score in the results based on a specific metric is equal to an expected value.

class art.checks.CheckScoreExists(metric)

Bases: CheckScore

Check to verify the existence of a score in the results based on a specific metric.

metric

An object used to calculate the metric.

class art.checks.CheckScoreGreaterThan(metric: Union[str, Any], value: float)

Bases: CheckScore

Check to verify if a score in the results based on a specific metric is greater than an expected value.

class art.checks.CheckScoreLessThan(metric: Union[str, Any], value: float)

Bases: CheckScore

Check to verify if a score in the results based on a specific metric is less than an expected value.

class art.checks.ResultOfCheck(is_positive: bool = True, error: str = '')

Bases: object

Dataclass representing the result of a check operation.

is_positive

Indicates if the check was successful. Defaults to True.

Type:

bool

error

Error message if the check was not successful. Defaults to an empty string.

Type:

str

error: str = ''
is_positive: bool = True

art.core module

class art.core.ArtModule

Bases: LightningModule, ABC

baseline_train(data: Dict)

Baseline train.

Parameters:

data (Dict) – Data to train.

Returns:

Data with loss.

Return type:

Dict

check_setup()

Check if the metric calculator has been set.

Raises:

ValueError – If the metric calculator has not been set.

compute_loss(data: Dict)

Compute loss.

Parameters:

data (Dict) – Data to compute loss.

Returns:

Data with loss.

Return type:

Dict

compute_metrics(data: Dict)

Compute metrics.

Parameters:

data (Dict) – Data to compute metrics.

Returns:

Data with metrics.

Return type:

Dict

abstract log_params()
ml_parse_data(data: Dict)

Parse data for machine learning training.

Parameters:

data (Dict) – Data to parse.

Returns:

Parsed data.

Return type:

Dict

ml_train(data: Dict)

Machine learning train.

Parameters:

data (Dict) – Data to train.

Returns:

Data with loss.

Return type:

Dict

parse_data(data: Dict)

Parse data.

Parameters:

data (Dict) – Data to parse.

Returns:

Parsed data.

Return type:

Dict

predict(data: Dict)

Predict.

Parameters:

data (Dict) – Data to predict.

Returns:

Data with predictions.

Return type:

Dict

prepare_for_metric(data: Dict)

Prepare data for metric calculation.

Parameters:

data (Dict) – Data to prepare.

Returns:

Data with unified type.

Return type:

Tuple[torch.Tensor, torch.Tensor]

set_metric_calculator(metric_calculator: MetricCalculator)

Set metric calculator.

Parameters:

metric_calculator (MetricCalculator) – A metric calculator.

set_pipelines()

Reset pipelines for training, validation, and testing.

test_step(batch: Union[Dict[str, Any], DataLoader, Tensor], batch_idx: int)

Test step.

Parameters:
  • batch (Union[Dict[str, Any], DataLoader, torch.Tensor]) – Batch to test.

  • batch_idx (int) – Batch index.

training_step(batch: Union[Dict[str, Any], DataLoader, Tensor], batch_idx: int)

Training step.

Parameters:
  • batch (Union[Dict[str, Any], DataLoader, torch.Tensor]) – Batch to train.

  • batch_idx (int) – Batch index.

Returns:

Data with loss.

Return type:

Dict

unify_type(x: Any)

Unify type - x to torch.Tensor.

Parameters:

x (Any) – Data to unify type.

Returns:

Data with unified type.

Return type:

torch.Tensor

validation_step(batch: Union[Dict[str, Any], DataLoader, Tensor], batch_idx: int)

Validation step.

Parameters:
  • batch (Union[Dict[str, Any], DataLoader, torch.Tensor]) – Batch to validate.

  • batch_idx (int) – Batch index.

art.decorators module

class art.decorators.BatchSaver(how_many_batches=10, image_key_name='input')

Bases: object

Save images from batch to debug_images folder

__init__(how_many_batches=10, image_key_name='input')
Parameters:
  • how_many_batches (int, optional) – How many batches to save. Defaults to 10.

  • image_key_name (str, optional) – under what . Defaults to “input”.

class art.decorators.EvolutionSaver(wanted_class_id: int)

Bases: object

Track evolution of logits for a given class

__init__(wanted_class_id: int)
Parameters:

wanted_class_id (int) – Which class to track

visualize()

Visualizes evolution of logits for a given class.

class art.decorators.LogInputStats(suppress_stdout=True, custom_logger=None)

Bases: object

Log input stats to art logger

__init__(suppress_stdout=True, custom_logger=None)
Parameters:
  • suppress_stdout (bool, optional) – Whether to suppress stdout. Defaults to True.

  • custom_logger (_type_, optional) – By default art_logger will be used. You can pass your custom logger if you want. Defaults to None.

class art.decorators.ModelDecorator(funcion_name: str, input_decorator: Optional[Callable] = None, output_decorator: Optional[Callable] = None)

Bases: object

funcion_name: str
input_decorator: Optional[Callable] = None
output_decorator: Optional[Callable] = None
art.decorators.art_decorate(functions: List[Tuple[object, str]], input_decorator: Optional[Callable] = None, output_decorator: Optional[Callable] = None)

Decorates list of objects functions. It doesn’t modify output of a function put can be used for logging additional information during training.

Parameters:
  • functions (List[Tuple[object, str]]) – List of tuples of objects and methods to decorate.

  • function_in (function, optional) – Function applied on the input. Defaults to None.

  • function_out (function, optional) – Function applied on the output. Defaults to None.

art.decorators.art_decorate_single_func(visualizing_function_in=None, visualizing_function_out=None)

Decorates input and output of a function.

Parameters:
  • function_in (function, optional) – Function applied on the input. Defaults to None.

  • function_out (function, optional) – Function applied on the output. Defaults to None.

Returns:

Decorated function.

Return type:

function

art.loggers module

class art.loggers.LoggerFlags(value)

Bases: Enum

An enumeration.

SUPRESS_STDOUT = 'supress_stdout'
class art.loggers.NeptuneLoggerAdapter(*args, **kwargs)

Bases: NeptuneLogger

This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers.

add_tags(tags: Union[List[str], str])

Adds tags to the Neptune experiment.

Parameters:

tags (Union[List[str], str]) – Tag or list of tags to add.

download_ckpt(id: str, name: Optional[str] = None, type: str = 'last', path: str = './checkpoints')

Downloads a checkpoint from Neptune.

Parameters:
  • id (str) – Run ID.

  • name (str, optional) – Name of the checkpoint. Defaults to None.

  • type (str, optional) – Type of the checkpoint. Defaults to “last”.

  • path (str, optional) – Path to download checkpoint to. Defaults to “./checkpoints”.

Raises:
  • Exception – If the checkpoint does not exist.

  • Exception – If the type is not “last” or “best”.

Returns:

Path to downloaded checkpoint.

Return type:

str

log_config(configFile, path: str = 'hydra/config')

Logs a config file to Neptune.

Parameters:
  • configFile (str) – Path to config file.

  • path (str, optional) – Path to log config file to. Defaults to “hydra/config”.

log_figure(figure, path: str = 'figure')

Logs a figure to Neptune.

Parameters:
  • figure (Any) – Figure to log.

  • path (str, optional) – Path to log figure to. Defaults to “figure”.

log_img(image, path: str = 'image')

Logs an image to Neptune.

Parameters:
  • image (np.ndarray) – Image to log.

  • path (str, optional) – Path to log image to. Defaults to “image”.

stop()
class art.loggers.WandbLoggerAdapter(*args, **kwargs)

Bases: WandbLogger

This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers. Logging plots in Wandb supports Plotly only. If you want to log matplotlib figures, you need to convert them to Plotly first or log them as images.

add_tags(tags: Union[List[str], str])

Adds tags to the Wandb run.

Parameters:

tags (Union[List[str], str]) – Tag or list of tags to add.

log_config(configFile: str)

Logs a config file to Wandb.

Works only when run as an admin.

Parameters:

configFile (str) – Path to config file.

log_figure(figure, path='figure')

Logs a figure to Wandb.

Parameters:
  • figure (Any) – Figure to log.

  • path (str, optional) – Path to log figure to. Defaults to “figure”.

log_img(image, path: Union[str, ndarray] = 'image')

Logs an image to Wandb.

Parameters:
  • image (np.ndarray) – Image to log.

  • path (str, optional) – Path to log image to. Defaults to “image”.

art.loggers.add_logger(log_file_path: Path) int
art.loggers.get_new_log_file_name(run_id: str) str
art.loggers.get_run_id() str
art.loggers.log_yellow_warning(message: str)
art.loggers.remove_logger(logger_id: int)
art.loggers.supress_stdout(current_logger: Logger) Logger

art.metrics module

class art.metrics.DefaultMetric

Bases: object

Placeholder for a default metric.

class art.metrics.DefaultModel

Bases: object

Placeholder for a default model.

class art.metrics.MetricCalculator(experiment: ArtProject)

Bases: object

Facilitates the management and application of metrics during different stages of training.

This class makes preparing templates for different kinds of projects easy.

add_metrics(metric: Any)

Add metrics to the list.

Parameters:

metric (Any) – The metric to add.

build_name(metric: Any) str

Builds a name for the metric based on its type, current stage.

Parameters:

metric (Any) – The metric being calculated.

compile(skipped_metrics: List[SkippedMetric])

Organize metrics based on stages, skipping specified ones.

Parameters:

skipped_metrics (List[SkippedMetric]) – A list of SkippedMetric instances.

to(device: str)

Move all metrics to a specified device.

Parameters:

device (str) – The device to move the metrics to.

class art.metrics.SkippedMetric(metric, stages: List[str] = ['train', 'validate'])

Bases: object

Represents a metric that should be skipped during certain training stages.

art.metrics.build_metric_name(metric: Any, stage: str) str

Builds a name for the metric based on its type and given training stage.

Parameters:
  • metric (Any) – The metric being calculated.

  • stage (str) – The current stage of training.

art.project module

class art.project.ArtProject(name: str, datamodule: LightningDataModule, use_metric_calculator: bool = True, **kwargs)

Bases: object

Represents a single Art project, encapsulating steps, state, metrics, and logging.

__init__(name: str, datamodule: LightningDataModule, use_metric_calculator: bool = True, **kwargs)

Initialize an Art project.

Parameters:
  • name (str) – The name of the project.

  • datamodule (L.LightningDataModule) – Data module to be used in this project.

  • use_metric_calculator (bool) – Whether to use the metric calculator.

  • **kwargs – Additional keyword arguments.

add_step(step: Step, checks: Optional[List[Check]] = [], skipped_metrics: List[SkippedMetric] = [])

Add a step to the project.

Parameters:
  • step (Step) – The step to be added.

  • checks (List[Check]) – A list of checks associated with the step.

  • skipped_metrics (List[SkippedMetric]) – A list of metrics to skip for this step.

check_checks(step: Step, checks: List[Check])

Validate if all checks pass for a given step.

Parameters:
  • step (Step) – The step to check.

  • checks (List[Check]) – List of checks to validate.

Raises:

Exception – If any of the checks fail.

check_if_must_be_run(step: Step, checks: List[Check]) bool

Check if a given step needs to be executed or if it can be skipped.

Parameters:
  • step (Step) – The step to check.

  • checks (List[Check]) – List of checks to validate.

Returns:

True if the step must be run, False otherwise.

Return type:

bool

fill_step_states(step: Step)

Update step states with the results from the given step.

Parameters:

step (Step) – The step whose results need to be recorded.

get_step(step_id: int) Step

Retrieve a specific step by its ID.

Parameters:

step_id (int) – The ID of the step to retrieve.

Returns:

The specified step.

Return type:

Step

get_steps()

Retrieve all steps in the project.

Returns:

List of steps.

Return type:

List[Dict[str, Any]]

print_summary()

Prints a summary of the project.

register_metrics(metrics: List[Any])

Register metrics to the project.

Parameters:

metrics (List[Any]) – A list of metrics to be registered.

replace_step(step: Step, step_id: int = -1)

Replace an existing step with a new one.

Parameters:
  • step (Step) – The new step.

  • step_id (int) – The ID of the step to replace. Default is the last step.

run_all(force_rerun=False, model_decorators: List[ModelDecorator] = [], trainer_kwargs: Dict[str, Any] = {})

Execute all steps in the project.

Parameters:
  • force_rerun (bool) – Whether to force rerun all steps.

  • model_decorators (List[ModelDecorator]) – List of model decorators to be applied.

run_step(step: Step, skipped_metrics: List[SkippedMetric], model_decorators: List[ModelDecorator], trainer_kwargs: Dict[str, Any], run_id: str)

Run a given step.

Parameters:
  • step (Step) – The step to run.

  • skipped_metrics (List[SkippedMetric]) – List of metrics to skip for this step.

  • model_decorators (List[Tuple(str, Callable)]) – List of model decorators to be applied.

  • run_id (str) – The ID of the run.

update_datamodule(datamodule: LightningDataModule)

Update the data module of the project.

Parameters:

datamodule (L.LightningDataModule) – New data module to be used in the project.

class art.project.ArtProjectState

Bases: object

current_stage: TrainingStage = 'train'
current_step: Optional[Step]
get_current_stage()

Gets current stage

Returns:

Current stage

Return type:

TrainingStage

get_current_step()

Gets current step

Returns:

Current step

Return type:

Step

get_steps()

Returns all steps that were run

Returns:

[description]

Return type:

Dict[str, Dict[str, Dict[str, str]]]

status: str
A class for managing the state of a project.

steps:{ “model_name”: {

“step_name”: {/step state/}, “step_name2”: {/step state/},

} “model2_name: {

“step_name”: {/step state/}, “step_name2”: {/step state/},

}

}

step_states: Dict[str, Dict[str, Dict[str, str]]]

art.steps module

class art.steps.CheckLossOnInit(model: ArtModule)

Bases: ModelStep

This step checks whether the loss on init is as expected

description = 'Checks loss on init'
do(previous_states: Dict)

This method checks loss on init. It validates the model on the train dataloader and checks whether the loss is as expected.

Parameters:

previous_states (Dict) – previous states

name = 'Check Loss On Init'
requires_ckpt_callback = False
class art.steps.EvaluateBaseline(baseline: ArtModule, device: Optional[str] = 'cpu')

Bases: ModelStep

This class takes a baseline and evaluates/trains it on the dataset

description = 'Evaluates a baseline on the dataset'
do(previous_states: Dict)

This method evaluates baseline on the dataset

Parameters:

previous_states (Dict) – previous states

name = 'Evaluate Baseline'
requires_ckpt_callback = False
class art.steps.ExploreData

Bases: Step

This class checks whether we have some markdown file description of the dataset + we implemented visualizations

description = 'This step allows you to perform data analysis and extract information that is necessery in next steps'
name = 'Data analysis'
class art.steps.ModelStep(model_class: ArtModule, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [], logger: Optional[Logger] = None)

Bases: Step

A specialized step in the project, representing a model-based step.

__init__(model_class: ArtModule, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [], logger: Optional[Logger] = None)

Initialize a model-based step.

Parameters:
  • model_class (ArtModule) – The model’s class associated with this step.

  • trainer_kwargs (Dict, optional) – Arguments to be passed to the trainer. Defaults to {}.

  • model_kwargs (Dict, optional) – Arguments to be passed to the model. Defaults to {}.

  • model_modifiers (List[Callable], optional) – List of functions to be applied to the model. Defaults to [].

  • datamodule_modifiers (List[Callable], optional) – List of functions to be applied to the data module. Defaults to [].

  • logger (Optional[Logger], optional) – Logger to be used. Defaults to None.

check_ckpt_callback(trainer_kwargs: Dict)
get_check_stage() str

Get the validation stage value from the TrainingStage enum.

Returns:

Validation stage value.

Return type:

str

get_current_stage() str

Retrieve the current training stage of the trainer.

Returns:

Current training stage.

Return type:

str

get_full_step_name() str

Retrieve the step ID, combining model name (if available) with the index.

Returns:

The step ID.

Return type:

str

get_hash() str

Compute a hash based on the source code of the step’s class.

Returns:

MD5 hash of the step’s source code.

Return type:

str

get_trainloader()
get_valloader()
initialize_model() Optional[ArtModule]

Initializes the model.

log_data_params()
log_model_params(model)
requires_ckpt_callback = True
reset_trainer(logger: Optional[Logger] = None, trainer_kwargs: Dict = {})

Reset the trainer. :param trainer_kwargs: Arguments to be passed to the trainer. :type trainer_kwargs: Dict :param logger: Logger to be used. Defaults to None. :type logger: Optional[Logger], optional

test(trainer_kwargs: Dict)

Test the model using the provided trainer arguments.

Parameters:

trainer_kwargs (Dict) – Arguments to be passed to the trainer for testing the model.

train(trainer_kwargs: Dict)

Train the model using the provided trainer arguments.

Parameters:

trainer_kwargs (Dict) – Arguments to be passed to the trainer for training the model.

validate(trainer_kwargs: Dict)

Validate the model using the provided trainer arguments.

Parameters:

trainer_kwargs (Dict) – Arguments to be passed to the trainer for validating the model.

class art.steps.NoModelUsed

Bases: object

class art.steps.Overfit(model: ArtModule, logger: Optional[Logger] = None, max_epochs: int = 1)

Bases: ModelStep

This step tries to overfit the model

description = 'Overfits model'
do(previous_states: Dict)

This method overfits the model

Parameters:

previous_states (Dict) – previous states

get_check_stage()

Returns check stage

log_model_params(model)
name = 'Overfit'
class art.steps.OverfitOneBatch(model: ArtModule, number_of_steps: int = 50)

Bases: ModelStep

This step tries to Overfit one train batch

description = 'Overfits one batch'
do(previous_states: Dict)

This method overfits one batch

Parameters:

previous_states (Dict) – previous states

get_check_stage()

Returns check stage

log_model_params(model)
name = 'Overfit One Batch'
requires_ckpt_callback = False
class art.steps.Regularize(model: ArtModule, logger: Optional[Logger] = None, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [])

Bases: ModelStep

This step tries applying regularization to the model

__init__(model: ArtModule, logger: Optional[Logger] = None, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [])
Parameters:
  • model (ArtModule) – model

  • logger (Logger, optional) – logger. Defaults to None.

  • trainer_kwargs (Dict, optional) – Kwargs passed to lightning Trainer. Defaults to {}.

  • model_kwargs (Dict, optional) – Kwargs passed to model. Defaults to {}.

  • model_modifiers (List[Callable], optional) – model modifiers. Defaults to [].

  • datamodule_modifiers (List[Callable], optional) – datamodule modifiers. Defaults to [].

check_if_already_tried()

The idea of this function is to help the project decide if there is any more reason the step shouldn’t be run even though it has failed.

continue_on_failure = True
description = 'Regularizes model'
do(previous_states: Dict)

This method regularizes the model

Parameters:

previous_states (Dict) – previous states

get_latest_run() Dict

If step was run returns itself, otherwise returns the latest run from the JSONStepSaver.

In case of regularization we are interested in the latest successful run.

Returns:

The latest run.

Return type:

Dict

name = 'Regularize'
save_to_disk()
set_successful()
stringify_modifiers(modifiers: List[Callable])
update_was_already_tried()

This method verify if such Regularize parameters was already tried and updates self.was_already_tried

class art.steps.Squeeze(model_class: ArtModule, trainer_kwargs: Dict = {}, model_kwargs: Dict = {}, model_modifiers: List[Callable] = [], datamodule_modifiers: List[Callable] = [], logger: Optional[Logger] = None)

Bases: ModelStep

class art.steps.Step

Bases: ABC

An abstract base class representing a generic step in a project.

__init__()

Initialize the step with an empty results dictionary.

add_result(name: str, value: Any)

Add a result to the step’s results dictionary.

Parameters:
  • name (str) – Name of the result.

  • value (Any) – Value of the result.

check_if_already_tried()

The idea of this function is to help the project decide if there is any more reason the step shouldn’t be run even though it has failed.

continue_on_failure = False
abstract do(previous_states: Dict)

Abstract method to execute the step. Must be implemented by child classes.

Parameters:

previous_states (Dict) – Dictionary containing the previous step states.

fill_basic_results()

Fill basic results like hash and commit id

get_full_step_name() str

Retrieve the full name of the step, which is a combination of its ID and name.

Returns:

The full step name.

Return type:

str

get_hash() str

Compute a hash based on the source code of the step’s class.

Returns:

MD5 hash of the step’s source code.

Return type:

str

get_latest_run() Dict

If step was run returns itself, otherwise returns the latest run from the JSONStepSaver.

Returns:

The latest run.

Return type:

Dict

is_successful()
model = <art.steps.NoModelUsed object>
name = 'Data analysis'
save_to_disk()
set_successful()
was_run() bool

Check if the step was already executed based on the existence of saved results.

Returns:

True if the step was run, otherwise False.

Return type:

bool

class art.steps.TransferLearning(model: ArtModule, model_modifiers: List[Callable] = [], logger: Optional[Logger] = None, freezed_trainer_kwargs: Dict = {}, unfreezed_trainer_kwargs: Dict = {}, freeze_names: Optional[list[str]] = None, keep_unfrozen: Optional[int] = None, fine_tune_lr: float = 1e-05, fine_tune: bool = True)

Bases: ModelStep

This step tries performing proper transfer learning

__init__(model: ArtModule, model_modifiers: List[Callable] = [], logger: Optional[Logger] = None, freezed_trainer_kwargs: Dict = {}, unfreezed_trainer_kwargs: Dict = {}, freeze_names: Optional[list[str]] = None, keep_unfrozen: Optional[int] = None, fine_tune_lr: float = 1e-05, fine_tune: bool = True)

This method initializes the step

Parameters:
  • model (ArtModule) – model

  • model_modifiers (List[Callable], optional) – model modifiers. Defaults to [].

  • logger (Logger, optional) – logger. Defaults to None.

  • freezed_trainer_kwargs (Dict, optional) – trainer kwargs use for transfer learning with freezed weights. Defaults to {}.

  • unfreezed_trainer_kwargs (Dict, optional) – trainer kwargs use for fine tuning with unfreezed weights. Defaults to {}.

  • freeze_names (Optional[list[str]], optional) – name of model to freeze which appears in layers. Defaults to None.

  • keep_unfrozen (Optional[int], optional) – number of last layers to keep unfrozen. Defaults to None.

  • fine_tune_lr (float, optional) – fine tune lr. Defaults to 1e-5.

  • fine_tune (bool, optional) – whether or not perform fine tuning. Defaults to True.

add_freezing()

Adds freezing to the model

add_lr_change()

Adds lr change to the model

add_unfreezing()

Adds unfreezing to the model

description = 'This step tries performing proper transfer learning'
do(previous_states: Dict)

This method trains the model :param previous_states: previous states :type previous_states: Dict

get_check_stage()

Returns check stage

name = 'TransferLearning'
class art.steps.Tune(model: ArtModule, logger: Optional[Logger] = None)

Bases: ModelStep

This step tunes the model

description = 'Tunes model'
do(previous_states: Dict)

This method tunes the model

Parameters:

previous_states (Dict) – previous states

name = 'Tune'

Module contents