How to Train a Continual Learning Model

   ​

We dive deep into different approaches to continual learning, a method to add new information to a trained model while preserving its existing knowledge.​

​[[{“value”:”Photo by Google DeepMind / Unsplash

Consider a self-driving automobile that drives itself through the city but has trouble in a construction zone. Or think about a medical diagnosis system that does a great job of detecting well-known diseases but misses a novel variation. These situations draw attention to a crucial problem in artificial intelligence: machine learning models’ capacity to adjust to new data without losing track of what they have already learned.

The conventional ML models learn comprehensive distribution and patterns from the training data and use them to generate and classify new data points. However, if new data is fed into the model which has a slightly different distribution they tend to perform poorly. To make these models adaptable to new data the model has to be retrained again on a new set of data. This of course is time-consuming and potentially expensive.

When it comes to us humans, we are capable of learning complex matters while preserving the old ones – we learn continuously. This allows us to be adaptable and learn things faster. However, this is not true for machine learning models, at least the conventional ones.

But what if we leverage the same approach to train ML models as we humans learn daily? What if we use the same approach where we systematically and continuously add new information while preserving the existing knowledge?

Yes, we can. An approach called continual learning, also known as continuous learning in machine learning, is revolutionizing how AI systems continuously adapt and improve over time.

In this article, we will see how we can use continual learning in ML to make a model more adaptable to new information while preserving the existing information. We will discuss the following:

What is continual learning?Benefits of continual learningTypes of continual learningApproaches to continual learningChallenges in continual learningApplication of continual learning

So let’s get started.

What is continual learning?

Continual learning is known by few other titles as well such as continuous learning, incremental learning, and lifelong learning. But for the sake of this article we will use continual learning.

In AI continual learning is the process of injecting or adding new information to a trained model while preserving the old information, mimicking human cognitive processes.

Traditional learning suffers from what is known as “catastrophic forgetting” where the new information overwrites the old one when trained on the new data. This mainly happens due to the lack of mechanisms that could preserve the memory.

Essentially, when the network is trained on new data, it adjusts its parameters to accommodate the new information; this erases the old information which leads the model to lose its old abilities. For instance,  if a model trained to recognize dogs is subsequently trained on birds, it may lose its ability to identify dogs due to the weight adjustments made for the new task. One general solution is to make a new training dataset that has new information. But that would be slow and training a model from scratch will be slower.

Key Points to Understand:

Weight Adjustment: When an ML model or neural network learns a new task, it adjusts its weights to better fit the new data. If these adjustments are substantial, they can overwrite the previously learned representations associated with previous tasks.Loss of Performance: Because of overwriting, the model may lose old abilities and perform poorly on previous tasks, even if it has become proficient at the new task.Interference: The interference happens because the model does not have a built-in mechanism to differentiate between the importance of weights for different tasks, leading to a situation where learning one task can negatively impact the performance on another

This is where the concept of continual learning comes into the picture. But continual learning also faces the issue of catastrophic forgetting. So how does continual learning overcome the issue of catastrophic learning?
One of the primary approaches is a trade-off between learning plasticity and memory stability known as the stability-plasticity tradeoff. Essentially, forgetting means erasing data, and learning means storing data. Stability-plasticity refers to a trade-off where we decide how much knowledge the model should retain and how much it should erase to accommodate for the new data. We cannot store new information without sacrificing the old ones.

A general framework of continual learning where a tradeoff between stability and plasticity is emphasized to achieve generalizability. | Source: Wang et al. (2024)Beyond simply balancing the “proportions” of these two aspects, a desirable solution for continual learning should obtain strong generalizability to accommodate distribution differences within and between tasks. – Wang et al. (2024)

Continual learning allows the ML system to evolve adaptively.

Benefits of continual learning

Continual learning has several advantages over conventional ML systems. These systems can offer better generalizability, robustness to unseen data, and efficiency. In this section, we will discuss the advantages of continual learning systems over conventional ML systems.

Incremental Learning

Incremental learning refers to sequentially injecting new information without completely retraining the model. This type of learning makes the models more efficient by saving time and computational resources. It also enables the model to continuously update the information within the same task domain, allowing for real-time adaptation to gradual changes.

For instance, in load forecasting, where we can predict how much electricity is required at the given time, incremental learning can be useful for energy management. It allows the model to continuously update with new load and weather forecasting to provide real-time prediction. This ensures that the power grid provides electricity with higher efficiency to meet consumption needs. This in turn avoids power wastage and inefficiency.

Knowledge Retention

While incremental learning allows the continual learning system to learn new information sequentially, knowledge retention allows to it preserve old and valuable information. It addresses the issue of catastrophic forgetting by employing techniques like stability-plasticity (discussed below) to maintain optimal performance on previously learned tasks. It also ensures a broad base of information when solving problems. Retained knowledge can be leveraged to learn new and related tasks more quickly and effectively. This enhances overall learning efficiency.

For instance, let’s assume that we train an AI model to recognize different dog breeds only by learning general dog features like ear shape and fur texture. Now, if we then train it to recognize cat breeds, it can reuse the general information that it obtained from the previous dataset and apply it to learn a new dataset. So, rather than training from scratch, the AI builds upon its retained knowledge – features like ear shape and fur texture. This allows the model to learn cat breed recognition much faster.

Adaptation

We saw that continual learning excels in real-time updates and knowledge retention, it also enables it to adapt well to the new information such that it can generalize well. While incremental learning injects data sequentially to learn new information, adaptation refers to the model’s ability to adjust its learned information and strategies to new situations and changing environments. This makes it adaptable to changing environments, which is crucial in many real-world applications. The ability of continuous learning models to adjust to new data and distribution is essential for numerous practical uses.

It is important to understand that adaptation is a consequence of incremental learning. However, adaptation itself requires additional capabilities beyond incremental learning. This includes ability to recognize when learned strategies or approaches need to be modified. This requires figuring out how to tweak the strategies, and generalizing knowledge to new domains.

For instance, if a model is trained for daytime street scenes could adapt for night or even severe weather conditions. This can be achieved by changing its understanding or distribution to visual features.

Applications of continual learning

Now, let’s discuss some of the applications of continual learning in the real-world use case:

Autonomous Vehicles: Without requiring total retraining, self-driving cars may learn to adapt to different traffic patterns, road conditions, and construction zones. This is one major benefit of continuous learning. This enhances performance and safety by enabling them to instantly refresh their knowledge base.Healthcare: As new research and findings become available, ML models in medical diagnostics can incorporate this information and patterns while still being able to diagnose well-established disorders thanks to continuous learning. The ability to learn dynamically is essential for the early diagnosis and treatment of new diseases.Robotics: Robots with continuous learning capabilities can adjust to novel surroundings and tasks without losing their ability to do previously acquired tasks. This is especially helpful for industrial automation, where robots must carry out a range of activities in various environments.Natural Language Processing (NLP): Over time, language models can keep their relevance and accuracy by continuously learning new terms, and slang, and changing language usage.Finance: By continuously learning from new fraudulent behaviors and adjusting to new strategies without forgetting earlier fraud patterns, continuous learning aids in fraud detection systems. Financial security measures become more reliable and resilient as a result of this real-time learning.

It is important to know that continuous active learning enables AI models to update their knowledge base in real-time, making them more responsive to changing environments.

Now that we are aware of the benefits and application of continuous learning, let’s examine the challenges followed by various approaches that solve these challenges in the subsequent section.

Challenges in continual learning

In this section we will learn about the three major challenges of continual learning that is catastrophic forgetting, stability-plasticity trade-off, and task boundary detection. These challenges highlight the need of various approaches that are required to develop robust and agile models without the need of retraining from scratch. But, first let us understand what these challenges are and how they impact AI models.

Catastrophic Forgetting

One of the most prevailing challenges in continual learning is catastrophic forgetting. This is also a challenge in any neural network where they tend to forget previously learned information when trained on new tasks. Traditional machine learning models excel at learning comprehensive patterns from a static dataset, but when exposed to a sequential stream of data, they often struggle to retain old knowledge. This happens because the model weights, adjusted to minimize the loss for new tasks, inadvertently overwrite the weights that encoded the previous tasks. This results in a dramatic loss of performance on earlier tasks, making it difficult for the model to function effectively across a variety of tasks without substantial retraining.

Stability-Plasticity Trade-Off

The stability-plasticity captures a trade-off between stability and plasticity. To recap what we learned previously, stability refers to preserving existing knowledge or information and plasticity refers to integrating new knowledge in a continual learning framework. While plasticity enables the model to adjust to new challenges, stability guarantees that it retains previously learned knowledge. On the other hand, stressing stability might make it more difficult for the model to learn new information, while increasing flexibility frequently results in catastrophic forgetting.

The trade-off strategies to maintain an optimal balance, allowing the model to learn continuously without compromising on either front. Approaches like Elastic Weight Consolidation and Synaptic Intelligence (which we discussed previously) aim to address this dilemma by selectively consolidating important weights while allowing flexibility for new tasks.

Task Boundary Detection

Detecting task boundaries is another critical challenge in continual learning. Unlike traditional settings where the start and end of tasks are predefined, real-world applications often present a continuous stream of data without clear task distinction. The ability to recognize when a new task begins is essential for updating the model appropriately without causing interference with previously learned tasks.

Effective task boundary detection enables the model to allocate resources dynamically and apply appropriate learning strategies, ensuring efficient and accurate adaptation to new information. Methods such as unsupervised learning techniques and adaptive learning rate adjustments are being explored to improve task boundary detection and facilitate smoother transitions between tasks.

To understand this let’s consider  an LLM-based chatbot like ChatGPT or Claude. Let’s assume that these models engage in open-ended conversations with users.  Here it needs to understand and recognise when the user changes the topic. For example if the user asks a question about the political issues in the country and then switches to understanding the rules concerning boxing or javelin throw then the model needs to be prepared. This preparedness can allow the model to recognise the two different tasks and establish a boundary there and not mix the two topics together.

Approaches to continual learning

Continual learning leverages various approaches to preserve previous information and carefully inject new and valuable information without disrupting the existing information.

A hierarchical diagram of various approaches and sub-approaches in continual learning | Source: Wang et al. (2024)c

Regularization-Based Approach

The regularization-Based Approach focuses on regularizing or penalizing the update of the models’ important parameters. By doing this we prevent overwriting of previously learned information. Some of the techniques used are Elastic Weight Consolidation (EWC), Learning without Forgetting (LwF), and Synaptic Intelligence (SI):

EWC estimates the importance of each parameter for previously learned tasks and penalizes changes to important parameters when learning new tasks.LwF essentially uses knowledge distillation where the knowledge from a larger is transferred to a smaller model without significant loss in the performance. Knowledge distillation allows to preserve and retain performance on old tasks while learning new ones.SI accumulates information from each parameter over time and uses this information to new memories. These techniques aim to find a balance between plasticity for learning new tasks and stability for retaining old knowledge.

These methods typically require storing a frozen copy of the old model for reference.

Replay-Based Approach

The replay-based approach is inspired by how the brain preserves important and meaningful memories. Essentially, to do that the brain replays memories from the past to strengthen the neural pathways during sleep. This ensures that the old memory is preserved and can be used to learn new experiences.

A flowchart of the replay mechanism of a neural network | Source: Deepmind

Similarly, in machine learning, replay-based methods store and reuse previous information to maintain and update learned knowledge without forgetting.

Some of the methods used to leverage replay-based learning are experience replay, generative replay, and feature replay:

Experience Replay stores a subset of data from previous tasks and periodically retrains on this data along with new task data. The data consists of both previous and and new task. The idea is to provide high quality representations from both the tasks.Generative Replay uses generative models to create synthetic data representing previous tasks, avoiding the need to store real data.Feature replay is an important approach where data privacy is crucial. It uses trained generative models to replay images that represent the same hidden features as real samples. This maintains feature-level distributions instead of raw data

Optimization-Based Approach

The previous two groups of methods – regularization and replay-based approaches – worked by adding additional terms to the loss functions. In an optimization-based approach, we will see how we can design and modify optimization techniques to learn new tasks while preserving the old ones.

This focuses on finding optimal parameter updates that benefit both new and old tasks. Using this approach you can develop robust optimization techniques. Some of the techniques involve gradient projection, meta-learning, and loss landscape optimization.

Gradient projection modifies the models’ parameter updates to align with certain directions. Generally, it involves a model to learn a new task by taking steps in an orthogonal direction to the gradient subspaces that are important for the previous tasks. Essentially, you prevent the model from disturbing the already established gradients.Meta-learning is also known as “learning-to-learn”. It aims to focus on developing models that can quickly adapt to new tasks by learning from their own learning experiences. Here, the model tries to learn an optimal learning algorithm or strategy that allows it to adapt quickly to new tasks. In continual learning it enables the model to quickly adapt to new tasks while remembering previous ones. The result is a model that can efficiently learn and adapt in dynamic environments with minimal data for new tasks.Loss landscape optimization refines the optimization process to find flatter minima. Essentially it tries to find areas where the model’s performance is good and stable, even if small changes occur. These stable areas or flatter minima help the model perform well on both old and new tasks by preventing it from greater loss, which is key for continual learning. By aiming for these stable and flatter areas, the model can learn new things without forgetting what it already knows, making it better at handling a series of different tasks over time.

Representation-Based Approach

The representation-Based Approach enables models to learn and maintain good feature representations that are useful across multiple tasks. This approach works by:

Incorporating self-supervised learning to obtain more generalizable features: This helps the model learn useful patterns from data without needing labels. The model in this scenario generates its own label by masking some part of the input followed by predicting the entire input. Essentially, the goal is to learn intrinsic patterns, representation, and structures in the data. This makes it easier to adapt to new tasks or generalize well.Leveraging pre-training to provide a strong initial representation: Pre-training which is an unsupervised learning involves training the model on a large, diverse dataset. The goal of pre-training is to provide the model with a strong initial set of parameters and representations that can be adapted to specific tasks with minimal fine-tuning. While pre-training involves self-supervised learning to learn generic representations, it also has the additional goal of providing a good starting point for fine-tuning on downstream tasks.Exploring continual pre-training to incrementally update large-scale models: In case of large models such as LLMs continual pre-training is used to keep the model up-to-date by feeding it new information over time. This ensures that the model’s knowledge stays current and relevant.

Architecture-Based Approach

The architecture-Based Approach involves modifying the network architecture to accommodate new tasks while preserving old knowledge. There are several approaches like parameter allocation, model decomposition, and modular networks that can be leveraged to enhance continual learning.

Parameter allocation or regularization method assigns dedicated parameters to each task. Essentially, in this approach the model adds a regularization term to the loss function to handle catastrophic forgetting. By doing this the model preserves those parameters that share similarities in old and new task. In simpler terms, each parameter is allocated to specific task.Model decomposition separates the model into task-shared and task-specific components.Modular networks use parallel sub-networks or sub-modules for different tasks.

These methods aim to reduce interference between tasks by design, balancing the preservation of task-specific knowledge with the potential for cross-task synergies. They aim to manage the trade-off between these benefits and the resulting model complexity.

Types of continual learning

In the earlier section, we saw the various approaches to continual learning. We must be aware that approaches are methods used to preserve learned information while incorporating new knowledge. On the other hand, types refer to different scenarios or settings in which continual learning is applied. Each has its own challenges and requirements based on task characteristics and information availability.

Diagram showing the mapping of each continual learning type with its respective scenario and its associated approaches | Source: Author

Now, let’s explore the different types of continual learning and in which scenarios they are used, what challenges they face, and a real-world application.

Instance-Incremental Learning

With instance-incremental learning, additional training instances are added to the model as they become available, without requiring access to the entire initial dataset. This is a form of incremental learning. Essentially, this implies that rather than retraining on the complete dataset, the model continuously learns from a streaming non-stationary data source, adjusting its knowledge in accordance with new information.

Scenario: When the training samples belonging to the same task arrive sequentially in batches over time.Challenge: The model must continually learn new information as new samples are fed into it, without forgetting previously learned information.Associated Approach: Replay-based approach and optimization-based approachExample: Updating a customer feedback analysis model as new reviews come in daily.

Domain-Incremental Learning

In domain-incremental learning, the model is supposed to learn knowledge and skills from a new domain without disrupting knowledge and skills from the previous domain.

Scenario: When the tasks have different inputs but share the same label space.Challenge: The model must learn and adapt to the input distribution without losing performance on previous domains.Associated Approach: Representation-based approach and architecture-based approachExample: Adapting a speech recognition system to different accents or dialects over time.

Task-Incremental Learning

Task-Incremental Learning (TIL) is another subset of incremental learning. The distinguishing characteristic of TIL is that it assumes that the model is aware of the current task during both training and inference.

Scenario: Tasks that have disjoint label spaces, meaning that the set of possible outputs for each task is completely different. Along with that the task identities are provided during both training and testing.Challenge: The model must learn new tasks with the ability to use task-specific components during inference while maintaining its performance on previously learned tasks.Associated Approach: Architecture-based approach and regularization-based approachExample: A virtual assistant learning to perform different tasks such as setting reminders, making phone calls, checking the weather, etc one at a time, while maintaining its performance on previously learned tasks.

Class-Incremental Learning

Similar to the previous incremental learning types here the model sequentially learns to classify between different classes over time. Class-IL focuses on learning new classes that are introduced incrementally.

Scenario: When task identities are provided only during training and not testing.Challenge: The model must learn to classify all classes from all tasks without explicit task boundaries during inference, making it more challenging than TIL.Associated Approach: Replay-based approach and representation-based approachExample: A facial recognition system learns to identify new individuals incrementally, without explicit information about which group or batch each new person belongs to during inference.

Task-Free Continual Learning

Task-Free Continual Learning focuses on the ability of models to learn continuously from a stream of data without explicit task boundaries or identities.

Scenario: Here the tasks have disjoint label space, but task identities are not provided during either training or testing.Challenge: The model must infer task boundaries and adapt to new tasks autonomously, making it one of the most challenging continual learning scenarios.Associated Approach: Optimization-based approach and representation-based approachExample: A recommender system learning multiple user behaviors such as likes and dislikes for a content without knowing which behavior corresponds to which task. These behaviors can be thought of as implicit tasks that the system needs to learn and adapt to over time and thereby create a category based on the user behaviour. Once the model learns the category it can recommend it to the user. For instance, if the user like to read fictional content then the system can learn those pattern and recommend a reading list of Harry Potter, Turtle all the way down, etc.

Online Continual Learning

In Online Continual Learning (OCL) the data arrives sequentially and cannot be revisited. The idea is to emphasize real-time adaptation.

Scenario: When data arrives in a continuous stream.Challenge: The model must learn and adapt in real-time without the ability to revisit previous data points.Associated Approach: Optimization-based approach and regularization-based approachExample: Adapting a fraud detection system to evolving fraud patterns based on streaming transaction data.

Blurred Boundary Continual Learning

In this type of continual learning the boundaries between tasks or domains are not clearly defined. There are also the chances of a gradual shift in the data distribution over time. Blurred Boundary continual learning as the name suggests learns and preserves information by identifying the task in a continuous stream of data.  

Scenario: Task or domain boundaries are not clearly defined.Challenge: The model must identify the ambiguity in task definitions without explicit information.Associated Approach: Optimization-based approach and representation-based approachExample: Adapting a visual recognition system that encounters objects with evolving features over time. For instance, assume a self-driving car that has a visual recognition system that has to recognise and make-decisions by reading different kinds of road signs. The distinctions between different sign categories may become less obvious over time as a result of subtle changes in weather, lighting, or even design. The car’s visual recognition system must continuously learn to adapt to these evolving sign appearances without explicit task definitions, while still maintaining its ability to recognize and respond to each sign type accurately.

Continual Pre-training

Continual pre-training is referred to as a pre-trained model that needs to be continually updated with new knowledge while maintaining its general capabilities. Continual pre-training is done in unsupervised manner which allows is learn to representation incrementally without overwriting the previous information.

Scenario: Perform incremental updating of large pre-trained models as new pre-training data becomes available.Challenge: The goal is to improve the model’s knowledge transfer capabilities to downstream tasks without forgetting previously learned information.Associated Approach: Architecture-based Approach and replay-based approachExample: Continuously updating a language model like GPT-3 with new text data to maintain its relevance and performance.

To get a summary of this section read the table below.

Learning Type

Characteristic

Main Challenge

Example

Instance-Incremental

Continuously learns from new instances without retraining on the entire dataset.

Avoid forgetting previously learned information while integrating new data.

Updating a customer feedback model as new reviews arrive daily.

Domain-Incremental

Learns from a new domain without disrupting knowledge from previous domains.

Adapting to new input distributions without losing performance on previous domains.

Adapting speech recognition to different accents over time.

Task-Incremental

Model is aware of the current task during training and inference.

Using task-specific components during inference while maintaining performance on previous tasks.

A virtual assistant learning different tasks (e.g., setting reminders, checking weather) one at a time.

Class-Incremental

Sequentially learns to classify new classes introduced over time.

Classifying all classes from all tasks without explicit task boundaries during inference.

A facial recognition system incrementally learning to identify new individuals.

Task-Free Continual

Learns continuously from a stream of data without explicit task boundaries.

Inferring task boundaries and adapting autonomously to new tasks.

A recommender system learning user preferences without knowing which behavior corresponds to which task.

Online Continual

Data arrives sequentially, with no revisiting of previous data.

Real-time learning and adaptation without revisiting past data points.

Adapting a fraud detection system to evolving fraud patterns based on streaming transaction data.

Blurred Boundary Continual

Task or domain boundaries are not clearly defined, gradual data shifts occur.

Identifying and adapting to ambiguous task definitions without explicit information.

A visual recognition system in self-driving cars adapting to subtle changes in road signs over time.

Continual Pre-training

Continually updates a pre-trained model with new data while preserving general capabilities.

Enhancing knowledge transfer to downstream tasks without forgetting previously learned information.

Continuously updating a language model like GPT-3 with new text data to maintain relevance.

Implementation of Continual Learning

Now that we have a theoretical understanding of this concept, let us learn how we can implement continual learning in PyTorch. In particular, we will learn how we can use the Elastic Weight Consolidation approach we have discussed before.

Please note the following code snippets are only for understanding purposes. You can find the complete notebook here. I recommend that you run it for yourself and see how it all works together.

The goal of the experiment

This experiment aims to show how we can apply the Elastic Weight Consolidation (EWC) approach to tackle catastrophic forgetting in a convolution neural network. In this experiment we will be using the CIFAR-10 dataset that contains 60000 32×32 color images in 10 classes. We will also permute the dataset which will allow us to have a different distribution of the dataset. This approach will act like a different set of dataset all together.

As such the experiment will have three phases.

Phase 1: We will train the model on the original dataset and test it on the permuted dataset and evaluate its performance.

Phase 2: We will train the model on the permuted dataset and test it on the orginal dataset and evaluate its performance.

As a matter of fact the model will not perform well in both the testing conditions.

Phase 3: We will train the model on the original dataset and use the EWC approach to achieve continual learning. We will then evaluate the model on the permuted test dataset. We will also do the inverse where we will train on the permuted dataset and test on original dataset.

Also, you can use any other dataset instead of permuted dataset.

Phase 1

To start with, we will first import all the dependencies.

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

Now, let’s build a convolutional neural for the Cifar dataset.

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
full_dataset = datasets.CIFAR10(root=‘./data’, train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root=‘./data’, train=False, download=True, transform=transform)


batch_size = 64
trainloader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Let’s visualize the data.

class CifarNet(nn.Module):
  def __init__(self):
      super(CifarNet, self).__init__()
      self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
      self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
      self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
      self.pool = nn.MaxPool2d(2, 2)
      self.fc1 = nn.Linear(64 * 4 * 4, 512)
      self.fc2 = nn.Linear(512, 10)
      self.dropout = nn.Dropout(0.5)

  def forward(self, x):
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = self.pool(F.relu(self.conv3(x)))
      x = x.view(-1, 64 * 4 * 4)
      x = F.relu(self.fc1(x))
      x = self.dropout(x)
      x = self.fc2(x)
      return x

# Create an instance of the network
model = CifarNet()

Once the model is created we will now initialize the optimizer. For our optimizer we will stochastic gradient descent. Free will to use other optimizers. 

model = CifarNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

Let’s train the model now for 20 epochs. This will ensure that we get a good accuracy in both the training and the testing datasets. 


for epoch in range(1, 21):
train(model, device, trainloader, optimizer, epoch)
test(model, device, testloader)

Once the training is done we can then test the model on new data. To make a new dataset let’s just permute a version of the same Cifar dataset. For that we must define the custom function to permute it.

def permute_cifar10(trainloader, testloader, seed=1234):
  “”” Permute pixels of CIFAR-10 images in existing DataLoaders “””
  np.random.seed(seed)
  print(“Starting permutation…”)

  # Get dimensions
  h, w = 32, 32  # CIFAR-10 images are 32×32
  c = 3  # 3 color channels

  # Create permutation indices
  perm_inds = list(range(h * w))
  np.random.shuffle(perm_inds)

  def permute_tensor(x):
      # x shape: (batch_size, 3, 32, 32)
      x = x.view(x.size(0), c, h * w)
      x = x[:, :, perm_inds]
      return x.view(x.size(0), c, h, w)

  # Function to permute a dataset
  def permute_dataset(dataloader):
      permuted_data = []
      labels = []
      for batch, batch_labels in dataloader:
          permuted_batch = permute_tensor(batch)
          permuted_data.append(permuted_batch)
          labels.append(batch_labels)
      return torch.cat(permuted_data), torch.cat(labels)

  # Permute training set
  train_data, train_labels = permute_dataset(trainloader)
 
  # Permute test set
  test_data, test_labels = permute_dataset(testloader)

  print(“Permutation done.”)

  # Create new DataLoaders with permuted data
  perm_trainloader = DataLoader(TensorDataset(train_data, train_labels),
                                batch_size=trainloader.batch_size,
                                shuffle=True, num_workers=2)
  perm_testloader = DataLoader(TensorDataset(test_data, test_labels),
                                batch_size=testloader.batch_size,
                                shuffle=False, num_workers=2)

  return perm_trainloader, perm_testloader

Let’s run the function and see what the permuted dataset looks like.

perm_trainloader, perm_testloader = permute_cifar10(trainloader, testloader)

Now, let’s test both the datasets on the trained model. First, we will test the model on the first task i.e., the original dataset followed by the second task i.e., the permuted dataset.

print(“Testing on the first task:”)

test(model, device, testloader)

print(“Testing on the second task:”)
test(model, device, perm_testloader)

Testing on the first task:

Test set: Average loss: 0.8133, Accuracy: 7215/10000 (72.15%)

Testing on the second task:

Test set: Average loss: 10.3271, Accuracy: 1038/10000 (10.38%)

As you can see that model didn’t yield good results in the second task.

Phase 2

The obvious answer to this problem is retraining the entire model on the new dataset (second task).

Let’s do that!

for epoch in range(1, 21):
train(model, device, perm_trainloader, optimizer, epoch)
test(model, device, perm_trainloader)

Once the training is done we will test the model in both datasets. 

print(“Testing on the first task:”)
test(model, device, testloader)

print(“Testing on the second task:”)
test(model, device, perm_testloader)

Testing on the first task:

Test set: Average loss: 2.4465, Accuracy: 2321/10000 (23.21%)

Testing on the second task:

Test set: Average loss: 1.4507, Accuracy: 4814/10000 (54.14%)

Not that good, isn’t it?

It performs worse in the first task with an improvement in the second.

This is a clear example of catastrophic forgetting. We will use the EWC approach to tackle this. In this method, we will estimate the importance of each parameter for previously learned tasks. We will also penalize the model if it changes important parameters when learning new tasks.

Phase 3 with EWC approach

Let’s reinitialize the datasets for both tasks as before.

batch_size = 64
trainloader = DataLoader(full_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# task 1
task_1_loader = [trainloader, testloader]

# task 2
task_2_loader = [perm_trainloader, perm_testloader]

# task list
tasks = [task_1_loader, task_2_loader]

As mentioned earlier we have to compute the importance of each weight – this is done by Fisher Information.  We will also implement a squared regularization loss for penalizing changes in the most important weights for the previous tasks.

EWC calculates the L(θ) as the total loss function which can be described as:

Here,

L_B(θ) is the loss on the new task Bλ is a hyperparameter that controls the importance of the old taskF_i is the i-th diagonal element of the Fisher Information Matrix for task Aθ_i is the i-th parameter of the neural network.θ_A,i is the optimal parameter value for task AWhile training on task B, EWC makes sure that task A is retained. In the figure above schematic parameter space is used to depict training trajectories, with parameter regions corresponding to good performance on task A (gray) and task B (cream). Once the initial job is learned, the parameters are set to θA. If we follow task B’s gradient steps exclusively (blue arrow), we will reduce task B’s loss but completely overwrite the knowledge we have gained for job A. However, the restriction is too great and we are only able to recall task A at the expense of not learning task B if we constrain each weight with the same coefficient (green arrow). The idea behind EWC is to calculate how critical weights for job A (red arrow) to discover a solution for task B without suffering a major loss. | Source: Overcoming catastrophic forgetting in neural networks 

Let’s define the dictionaries that will accumulate gradients that can be used to calculate fisher information.

#fisher_dict stores the i-th diagonal element of the Fisher Information Matrix for task A
fisher_dict = {}

#optpar_dict stores the i-th parameter of the neural network
optpar_dict = {}

ewc_lambda = 1 #controls the importance of the old task

Now let’s define the ewc function that will train the model. This function will contain the loss function.

def train_ewc(model, device, task_id, x_train, t_train, optimizer, epoch):
    model.train()

    for start in range(0, len(t_train)-1, 256):
      end = start + 256
      x, y = torch.from_numpy(x_train[start:end]), torch.from_numpy(t_train[start:end]).long()
      x, y = x.to(device), y.to(device)
     
      optimizer.zero_grad()

      output = model(x)
      loss = F.cross_entropy(output, y)
     
      ### magic here! 🙂
      for task in range(task_id):
        for name, param in model.named_parameters():
          fisher = fisher_dict[task][name]
          optpar = optpar_dict[task][name]
          loss += (fisher * (optpar – param).pow(2)).sum() * ewc_lambda
     
      loss.backward()
      optimizer.step()
      #print(loss.item())
    print(‘Train Epoch: {} tLoss: {:.6f}’.format(epoch, loss.item()))

We also need to define an additional function that will compute the fisher information for each weight at the end of each task:

def on_task_update(task_id, task_loader):
  model.train()
  optimizer.zero_grad()
 
  # Accumulating gradients
  for batch, labels in task_loader:
      batch, labels = batch.to(device), labels.to(device)
      output = model(batch)
      loss = F.cross_entropy(output, labels)
      loss.backward()

  fisher_dict[task_id] = {}
  optpar_dict[task_id] = {}

  # Gradients accumulated can be used to calculate fisher
  for name, param in model.named_parameters():
      optpar_dict[task_id][name] = param.data.clone()
      fisher_dict[task_id][name] = param.grad.data.clone().pow(2)
 
  # Normalize fisher values
  fisher_normalization_factor = sum(fisher_dict[task_id][name].sum() for name in fisher_dict[task_id])
  for name in fisher_dict[task_id]:
      fisher_dict[task_id][name] /= fisher_normalization_factor

  # Clear gradients
  optimizer.zero_grad()

Let’s train the model.

​​ewc_accs = []
for id, task in enumerate(tasks):
  avg_acc = 0
  print(“Training on task: “, id)
 
  train_loader, test_loader = task
 
  for epoch in range(1, 25):
      train_ewc(model, device, id, train_loader, optimizer, epoch, ewc_lambda=ewc_lambda)
  on_task_update(id, train_loader)
 
  for id_test, test_task in enumerate(tasks):
      print(“Testing on task: “, id_test)
      _, test_loader = test_task
      acc = test(model, device, test_loader)
      avg_acc += acc
 
  print(“Avg acc: “, avg_acc / len(tasks))
  ewc_accs.append(avg_acc / len(tasks))

Results from training on the first task:

Testing on task: 0
Test set: Average loss: 0.8399, Accuracy: 7070/10000 (70.70%)
Testing on task: 1
Test set: Average loss: 10.3184, Accuracy: 1048/10000 (10.48%)
Results from training on the second task:
Testing on task: 0
Test set: Average loss: 1.7822, Accuracy: 3449/10000 (51.49%)
Testing on task: 1
Test set: Average loss: 1.7874, Accuracy: 3417/10000 (34.17%)

As you can see there is a slight improvement on the second task but the performance on the first task reduced as well.

Observations

To get more clear understanding I trained the model longer with lambda value ranging from -2.0 to 2.0. Here are some observation based on that.

The graph shows the average task accuracies for different lambda values in the EWC experiment. Lambda values: The experiment employs the lambda values ranging from -2.00 to 2.00. It can be seen in the x-axis. The Lambda is the regularization parameter in EWC that controls the trade-off between retaining knowledge from the first task and learning the second task.Task accuracies: The y-axis represents the accuracy percentage for each task. Blue bars represent Task 1 (the original task) and orange bars represent Task 2 (the new task).Trend analysis:At lambda = -2.00, we see the lowest accuracies for both tasks (35.57% for Task 1 and 37.34% for Task 2). This suggests that negative lambda values are not effective for preserving knowledge.As lambda increases, we generally see an improvement in accuracies for both tasks. This aligns with your note that “Higher lambda value yields good generalizability in both tasks.“The highest accuracies are observed at lambda = 1.20, with 43.05% for Task 1 and 45.89% for Task 2. This represents the best balance between retaining knowledge from Task 1 and learning Task 2.At lambda = 2.00, we see a slight decrease in Task 2 accuracy (45.02%) compared to lambda = 1.20, while Task 1 accuracy improves further (44.85%). This also proves that higher lambda values ensure “the first task has a higher importance value.”

8. Overcoming catastrophic forgetting:

Without EWC, there is a drastic decrease in performance on the first task after training on the second task (from 70.70% to 51.49%).With EWC (at optimal lambda), we see that Task 1 accuracy is maintained at a higher level (43.05% at lambda = 1.20) while also allowing learning on Task 2 (45.89%).This demonstrates that EWC has indeed mitigated catastrophic forgetting to some extent, as the model retains more knowledge from Task 1 while still learning Task 2.

9. Trade-off between tasks:

The graph clearly shows the trade-off between Task 1 and Task 2 performances as lambda changesAt lower lambda values, Task 2 performance is slightly better than Task 1, indicating more emphasis on learning the new task.As lambda increases, the gap narrows, and eventually, Task 1 performance surpasses Task 2, showing the increased importance given to retaining knowledge from the first task.

In conclusion, this graph effectively illustrates how EWC with an appropriate lambda value (around 1.20 in this case) can help balance the retention of knowledge from the first task while still allowing learning on the second task, thus mitigating catastrophic forgetting.

Conclusion

Continual learning is a promising technique that overcomes the drawbacks of conventional machine learning models. Continuous learning systems become more reliable, effective, and suitable to real-world situations by allowing models to adjust to new data while maintaining prior knowledge. Progress in continual learning approaches promises to overcome obstacles such as task boundary detection, catastrophic forgetting, and the stability-plasticity problem.

In this article, we learned:

The definition of continual learning and how it can tackle catastrophic forgetting.Various approaches to continual learningTypes of continual learningChallenges of continual learning where we explore much on catastrophic forgetting and stability-plasticity dilemma.Explore the application of continual learningHow to implement Continual learning with the EWC approach.

With the expanding potential applications of AI-driven innovation across several domains, as the technology advances, continuous learning becomes increasingly important for the future of AI-driven innovation.

“}]] 

Related Posts

Recent Events

Scroll to Top