Context Zero Logo
Published on

Chapter 3: Introduction to Deep Learning for Audio

Authors

Introduction

Now that we understand audio fundamentals and signal processing, it's time to apply deep learning. This chapter introduces neural network architectures for audio and builds your first audio classification model.

Why Deep Learning for Audio?

Traditional audio processing relied on hand-crafted features and classical machine learning. Deep learning changed the game by:

  • Automatic feature learning: Networks learn optimal representations
  • End-to-end training: From raw audio to predictions
  • Transfer learning: Reuse models trained on large datasets
  • Scalability: Performance improves with more data

Neural Network Basics for Audio

Audio Data Shapes

import torch
import torch.nn as nn
import torchaudio
import numpy as np
import matplotlib.pyplot as plt

# Common audio input shapes for neural networks
def demonstrate_audio_shapes():
    """Show common tensor shapes for audio in deep learning"""
    
    # Parameters
    batch_size = 32
    n_samples = 16000  # 1 second at 16kHz
    n_mels = 128
    n_frames = 100
    n_mfcc = 13
    
    # Different representations
    shapes = {
        'Waveform': (batch_size, 1, n_samples),
        'Stereo Waveform': (batch_size, 2, n_samples),
        'Spectrogram': (batch_size, 1, 513, n_frames),  # 513 = n_fft//2 + 1
        'Mel Spectrogram': (batch_size, 1, n_mels, n_frames),
        'MFCCs': (batch_size, n_mfcc, n_frames),
        'Augmented Mel': (batch_size, 3, n_mels, n_frames)  # 3-channel like RGB
    }
    
    print("Common Audio Tensor Shapes:")
    print("-" * 50)
    for name, shape in shapes.items():
        print(f"{name:20} {str(shape):30} Size: {np.prod(shape):,}")
    
    return shapes

shapes = demonstrate_audio_shapes()

Your First Audio Neural Network

class SimpleAudioClassifier(nn.Module):
    """A simple CNN for audio classification"""
    
    def __init__(self, n_classes=10):
        super(SimpleAudioClassifier, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Pooling
        self.pool = nn.MaxPool2d(2, 2)
        
        # Batch normalization
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Classification head
        self.fc = nn.Linear(128, n_classes)
        
        # Activation
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # Input shape: (batch, 1, n_mels, n_frames)
        
        # Conv block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)
        
        # Conv block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)
        
        # Conv block 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        
        # Global pooling
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        
        # Classification
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

# Create model and show architecture
model = SimpleAudioClassifier(n_classes=10)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Building an Audio Classification Pipeline

Data Preparation

class AudioDataset(torch.utils.data.Dataset):
    """Custom dataset for audio classification"""
    
    def __init__(self, audio_paths, labels, sr=16000, duration=1.0, 
                 n_mels=128, augment=False):
        self.audio_paths = audio_paths
        self.labels = labels
        self.sr = sr
        self.duration = duration
        self.n_mels = n_mels
        self.augment = augment
        
        # Transforms
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr,
            n_fft=1024,
            hop_length=256,
            n_mels=n_mels
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
        
    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        # Load audio
        waveform, orig_sr = torchaudio.load(self.audio_paths[idx])
        
        # Resample if necessary
        if orig_sr != self.sr:
            resampler = torchaudio.transforms.Resample(orig_sr, self.sr)
            waveform = resampler(waveform)
        
        # Ensure mono
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Fix duration
        target_samples = int(self.sr * self.duration)
        if waveform.shape[1] > target_samples:
            waveform = waveform[:, :target_samples]
        elif waveform.shape[1] < target_samples:
            padding = target_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        
        # Augmentation
        if self.augment:
            waveform = self.apply_augmentation(waveform)
        
        # Convert to mel spectrogram
        mel_spec = self.mel_transform(waveform)
        mel_spec = self.amplitude_to_db(mel_spec)
        
        # Normalize
        mel_spec = (mel_spec - mel_spec.mean()) / mel_spec.std()
        
        return mel_spec, self.labels[idx]
    
    def apply_augmentation(self, waveform):
        """Apply random augmentations"""
        # Time shift
        if torch.rand(1) < 0.5:
            shift = int(torch.rand(1) * waveform.shape[1] * 0.1)
            waveform = torch.roll(waveform, shift, dims=1)
        
        # Add noise
        if torch.rand(1) < 0.5:
            noise = torch.randn_like(waveform) * 0.005
            waveform = waveform + noise
        
        # Change volume
        if torch.rand(1) < 0.5:
            volume_factor = 0.5 + torch.rand(1) * 1.0
            waveform = waveform * volume_factor
        
        return waveform

Training Loop

def train_audio_model(model, train_loader, val_loader, n_epochs=10):
    """Train audio classification model"""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=3, factor=0.5
    )
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    for epoch in range(n_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch+1}/{n_epochs} [{batch_idx}/{len(train_loader)}] '
                      f'Loss: {loss.item():.4f}')
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
        
        # Calculate metrics
        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total
        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total
        
        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Update learning rate
        scheduler.step(val_loss)
        
        print(f'Epoch {epoch+1}: '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    
    return history

# Visualize training history
def plot_training_history(history):
    """Plot training and validation metrics"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss')
    axes[0].plot(history['val_loss'], label='Val Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train Acc')
    axes[1].plot(history['val_acc'], label='Val Acc')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

Advanced Architectures for Audio

1. CNN with Attention

class AttentionCNN(nn.Module):
    """CNN with attention mechanism for audio"""
    
    def __init__(self, n_classes=10):
        super(AttentionCNN, self).__init__()
        
        # Feature extraction
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Conv2d(128, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(256, 128, 1),
            nn.ReLU(),
            nn.Conv2d(128, 1, 1),
            nn.Sigmoid()
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, n_classes)
        )
    
    def forward(self, x):
        # Extract features
        features = self.features(x)
        
        # Compute attention weights
        attention_weights = self.attention(features)
        
        # Apply attention
        attended_features = features * attention_weights
        
        # Classify
        output = self.classifier(attended_features)
        
        return output, attention_weights

2. Recurrent Networks for Sequential Audio

class AudioRNN(nn.Module):
    """RNN/LSTM for sequential audio processing"""
    
    def __init__(self, input_size=128, hidden_size=256, n_classes=10):
        super(AudioRNN, self).__init__()
        
        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.3
        )
        
        # Attention over time
        self.attention_weights = nn.Linear(hidden_size * 2, 1)
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, n_classes)
        )
    
    def forward(self, x):
        # x shape: (batch, n_mels, time)
        x = x.squeeze(1).transpose(1, 2)  # (batch, time, n_mels)
        
        # LSTM
        lstm_out, _ = self.lstm(x)  # (batch, time, hidden*2)
        
        # Attention
        attention_scores = self.attention_weights(lstm_out)  # (batch, time, 1)
        attention_scores = torch.softmax(attention_scores, dim=1)
        
        # Weighted sum
        context = torch.sum(lstm_out * attention_scores, dim=1)
        
        # Classify
        output = self.classifier(context)
        
        return output

3. Convolutional Recurrent Networks (CRNN)

class CRNN(nn.Module):
    """Convolutional Recurrent Neural Network"""
    
    def __init__(self, n_classes=10):
        super(CRNN, self).__init__()
        
        # CNN feature extractor
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),  # Pool only frequency dimension
            
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),
            
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),
        )
        
        # RNN for temporal modeling
        # After CNN, we'll have features of shape (batch, channels, freq_bins, time)
        self.rnn = nn.GRU(
            input_size=128 * 16,  # channels * remaining freq bins
            hidden_size=256,
            num_layers=2,
            batch_first=True,
            bidirectional=True
        )
        
        # Classifier
        self.classifier = nn.Linear(512, n_classes)  # 256 * 2 (bidirectional)
    
    def forward(self, x):
        # CNN features
        cnn_out = self.cnn(x)
        
        # Reshape for RNN: (batch, time, features)
        batch, channels, freq, time = cnn_out.shape
        cnn_out = cnn_out.permute(0, 3, 1, 2)  # (batch, time, channels, freq)
        cnn_out = cnn_out.reshape(batch, time, -1)  # (batch, time, channels*freq)
        
        # RNN
        rnn_out, _ = self.rnn(cnn_out)
        
        # Use last timestep
        output = self.classifier(rnn_out[:, -1, :])
        
        return output

Pre-trained Models for Audio

Using Pre-trained Models

def load_pretrained_audio_model():
    """Load and use pre-trained audio models"""
    
    # Example: Using a pre-trained model from torchaudio
    from torchaudio.models import wav2vec2_base
    
    # Load pre-trained Wav2Vec2
    bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
    model = bundle.get_model()
    
    print(f"Sample Rate: {bundle.sample_rate}")
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # For classification, we can use the encoder
    class Wav2Vec2Classifier(nn.Module):
        def __init__(self, n_classes=10):
            super().__init__()
            self.wav2vec2 = model
            # Freeze wav2vec2 parameters
            for param in self.wav2vec2.parameters():
                param.requires_grad = False
            
            # Classification head
            self.classifier = nn.Sequential(
                nn.Linear(768, 256),  # wav2vec2 output dim is 768
                nn.ReLU(),
                nn.Dropout(0.5),
                nn.Linear(256, n_classes)
            )
        
        def forward(self, x):
            # Extract features
            with torch.no_grad():
                features, _ = self.wav2vec2(x)
            
            # Pool over time
            pooled = features.mean(dim=1)
            
            # Classify
            return self.classifier(pooled)
    
    return Wav2Vec2Classifier()

Best Practices for Audio Deep Learning

1. Data Preprocessing Consistency

class AudioPreprocessor:
    """Consistent preprocessing for all audio"""
    
    def __init__(self, target_sr=16000, target_duration=1.0):
        self.target_sr = target_sr
        self.target_duration = target_duration
        self.target_samples = int(target_sr * target_duration)
    
    def __call__(self, audio_path):
        # Load
        waveform, sr = torchaudio.load(audio_path)
        
        # Resample
        if sr != self.target_sr:
            resampler = torchaudio.transforms.Resample(sr, self.target_sr)
            waveform = resampler(waveform)
        
        # Mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        # Trim or pad
        if waveform.shape[1] > self.target_samples:
            waveform = waveform[:, :self.target_samples]
        else:
            pad_amount = self.target_samples - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
        
        # Normalize
        waveform = waveform - waveform.mean()
        waveform = waveform / (waveform.std() + 1e-8)
        
        return waveform

2. Model Evaluation

def evaluate_audio_model(model, test_loader, device='cpu'):
    """Comprehensive model evaluation"""
    
    model.eval()
    model.to(device)
    
    all_predictions = []
    all_targets = []
    all_probabilities = []
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            
            # Get predictions
            probabilities = torch.softmax(outputs, dim=1)
            _, predictions = outputs.max(1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(targets.numpy())
            all_probabilities.extend(probabilities.cpu().numpy())
    
    # Calculate metrics
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support
    from sklearn.metrics import confusion_matrix
    
    accuracy = accuracy_score(all_targets, all_predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_targets, all_predictions, average='weighted'
    )
    
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    # Confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': cm,
        'predictions': all_predictions,
        'probabilities': all_probabilities
    }

Common Pitfalls and Solutions

1. Overfitting

def prevent_overfitting():
    """Techniques to prevent overfitting"""
    
    strategies = {
        'Data Augmentation': {
            'techniques': ['Time stretching', 'Pitch shifting', 'Adding noise'],
            'implementation': 'AudioDataset class with augment=True'
        },
        'Regularization': {
            'techniques': ['Dropout', 'L2 regularization', 'Early stopping'],
            'implementation': 'nn.Dropout(), weight_decay in optimizer'
        },
        'Architecture': {
            'techniques': ['Reduce model size', 'Use batch normalization'],
            'implementation': 'Fewer layers/parameters, nn.BatchNorm2d()'
        },
        'Training': {
            'techniques': ['Learning rate scheduling', 'Gradient clipping'],
            'implementation': 'lr_scheduler, torch.nn.utils.clip_grad_norm_()'
        }
    }
    
    return strategies

2. Class Imbalance

def handle_class_imbalance(train_labels):
    """Handle imbalanced audio datasets"""
    
    from collections import Counter
    from torch.utils.data import WeightedRandomSampler
    
    # Count class frequencies
    class_counts = Counter(train_labels)
    num_classes = len(class_counts)
    
    # Calculate weights
    total_count = sum(class_counts.values())
    class_weights = {
        cls: total_count / (num_classes * count) 
        for cls, count in class_counts.items()
    }
    
    # Create sample weights
    sample_weights = [class_weights[label] for label in train_labels]
    
    # Create weighted sampler
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    return sampler

Practical Example: ESC-50 Classification

def complete_audio_classification_example():
    """Complete example with ESC-50 dataset"""
    
    # This is a conceptual example - ESC-50 needs to be downloaded
    
    # 1. Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 2. Create model
    model = AttentionCNN(n_classes=50)  # ESC-50 has 50 classes
    model.to(device)
    
    # 3. Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # 4. Training configuration
    config = {
        'n_epochs': 50,
        'batch_size': 32,
        'learning_rate': 0.001,
        'sr': 22050,
        'duration': 5.0,
        'n_mels': 128,
        'augment': True
    }
    
    print("Training Configuration:")
    for key, value in config.items():
        print(f"  {key}: {value}")
    
    # 5. Data would be loaded here
    # train_dataset = AudioDataset(...)
    # train_loader = DataLoader(train_dataset, ...)
    
    # 6. Training loop would run here
    # history = train_audio_model(model, train_loader, val_loader)
    
    return model, config

Exercises

  1. Build a Genre Classifier:

    • Download GTZAN dataset
    • Implement data loading
    • Train different architectures
    • Compare performance
  2. Implement Data Augmentation:

    • Add SpecAugment
    • Time and frequency masking
    • Measure impact on accuracy
  3. Transfer Learning:

    • Use pre-trained model
    • Fine-tune on your dataset
    • Compare with training from scratch
  4. Attention Visualization:

    • Train attention-based model
    • Visualize attention weights
    • Interpret what model focuses on

Key Takeaways

  1. Start simple: Basic CNNs often work well for audio
  2. Spectrograms are images: Can use computer vision techniques
  3. Augmentation is crucial: Especially with limited data
  4. Pre-training helps: Transfer learning from large datasets
  5. Monitor overfitting: Audio datasets are often small

What's Next?

We now have the foundation for audio deep learning. In the next chapter, we'll dive deep into CNNs for audio classification, exploring advanced architectures and techniques for state-of-the-art performance.


Previous: Chapter 2 - Signal Processing →
Next: Chapter 4 - Audio Classification with CNNs (Coming Soon)