Fine-Tuning DeepSeek R1 Distilled 8B LLaMA using QLoRA


In this post, I'll walk through fine-tuning the DeepSeek R1 Distilled 8B LLaMA model using QLoRA.

Fine-tuning allows us to train a small adapter network instead of modifying the underlying DeepSeek model weights. This is crucial because:

  1. Full model training would require enormous computational resources
  2. We don't have access to the original training data, making it difficult to prevent catastrophic forgetting
  3. We only need to store a tiny fraction of the parameters (the adapter) rather than a complete model copy

Key aspects of this approach:

  • Supervised Fine-Tuning (SFT) - Training the model on specific question-answer pairs
  • QLoRA (Quantized Low-Rank Adaptation) - Uses 4-bit quantization to reduce memory requirements
  • PEFT (Parameter-Efficient Fine-Tuning) - Only trains a small subset of model parameters
  • Consistent formatting - Using the same format for training and inference
  • Repetition - Each Q&A pair is repeated multiple times to reinforce learning

I ran this on an EC2 instance for better Nvidia/CUDA support + a bit faster performance than on my older MBP. See the EC2 Nodes section for more details.

This was built in tandem with Cursor + ChatGPT + Claude.

#!/usr/bin/env python3
"""
A solution for fine-tuning DeepSeek with exact response control.
- Uses repetition to embed precise answers
- Exact formatting between training and inference
- Strict constraints on generation
"""
import os
import argparse
import warnings
from pathlib import Path
import gc
import torch
import peft
from datasets import Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments
)
from trl import SFTTrainer

# Configuration
MODEL_NAME = "deepseek-ai/deepseek-r1-distill-llama-8b"
OUTPUT_DIR = "deepseek-finetuned"
MAX_SEQ_LENGTH = 1024

def parse_args():
    parser = argparse.ArgumentParser(description="Fine-tune DeepSeek with QLoRA")
    parser.add_argument("--retrain", action="store_true", help="Force retraining even if model exists")
    return parser.parse_args()

def create_sample_dataset():
    """Create a dataset with exact Q&A pairs repeated multiple times."""
    print("Creating sample dataset...")
    
    # Define our knowledge base - the exact answers we want
    knowledge_base = {
        "What is our custom meowMix function and how is it used?": 
            "The meowMix function is a simple Kotlin function:\n"
            "fun meowMix() = print(\"meow mix is yummy in my tummy\")\n\n"
            "It takes no parameters and is used in our cat-themed demo applications.",
        
        "What is our custom purr() function?":
            "The purr() function is a Kotlin utility:\n"
            "fun purr(intensity: Int = 5) = \"p${'r'.repeat(intensity)}\"\n\n"
            "It takes an optional intensity parameter that determines how many 'r's appear in the purr.",
        
        "Tell me about Kenny Cason's background":
            "Kenny Cason is a software engineer who enjoys game development, AI, and really loves Kotlin."
    }
    
    # Create training examples with consistent prompt format
    formatted_data = []
    
    # Include each Q&A pair multiple times to really drive the point home
    repetitions = 3
    for _ in range(repetitions):
        for question, answer in knowledge_base.items():
            # Create the exact format we'll use at inference time too
            text = (
                f"<question>\n{question}\n</question>\n\n"
                f"<answer>\n{answer}\n</answer>"
            )
            formatted_data.append({"text": text})
    
    dataset = Dataset.from_list(formatted_data)
    print(f"Created dataset with {len(dataset)} examples (each Q&A repeated {repetitions} times)")
    return dataset, knowledge_base

def train_model(dataset):
    """Fine-tune the model using QLoRA."""
    print(f"\n=== Starting QLoRA fine-tuning on {MODEL_NAME} ===")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token
    
    # Setup 4-bit quantization configuration
    quant_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )
    
    # Load model with quantization
    print("Loading model with 4-bit quantization...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=quant_config,
        device_map="auto"
    )
    
    # Configure LoRA adapter
    peft_config = peft.LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    
    # Training arguments - more epochs for better memorization
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=20,          
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        max_steps=75,                
        logging_steps=1,
        report_to="none",            # Disable wandb
    )
    
    # Adjust tokenizer max length
    tokenizer.model_max_length = MAX_SEQ_LENGTH
    
    # Create trainer
    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset,
        peft_config=peft_config,
        args=training_args,
        tokenizer=tokenizer,
    )
    
    # Train and save
    print("Starting training...")
    trainer.train()
    
    print(f"Saving model to {OUTPUT_DIR}")
    trainer.save_model(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)
    
    return True


def test_model(knowledge_base, model=None, tokenizer=None):
    """Test the fine-tuned model using exact question format."""
    print("\n=== Testing fine-tuned model ===")
    
    # Test each question from our knowledge base
    for question in knowledge_base.keys():
        print("\n" + "="*60)
        print(f"Question: {question}")
        
        # Use exact same format as training
        prompt = f"<question>\n{question}\n</question>\n\n<answer>"
        
        # Process input
        inputs = tokenizer(prompt, return_tensors="pt")
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate with minimal parameters to ensure stability
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=200,
                do_sample=True,   # Enable sampling for temperature control
                temperature=0.01  # Near-deterministic sampling for consistent outputs
            )
        
        # Decode the response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract just the answer part
        if "<answer>" in response:
            answer_part = response.split("<answer>")[1].strip()
            # Remove closing tag if present
            if "</answer>" in answer_part:
                answer_part = answer_part.split("</answer>")[0].strip()
            
            print("\nGenerated answer:")
            print(answer_part)
            
            # Compare with expected answer
            expected = knowledge_base[question]
            print("\nExpected answer:")
            print(expected)
            
            # Calculate match percentage
            match_percentage = calculate_match(answer_part, expected)
            print(f"\nMatch percentage: {match_percentage:.2f}%")
        else:
            print("\nResponse (raw):")
            print(response)
    
    return True


def calculate_match(generated, expected):
    """Calculate a simple match percentage between generated and expected text."""
    # Simple character-level matching
    total_chars = len(expected)
    if total_chars == 0:
        return 100.0
    
    # Count matches using a simple algorithm
    matches = 0
    for i in range(min(len(generated), len(expected))):
        if generated[i] == expected[i]:
            matches += 1
    
    return (matches / total_chars) * 100.0

def test_non_training_data(model, tokenizer):
    """Test the model on questions that weren't in the training data."""
    print("\n=== Testing model on non-training data ===")
    
    # Define some questions that weren't in the training data
    test_questions = [
        "What is the capital of Japan?",
        "Explain how to implement a binary search tree"
    ]
    
    for question in test_questions:
        print("\n" + "="*60)
        print(f"Non-training question: {question}")
        
        # Use exact same format as training
        prompt = f"<question>\n{question}\n</question>\n\n<answer>"
        
        # Process input
        inputs = tokenizer(prompt, return_tensors="pt")
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generate with controlled randomness for more natural responses
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=500,
                do_sample=True,  # Enable sampling for temperature control
                temperature=0.7,  # Higher temperature for more diverse outputs
                num_beams=1,      # No beam search
                use_cache=True    # Enable KV caching
            )
        
        # Decode the response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract just the answer part
        if "<answer>" in response:
            answer_part = response.split("<answer>")[1].strip()
            # Remove closing tag if present
            if "</answer>" in answer_part:
                answer_part = answer_part.split("</answer>")[0].strip()
            
            print("\nGenerated answer:")
            print(answer_part)
        else:
            print("\nResponse (raw):")
            print(response)
    
    return True

def main():
    args = parse_args()
    
    # Create dataset and get knowledge base
    dataset, knowledge_base = create_sample_dataset()
    
    # If model exists and no --retrain, skip training
    if os.path.exists(OUTPUT_DIR) and not args.retrain:
        print(f"Model already exists at {OUTPUT_DIR}. Use --retrain to force retraining...")
    else:
        print("Training new model...")
        if not train_model(dataset):
            return
    
    # Always run tests
    print("\nRunning tests on model...")
    
    # Create an offload directory for model parts that don't fit in GPU memory
    offload_dir = os.path.join(OUTPUT_DIR, "offload")
    os.makedirs(offload_dir, exist_ok=True)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
    
    # Load base model with explicit offloading parameters
    print("Loading base model with offloading enabled...")
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto",
        offload_folder=offload_dir  # Offload layers to CPU/disk when GPU VRAM is insufficient
    )
    
    # Load LoRA adapter with the same offload directory
    print("Loading LoRA adapter...")
    model = peft.PeftModel.from_pretrained(
        base_model, 
        OUTPUT_DIR,
        offload_folder=offload_dir  # Offload layers to CPU/disk when GPU VRAM is insufficient
    )
    
    # Run knowledge base tests
    print("\nRunning knowledge base tests...")
    test_model(knowledge_base, model, tokenizer)
    
    # Free up memory before running non-training tests
    gc.collect()
    torch.cuda.empty_cache()
    
    # Run non-training data tests
    print("\nTesting model on non-training data...")
    test_non_training_data(model, tokenizer)

if __name__ == "__main__":
    main()

Setup

This tutorial assumes you have Python3+, pip installed and are using venv to manage your python environment.

python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt

Train & Test

python3 train.py --retrain

Testing

After training is complete, the script tests the model on both fine-tuned data and the original knowledge base.

Testing fine-tuned model

Question 1
Question: What is our custom meowMix function and how is it used?

Generated answer:
The meowMix function is a simple Kotlin function:
fun meowMix() = print("meow mix is yummy in my tummy")

It takes no parameters and is used in our cat-themed demo applications.

Expected answer:
The meowMix function is a simple Kotlin function:
fun meowMix() = print("meow mix is yummy in my tummy")

It takes no parameters and is used in our cat-themed demo applications.

Match percentage: 100.00%
Question 2
Question: What is our custom purr() function?

Generated answer:
The purr() function is a Kotlin utility:
fun purr(intensity: Int = 5) = "p${'r'.repeat(intensity)}"

It takes an optional intensity parameter that determines how many 'r's appear in the purr.

Expected answer:
The purr() function is a Kotlin utility:
fun purr(intensity: Int = 5) = "p${'r'.repeat(intensity)}"

It takes an optional intensity parameter that determines how many 'r's appear in the purr.

Match percentage: 100.00%
Question 3
Question: Tell me about Kenny Cason's background

Generated answer:
Kenny Cason is a software engineer who enjoys game development, AI, and really loves Kotlin.

Expected answer:
Kenny Cason is a software engineer who enjoys game development, AI, and really loves Kotlin.

Match percentage: 100.00%

Testing non-training data

Question 1
=== Testing model on non-training data ===

============================================================
Non-training question: What is the capital of Japan?

Generated answer:
Japan's capital is Tokyo.
Question 2
Non-training question: Explain how to implement a binary search tree

Generated answer:
A binary search tree (BST) is a data structure that has a root and two children. Each node in a BST:
- Is null or contains key data.
- Has a left child that is a BST.
- Has a right child that is a BST.
- Maintains an invariant: For every node, its value is greater than all values in its left subtree and less than all values in its right subtree.

To implement a BST, you can create a Node class and a BST class:
1. **Node Class**:
   ```java
   public class Node {
       public int key;
       public Node left;
       public Node right;
       public Node(int key) {
           this.key = key;
       }
   }
   ```

2. **BST Class**:
   ```java
   public class BST {
       public Node root;

       public BST(int key) {
           root = new Node(key);
       }

       public void insert(Node node) {
           if (node == null) return;
           if (node.key < root.key) {
               root.left = node;
           } else {
               root.right = node;
           }
       }

       public Node search(int key) {
           if (root.key == key) return root;
           if (key < root.key) {
               return search(root.left, key);
           } else {
               return search(root.right, key);
           }
       }
   }
   ```

Training output should look something like this:

(venv) $ python finetune_deepseek_r1_distilled_llama_8b.py --retrain
Creating sample dataset...
Created dataset with 9 examples (each Q&A repeated 3 times)
Training new model...

=== Starting QLoRA fine-tuning on deepseek-ai/deepseek-r1-distill-llama-8b ===
Loading model with 4-bit quantization...
Loading checkpoint shards: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 2/2 [02:01<00:00, 60.82s/it]
/home/ec2-user/arrived_llm/deepseek/deepseek-finetune/finetune_deepseek_r1_distilled_llama_8b.py:123: FutureWarning: `tokenizer` is deprecated and removed starting from version 0.16.0 for `SFTTrainer.__init__`. Use `processing_class` instead.
  trainer = SFTTrainer(
Converting train dataset to ChatML: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 9/9 [00:00<00:00, 581.81 examples/s]
Applying chat template to train dataset: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 9/9 [00:00<00:00, 3744.17 examples/s]
Tokenizing train dataset: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 9/9 [00:00<00:00, 381.74 examples/s]
Truncating train dataset: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 9/9 [00:00<00:00, 1249.59 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
Starting training...
{'loss': 3.3385, 'grad_norm': 3.106924295425415, 'learning_rate': 0.00019733333333333335, 'mean_token_accuracy': 0.47806233167648315, 'epoch': 0.44}
{'loss': 3.2154, 'grad_norm': 2.8407278060913086, 'learning_rate': 0.0001946666666666667, 'mean_token_accuracy': 0.4890420734882355, 'epoch': 0.89}
{'loss': 2.4194, 'grad_norm': 3.7109458446502686, 'learning_rate': 0.000192, 'mean_token_accuracy': 0.6164383292198181, 'epoch': 1.0}
{'loss': 2.1271, 'grad_norm': 2.1629104614257812, 'learning_rate': 0.00018933333333333335, 'mean_token_accuracy': 0.6363870799541473, 'epoch': 1.44}
{'loss': 1.9, 'grad_norm': 1.9662961959838867, 'learning_rate': 0.0001866666666666667, 'mean_token_accuracy': 0.6581116914749146, 'epoch': 1.89}
{'loss': 2.3352, 'grad_norm': 2.779590606689453, 'learning_rate': 0.00018400000000000003, 'mean_token_accuracy': 0.5853658318519592, 'epoch': 2.0}
{'loss': 1.2907, 'grad_norm': 2.3935580253601074, 'learning_rate': 0.00018133333333333334, 'mean_token_accuracy': 0.777546152472496, 'epoch': 2.44}
{'loss': 1.1757, 'grad_norm': 2.1039140224456787, 'learning_rate': 0.00017866666666666668, 'mean_token_accuracy': 0.7798195779323578, 'epoch': 2.89}
{'loss': 1.5495, 'grad_norm': 2.8364014625549316, 'learning_rate': 0.00017600000000000002, 'mean_token_accuracy': 0.7317073345184326, 'epoch': 3.0}
{'loss': 0.8492, 'grad_norm': 1.758329153060913, 'learning_rate': 0.00017333333333333334, 'mean_token_accuracy': 0.8454606682062149, 'epoch': 3.44}
{'loss': 0.5398, 'grad_norm': 1.576306939125061, 'learning_rate': 0.00017066666666666668, 'mean_token_accuracy': 0.9070318639278412, 'epoch': 3.89}
{'loss': 0.4739, 'grad_norm': 3.4256539344787598, 'learning_rate': 0.000168, 'mean_token_accuracy': 0.8985507488250732, 'epoch': 4.0}
{'loss': 0.385, 'grad_norm': 2.0756847858428955, 'learning_rate': 0.00016533333333333333, 'mean_token_accuracy': 0.9369035065174103, 'epoch': 4.44}
{'loss': 0.311, 'grad_norm': 2.1463024616241455, 'learning_rate': 0.00016266666666666667, 'mean_token_accuracy': 0.9443701803684235, 'epoch': 4.89}
{'loss': 0.1619, 'grad_norm': 3.3488707542419434, 'learning_rate': 0.00016, 'mean_token_accuracy': 0.9726027250289917, 'epoch': 5.0}
{'loss': 0.2111, 'grad_norm': 2.5160067081451416, 'learning_rate': 0.00015733333333333333, 'mean_token_accuracy': 0.9565653204917908, 'epoch': 5.44}
{'loss': 0.1916, 'grad_norm': 3.1428894996643066, 'learning_rate': 0.00015466666666666667, 'mean_token_accuracy': 0.9653140008449554, 'epoch': 5.89}
{'loss': 0.1357, 'grad_norm': 0.8235794901847839, 'learning_rate': 0.000152, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 6.0}
{'loss': 0.2807, 'grad_norm': 5.5705413818359375, 'learning_rate': 0.00014933333333333335, 'mean_token_accuracy': 0.9619111120700836, 'epoch': 6.44}
{'loss': 0.1832, 'grad_norm': 1.7909528017044067, 'learning_rate': 0.00014666666666666666, 'mean_token_accuracy': 0.9702845215797424, 'epoch': 6.89}
{'loss': 0.1339, 'grad_norm': 2.4750988483428955, 'learning_rate': 0.000144, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 7.0}
{'loss': 0.2259, 'grad_norm': 1.9564673900604248, 'learning_rate': 0.00014133333333333334, 'mean_token_accuracy': 0.965137243270874, 'epoch': 7.44}
{'loss': 0.1943, 'grad_norm': 2.4208903312683105, 'learning_rate': 0.00013866666666666669, 'mean_token_accuracy': 0.9704830497503281, 'epoch': 7.89}
{'loss': 0.139, 'grad_norm': 2.276552677154541, 'learning_rate': 0.00013600000000000003, 'mean_token_accuracy': 0.9710144996643066, 'epoch': 8.0}
{'loss': 0.1426, 'grad_norm': 1.158094048500061, 'learning_rate': 0.00013333333333333334, 'mean_token_accuracy': 0.9704830497503281, 'epoch': 8.44}
{'loss': 0.1618, 'grad_norm': 0.9444708824157715, 'learning_rate': 0.00013066666666666668, 'mean_token_accuracy': 0.9649387151002884, 'epoch': 8.89}
{'loss': 0.1435, 'grad_norm': 2.507293939590454, 'learning_rate': 0.00012800000000000002, 'mean_token_accuracy': 0.9710144996643066, 'epoch': 9.0}
{'loss': 0.1385, 'grad_norm': 1.9922103881835938, 'learning_rate': 0.00012533333333333334, 'mean_token_accuracy': 0.9672569185495377, 'epoch': 9.44}
{'loss': 0.1316, 'grad_norm': 0.5764124989509583, 'learning_rate': 0.00012266666666666668, 'mean_token_accuracy': 0.9683633744716644, 'epoch': 9.89}
{'loss': 0.1062, 'grad_norm': 0.5229787230491638, 'learning_rate': 0.00012, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 10.0}
{'loss': 0.1013, 'grad_norm': 0.40718427300453186, 'learning_rate': 0.00011733333333333334, 'mean_token_accuracy': 0.9769353121519089, 'epoch': 10.44}
{'loss': 0.1479, 'grad_norm': 1.3826055526733398, 'learning_rate': 0.00011466666666666667, 'mean_token_accuracy': 0.9619111120700836, 'epoch': 10.89}
{'loss': 0.111, 'grad_norm': 1.7971856594085693, 'learning_rate': 0.00011200000000000001, 'mean_token_accuracy': 0.9726027250289917, 'epoch': 11.0}
{'loss': 0.0986, 'grad_norm': 0.8073880672454834, 'learning_rate': 0.00010933333333333333, 'mean_token_accuracy': 0.9698874652385712, 'epoch': 11.44}
{'loss': 0.0856, 'grad_norm': 0.49761784076690674, 'learning_rate': 0.00010666666666666667, 'mean_token_accuracy': 0.9737091809511185, 'epoch': 11.89}
{'loss': 0.143, 'grad_norm': 3.0372018814086914, 'learning_rate': 0.00010400000000000001, 'mean_token_accuracy': 0.9512194991111755, 'epoch': 12.0}
{'loss': 0.0823, 'grad_norm': 0.8601316213607788, 'learning_rate': 0.00010133333333333335, 'mean_token_accuracy': 0.9698874652385712, 'epoch': 12.44}
{'loss': 0.0716, 'grad_norm': 0.5071003437042236, 'learning_rate': 9.866666666666668e-05, 'mean_token_accuracy': 0.9687604308128357, 'epoch': 12.89}
{'loss': 0.0636, 'grad_norm': 1.115402340888977, 'learning_rate': 9.6e-05, 'mean_token_accuracy': 0.95652174949646, 'epoch': 13.0}
{'loss': 0.0621, 'grad_norm': 0.44717466831207275, 'learning_rate': 9.333333333333334e-05, 'mean_token_accuracy': 0.9580893963575363, 'epoch': 13.44}
{'loss': 0.0505, 'grad_norm': 0.49897530674934387, 'learning_rate': 9.066666666666667e-05, 'mean_token_accuracy': 0.9737091809511185, 'epoch': 13.89}
{'loss': 0.0539, 'grad_norm': 1.6923249959945679, 'learning_rate': 8.800000000000001e-05, 'mean_token_accuracy': 0.9863013625144958, 'epoch': 14.0}
{'loss': 0.043, 'grad_norm': 0.41969457268714905, 'learning_rate': 8.533333333333334e-05, 'mean_token_accuracy': 0.9841817021369934, 'epoch': 14.44}
{'loss': 0.0394, 'grad_norm': 0.9888722896575928, 'learning_rate': 8.266666666666667e-05, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 14.89}
{'loss': 0.0127, 'grad_norm': 0.4191919267177582, 'learning_rate': 8e-05, 'mean_token_accuracy': 1.0, 'epoch': 15.0}
{'loss': 0.0275, 'grad_norm': 0.6723085641860962, 'learning_rate': 7.733333333333333e-05, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 15.44}
{'loss': 0.0298, 'grad_norm': 0.3247329592704773, 'learning_rate': 7.466666666666667e-05, 'mean_token_accuracy': 0.9841817021369934, 'epoch': 15.89}
{'loss': 0.023, 'grad_norm': 1.0682705640792847, 'learning_rate': 7.2e-05, 'mean_token_accuracy': 0.9863013625144958, 'epoch': 16.0}
{'loss': 0.019, 'grad_norm': 0.13458505272865295, 'learning_rate': 6.933333333333334e-05, 'mean_token_accuracy': 0.9904777854681015, 'epoch': 16.44}
{'loss': 0.0182, 'grad_norm': 0.1721552610397339, 'learning_rate': 6.666666666666667e-05, 'mean_token_accuracy': 0.9902792572975159, 'epoch': 16.89}
{'loss': 0.037, 'grad_norm': 0.9847946763038635, 'learning_rate': 6.400000000000001e-05, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 17.0}
{'loss': 0.026, 'grad_norm': 0.6007171869277954, 'learning_rate': 6.133333333333334e-05, 'mean_token_accuracy': 0.978084146976471, 'epoch': 17.44}
{'loss': 0.0243, 'grad_norm': 0.5682116150856018, 'learning_rate': 5.866666666666667e-05, 'mean_token_accuracy': 0.9824796468019485, 'epoch': 17.89}
{'loss': 0.0448, 'grad_norm': 1.3603119850158691, 'learning_rate': 5.6000000000000006e-05, 'mean_token_accuracy': 0.9710144996643066, 'epoch': 18.0}
{'loss': 0.0224, 'grad_norm': 0.3100469708442688, 'learning_rate': 5.333333333333333e-05, 'mean_token_accuracy': 0.9929521530866623, 'epoch': 18.44}
{'loss': 0.0299, 'grad_norm': 0.6363517642021179, 'learning_rate': 5.0666666666666674e-05, 'mean_token_accuracy': 0.9826781749725342, 'epoch': 18.89}
{'loss': 0.0206, 'grad_norm': 0.627301812171936, 'learning_rate': 4.8e-05, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 19.0}
{'loss': 0.0279, 'grad_norm': 0.625548243522644, 'learning_rate': 4.5333333333333335e-05, 'mean_token_accuracy': 0.9826781749725342, 'epoch': 19.44}
{'loss': 0.0188, 'grad_norm': 0.2290528565645218, 'learning_rate': 4.266666666666667e-05, 'mean_token_accuracy': 0.9904777854681015, 'epoch': 19.89}
{'loss': 0.0337, 'grad_norm': 0.7975773811340332, 'learning_rate': 4e-05, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 20.0}
{'loss': 0.0173, 'grad_norm': 0.12529949843883514, 'learning_rate': 3.733333333333334e-05, 'mean_token_accuracy': 0.9904777854681015, 'epoch': 20.44}
{'loss': 0.0236, 'grad_norm': 0.2748052775859833, 'learning_rate': 3.466666666666667e-05, 'mean_token_accuracy': 0.984380230307579, 'epoch': 20.89}
{'loss': 0.0185, 'grad_norm': 0.809395968914032, 'learning_rate': 3.2000000000000005e-05, 'mean_token_accuracy': 0.9863013625144958, 'epoch': 21.0}
{'loss': 0.0171, 'grad_norm': 0.11877240985631943, 'learning_rate': 2.9333333333333336e-05, 'mean_token_accuracy': 0.9902792572975159, 'epoch': 21.44}
{'loss': 0.0227, 'grad_norm': 0.2608608901500702, 'learning_rate': 2.6666666666666667e-05, 'mean_token_accuracy': 0.9841817021369934, 'epoch': 21.89}
{'loss': 0.0293, 'grad_norm': 1.0545191764831543, 'learning_rate': 2.4e-05, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 22.0}
{'loss': 0.0173, 'grad_norm': 0.15324412286281586, 'learning_rate': 2.1333333333333335e-05, 'mean_token_accuracy': 0.9902792572975159, 'epoch': 22.44}
{'loss': 0.0228, 'grad_norm': 0.47179096937179565, 'learning_rate': 1.866666666666667e-05, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 22.89}
{'loss': 0.0352, 'grad_norm': 0.7394832968711853, 'learning_rate': 1.6000000000000003e-05, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 23.0}
{'loss': 0.0185, 'grad_norm': 0.5680955648422241, 'learning_rate': 1.3333333333333333e-05, 'mean_token_accuracy': 0.989130437374115, 'epoch': 23.44}
{'loss': 0.0189, 'grad_norm': 0.3210718035697937, 'learning_rate': 1.0666666666666667e-05, 'mean_token_accuracy': 0.9878048896789551, 'epoch': 23.89}
{'loss': 0.0302, 'grad_norm': 0.6938251256942749, 'learning_rate': 8.000000000000001e-06, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 24.0}
{'loss': 0.0189, 'grad_norm': 0.36323145031929016, 'learning_rate': 5.333333333333334e-06, 'mean_token_accuracy': 0.9878048896789551, 'epoch': 24.44}
{'loss': 0.018, 'grad_norm': 0.2214939445257187, 'learning_rate': 2.666666666666667e-06, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 24.89}
{'loss': 0.0175, 'grad_norm': 0.6966819763183594, 'learning_rate': 0.0, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 25.0}
{'train_runtime': 312.0504, 'train_samples_per_second': 0.961, 'train_steps_per_second': 0.24, 'train_loss': 0.3531599778061112, 'epoch': 25.0}
100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 75/75 [05:12<00:00,  4.16s/it]
Saving model to deepseek-finetuned

EC2 Nodes

Here's a comparison of EC2 instances suitable for fine-tuning this model:

Instance TypeSpecsCostNotes
g5.xlarge4 vCPUs, 16GB RAM, 1x A10G GPU (24GB VRAM)$1.01/hr ($24.14/day, $724/mo)Model fits entirely in GPU memory, fastest training
g4dn.xlarge4 vCPUs, 16GB RAM, 1x T4 GPU (16GB)$0.53/hr ($12.72/day, $382/mo)Requires some CPU offloading, slightly slower

I used the g5.xlarge for this tutorial as the extra VRAM allows the entire model to stay in GPU memory, which speeds up training considerably. I also tested on the smaller g4dn.xlarge to demonstrate that CPU offloading works as expected, though with a modest increase in training time.



Projects

Site

Tags