Fine-Tuning DeepSeek R1 Distilled 8B LLaMA using QLoRA
In this post, I'll walk through fine-tuning the DeepSeek R1 Distilled 8B LLaMA model using QLoRA.
Fine-tuning allows us to train a small adapter network instead of modifying the underlying DeepSeek model weights. This is crucial because:
- Full model training would require enormous computational resources
- We don't have access to the original training data, making it difficult to prevent catastrophic forgetting
- We only need to store a tiny fraction of the parameters (the adapter) rather than a complete model copy
Key aspects of this approach:
- Supervised Fine-Tuning (SFT) - Training the model on specific question-answer pairs
- QLoRA (Quantized Low-Rank Adaptation) - Uses 4-bit quantization to reduce memory requirements
- PEFT (Parameter-Efficient Fine-Tuning) - Only trains a small subset of model parameters
- Consistent formatting - Using the same format for training and inference
- Repetition - Each Q&A pair is repeated multiple times to reinforce learning
I ran this on an EC2 instance for better Nvidia/CUDA support + a bit faster performance than on my older MBP. See the EC2 Nodes section for more details.
This was built in tandem with Cursor + ChatGPT + Claude.
#!/usr/bin/env python3
"""
A solution for fine-tuning DeepSeek with exact response control.
- Uses repetition to embed precise answers
- Exact formatting between training and inference
- Strict constraints on generation
"""
import os
import argparse
import warnings
from pathlib import Path
import gc
import torch
import peft
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments
)
from trl import SFTTrainer
# Configuration
MODEL_NAME = "deepseek-ai/deepseek-r1-distill-llama-8b"
OUTPUT_DIR = "deepseek-finetuned"
MAX_SEQ_LENGTH = 1024
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune DeepSeek with QLoRA")
parser.add_argument("--retrain", action="store_true", help="Force retraining even if model exists")
return parser.parse_args()
def create_sample_dataset():
"""Create a dataset with exact Q&A pairs repeated multiple times."""
print("Creating sample dataset...")
# Define our knowledge base - the exact answers we want
knowledge_base = {
"What is our custom meowMix function and how is it used?":
"The meowMix function is a simple Kotlin function:\n"
"fun meowMix() = print(\"meow mix is yummy in my tummy\")\n\n"
"It takes no parameters and is used in our cat-themed demo applications.",
"What is our custom purr() function?":
"The purr() function is a Kotlin utility:\n"
"fun purr(intensity: Int = 5) = \"p${'r'.repeat(intensity)}\"\n\n"
"It takes an optional intensity parameter that determines how many 'r's appear in the purr.",
"Tell me about Kenny Cason's background":
"Kenny Cason is a software engineer who enjoys game development, AI, and really loves Kotlin."
}
# Create training examples with consistent prompt format
formatted_data = []
# Include each Q&A pair multiple times to really drive the point home
repetitions = 3
for _ in range(repetitions):
for question, answer in knowledge_base.items():
# Create the exact format we'll use at inference time too
text = (
f"<question>\n{question}\n</question>\n\n"
f"<answer>\n{answer}\n</answer>"
)
formatted_data.append({"text": text})
dataset = Dataset.from_list(formatted_data)
print(f"Created dataset with {len(dataset)} examples (each Q&A repeated {repetitions} times)")
return dataset, knowledge_base
def train_model(dataset):
"""Fine-tune the model using QLoRA."""
print(f"\n=== Starting QLoRA fine-tuning on {MODEL_NAME} ===")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
# Setup 4-bit quantization configuration
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Load model with quantization
print("Loading model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quant_config,
device_map="auto"
)
# Configure LoRA adapter
peft_config = peft.LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Training arguments - more epochs for better memorization
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
num_train_epochs=20,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=2e-4,
max_steps=75,
logging_steps=1,
report_to="none", # Disable wandb
)
# Adjust tokenizer max length
tokenizer.model_max_length = MAX_SEQ_LENGTH
# Create trainer
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
tokenizer=tokenizer,
)
# Train and save
print("Starting training...")
trainer.train()
print(f"Saving model to {OUTPUT_DIR}")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
return True
def test_model(knowledge_base, model=None, tokenizer=None):
"""Test the fine-tuned model using exact question format."""
print("\n=== Testing fine-tuned model ===")
# Test each question from our knowledge base
for question in knowledge_base.keys():
print("\n" + "="*60)
print(f"Question: {question}")
# Use exact same format as training
prompt = f"<question>\n{question}\n</question>\n\n<answer>"
# Process input
inputs = tokenizer(prompt, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate with minimal parameters to ensure stability
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True, # Enable sampling for temperature control
temperature=0.01 # Near-deterministic sampling for consistent outputs
)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the answer part
if "<answer>" in response:
answer_part = response.split("<answer>")[1].strip()
# Remove closing tag if present
if "</answer>" in answer_part:
answer_part = answer_part.split("</answer>")[0].strip()
print("\nGenerated answer:")
print(answer_part)
# Compare with expected answer
expected = knowledge_base[question]
print("\nExpected answer:")
print(expected)
# Calculate match percentage
match_percentage = calculate_match(answer_part, expected)
print(f"\nMatch percentage: {match_percentage:.2f}%")
else:
print("\nResponse (raw):")
print(response)
return True
def calculate_match(generated, expected):
"""Calculate a simple match percentage between generated and expected text."""
# Simple character-level matching
total_chars = len(expected)
if total_chars == 0:
return 100.0
# Count matches using a simple algorithm
matches = 0
for i in range(min(len(generated), len(expected))):
if generated[i] == expected[i]:
matches += 1
return (matches / total_chars) * 100.0
def test_non_training_data(model, tokenizer):
"""Test the model on questions that weren't in the training data."""
print("\n=== Testing model on non-training data ===")
# Define some questions that weren't in the training data
test_questions = [
"What is the capital of Japan?",
"Explain how to implement a binary search tree"
]
for question in test_questions:
print("\n" + "="*60)
print(f"Non-training question: {question}")
# Use exact same format as training
prompt = f"<question>\n{question}\n</question>\n\n<answer>"
# Process input
inputs = tokenizer(prompt, return_tensors="pt")
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate with controlled randomness for more natural responses
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=500,
do_sample=True, # Enable sampling for temperature control
temperature=0.7, # Higher temperature for more diverse outputs
num_beams=1, # No beam search
use_cache=True # Enable KV caching
)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the answer part
if "<answer>" in response:
answer_part = response.split("<answer>")[1].strip()
# Remove closing tag if present
if "</answer>" in answer_part:
answer_part = answer_part.split("</answer>")[0].strip()
print("\nGenerated answer:")
print(answer_part)
else:
print("\nResponse (raw):")
print(response)
return True
def main():
args = parse_args()
# Create dataset and get knowledge base
dataset, knowledge_base = create_sample_dataset()
# If model exists and no --retrain, skip training
if os.path.exists(OUTPUT_DIR) and not args.retrain:
print(f"Model already exists at {OUTPUT_DIR}. Use --retrain to force retraining...")
else:
print("Training new model...")
if not train_model(dataset):
return
# Always run tests
print("\nRunning tests on model...")
# Create an offload directory for model parts that don't fit in GPU memory
offload_dir = os.path.join(OUTPUT_DIR, "offload")
os.makedirs(offload_dir, exist_ok=True)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
# Load base model with explicit offloading parameters
print("Loading base model with offloading enabled...")
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
device_map="auto",
offload_folder=offload_dir # Offload layers to CPU/disk when GPU VRAM is insufficient
)
# Load LoRA adapter with the same offload directory
print("Loading LoRA adapter...")
model = peft.PeftModel.from_pretrained(
base_model,
OUTPUT_DIR,
offload_folder=offload_dir # Offload layers to CPU/disk when GPU VRAM is insufficient
)
# Run knowledge base tests
print("\nRunning knowledge base tests...")
test_model(knowledge_base, model, tokenizer)
# Free up memory before running non-training tests
gc.collect()
torch.cuda.empty_cache()
# Run non-training data tests
print("\nTesting model on non-training data...")
test_non_training_data(model, tokenizer)
if __name__ == "__main__":
main()
Setup
This tutorial assumes you have Python3+
, pip
installed and are using venv
to manage your python environment.
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
Train & Test
python3 train.py --retrain
Testing
After training is complete, the script tests the model on both fine-tuned data and the original knowledge base.
Testing fine-tuned model
Question 1
Question: What is our custom meowMix function and how is it used?
Generated answer:
The meowMix function is a simple Kotlin function:
fun meowMix() = print("meow mix is yummy in my tummy")
It takes no parameters and is used in our cat-themed demo applications.
Expected answer:
The meowMix function is a simple Kotlin function:
fun meowMix() = print("meow mix is yummy in my tummy")
It takes no parameters and is used in our cat-themed demo applications.
Match percentage: 100.00%
Question 2
Question: What is our custom purr() function?
Generated answer:
The purr() function is a Kotlin utility:
fun purr(intensity: Int = 5) = "p${'r'.repeat(intensity)}"
It takes an optional intensity parameter that determines how many 'r's appear in the purr.
Expected answer:
The purr() function is a Kotlin utility:
fun purr(intensity: Int = 5) = "p${'r'.repeat(intensity)}"
It takes an optional intensity parameter that determines how many 'r's appear in the purr.
Match percentage: 100.00%
Question 3
Question: Tell me about Kenny Cason's background
Generated answer:
Kenny Cason is a software engineer who enjoys game development, AI, and really loves Kotlin.
Expected answer:
Kenny Cason is a software engineer who enjoys game development, AI, and really loves Kotlin.
Match percentage: 100.00%
Testing non-training data
Question 1
=== Testing model on non-training data ===
============================================================
Non-training question: What is the capital of Japan?
Generated answer:
Japan's capital is Tokyo.
Question 2
Non-training question: Explain how to implement a binary search tree
Generated answer:
A binary search tree (BST) is a data structure that has a root and two children. Each node in a BST:
- Is null or contains key data.
- Has a left child that is a BST.
- Has a right child that is a BST.
- Maintains an invariant: For every node, its value is greater than all values in its left subtree and less than all values in its right subtree.
To implement a BST, you can create a Node class and a BST class:
1. **Node Class**:
```java
public class Node {
public int key;
public Node left;
public Node right;
public Node(int key) {
this.key = key;
}
}
```
2. **BST Class**:
```java
public class BST {
public Node root;
public BST(int key) {
root = new Node(key);
}
public void insert(Node node) {
if (node == null) return;
if (node.key < root.key) {
root.left = node;
} else {
root.right = node;
}
}
public Node search(int key) {
if (root.key == key) return root;
if (key < root.key) {
return search(root.left, key);
} else {
return search(root.right, key);
}
}
}
```
Training output should look something like this:
(venv) $ python finetune_deepseek_r1_distilled_llama_8b.py --retrain
Creating sample dataset...
Created dataset with 9 examples (each Q&A repeated 3 times)
Training new model...
=== Starting QLoRA fine-tuning on deepseek-ai/deepseek-r1-distill-llama-8b ===
Loading model with 4-bit quantization...
Loading checkpoint shards: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 2/2 [02:01<00:00, 60.82s/it]
/home/ec2-user/arrived_llm/deepseek/deepseek-finetune/finetune_deepseek_r1_distilled_llama_8b.py:123: FutureWarning: `tokenizer` is deprecated and removed starting from version 0.16.0 for `SFTTrainer.__init__`. Use `processing_class` instead.
trainer = SFTTrainer(
Converting train dataset to ChatML: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 9/9 [00:00<00:00, 581.81 examples/s]
Applying chat template to train dataset: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 9/9 [00:00<00:00, 3744.17 examples/s]
Tokenizing train dataset: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 9/9 [00:00<00:00, 381.74 examples/s]
Truncating train dataset: 100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 9/9 [00:00<00:00, 1249.59 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
Starting training...
{'loss': 3.3385, 'grad_norm': 3.106924295425415, 'learning_rate': 0.00019733333333333335, 'mean_token_accuracy': 0.47806233167648315, 'epoch': 0.44}
{'loss': 3.2154, 'grad_norm': 2.8407278060913086, 'learning_rate': 0.0001946666666666667, 'mean_token_accuracy': 0.4890420734882355, 'epoch': 0.89}
{'loss': 2.4194, 'grad_norm': 3.7109458446502686, 'learning_rate': 0.000192, 'mean_token_accuracy': 0.6164383292198181, 'epoch': 1.0}
{'loss': 2.1271, 'grad_norm': 2.1629104614257812, 'learning_rate': 0.00018933333333333335, 'mean_token_accuracy': 0.6363870799541473, 'epoch': 1.44}
{'loss': 1.9, 'grad_norm': 1.9662961959838867, 'learning_rate': 0.0001866666666666667, 'mean_token_accuracy': 0.6581116914749146, 'epoch': 1.89}
{'loss': 2.3352, 'grad_norm': 2.779590606689453, 'learning_rate': 0.00018400000000000003, 'mean_token_accuracy': 0.5853658318519592, 'epoch': 2.0}
{'loss': 1.2907, 'grad_norm': 2.3935580253601074, 'learning_rate': 0.00018133333333333334, 'mean_token_accuracy': 0.777546152472496, 'epoch': 2.44}
{'loss': 1.1757, 'grad_norm': 2.1039140224456787, 'learning_rate': 0.00017866666666666668, 'mean_token_accuracy': 0.7798195779323578, 'epoch': 2.89}
{'loss': 1.5495, 'grad_norm': 2.8364014625549316, 'learning_rate': 0.00017600000000000002, 'mean_token_accuracy': 0.7317073345184326, 'epoch': 3.0}
{'loss': 0.8492, 'grad_norm': 1.758329153060913, 'learning_rate': 0.00017333333333333334, 'mean_token_accuracy': 0.8454606682062149, 'epoch': 3.44}
{'loss': 0.5398, 'grad_norm': 1.576306939125061, 'learning_rate': 0.00017066666666666668, 'mean_token_accuracy': 0.9070318639278412, 'epoch': 3.89}
{'loss': 0.4739, 'grad_norm': 3.4256539344787598, 'learning_rate': 0.000168, 'mean_token_accuracy': 0.8985507488250732, 'epoch': 4.0}
{'loss': 0.385, 'grad_norm': 2.0756847858428955, 'learning_rate': 0.00016533333333333333, 'mean_token_accuracy': 0.9369035065174103, 'epoch': 4.44}
{'loss': 0.311, 'grad_norm': 2.1463024616241455, 'learning_rate': 0.00016266666666666667, 'mean_token_accuracy': 0.9443701803684235, 'epoch': 4.89}
{'loss': 0.1619, 'grad_norm': 3.3488707542419434, 'learning_rate': 0.00016, 'mean_token_accuracy': 0.9726027250289917, 'epoch': 5.0}
{'loss': 0.2111, 'grad_norm': 2.5160067081451416, 'learning_rate': 0.00015733333333333333, 'mean_token_accuracy': 0.9565653204917908, 'epoch': 5.44}
{'loss': 0.1916, 'grad_norm': 3.1428894996643066, 'learning_rate': 0.00015466666666666667, 'mean_token_accuracy': 0.9653140008449554, 'epoch': 5.89}
{'loss': 0.1357, 'grad_norm': 0.8235794901847839, 'learning_rate': 0.000152, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 6.0}
{'loss': 0.2807, 'grad_norm': 5.5705413818359375, 'learning_rate': 0.00014933333333333335, 'mean_token_accuracy': 0.9619111120700836, 'epoch': 6.44}
{'loss': 0.1832, 'grad_norm': 1.7909528017044067, 'learning_rate': 0.00014666666666666666, 'mean_token_accuracy': 0.9702845215797424, 'epoch': 6.89}
{'loss': 0.1339, 'grad_norm': 2.4750988483428955, 'learning_rate': 0.000144, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 7.0}
{'loss': 0.2259, 'grad_norm': 1.9564673900604248, 'learning_rate': 0.00014133333333333334, 'mean_token_accuracy': 0.965137243270874, 'epoch': 7.44}
{'loss': 0.1943, 'grad_norm': 2.4208903312683105, 'learning_rate': 0.00013866666666666669, 'mean_token_accuracy': 0.9704830497503281, 'epoch': 7.89}
{'loss': 0.139, 'grad_norm': 2.276552677154541, 'learning_rate': 0.00013600000000000003, 'mean_token_accuracy': 0.9710144996643066, 'epoch': 8.0}
{'loss': 0.1426, 'grad_norm': 1.158094048500061, 'learning_rate': 0.00013333333333333334, 'mean_token_accuracy': 0.9704830497503281, 'epoch': 8.44}
{'loss': 0.1618, 'grad_norm': 0.9444708824157715, 'learning_rate': 0.00013066666666666668, 'mean_token_accuracy': 0.9649387151002884, 'epoch': 8.89}
{'loss': 0.1435, 'grad_norm': 2.507293939590454, 'learning_rate': 0.00012800000000000002, 'mean_token_accuracy': 0.9710144996643066, 'epoch': 9.0}
{'loss': 0.1385, 'grad_norm': 1.9922103881835938, 'learning_rate': 0.00012533333333333334, 'mean_token_accuracy': 0.9672569185495377, 'epoch': 9.44}
{'loss': 0.1316, 'grad_norm': 0.5764124989509583, 'learning_rate': 0.00012266666666666668, 'mean_token_accuracy': 0.9683633744716644, 'epoch': 9.89}
{'loss': 0.1062, 'grad_norm': 0.5229787230491638, 'learning_rate': 0.00012, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 10.0}
{'loss': 0.1013, 'grad_norm': 0.40718427300453186, 'learning_rate': 0.00011733333333333334, 'mean_token_accuracy': 0.9769353121519089, 'epoch': 10.44}
{'loss': 0.1479, 'grad_norm': 1.3826055526733398, 'learning_rate': 0.00011466666666666667, 'mean_token_accuracy': 0.9619111120700836, 'epoch': 10.89}
{'loss': 0.111, 'grad_norm': 1.7971856594085693, 'learning_rate': 0.00011200000000000001, 'mean_token_accuracy': 0.9726027250289917, 'epoch': 11.0}
{'loss': 0.0986, 'grad_norm': 0.8073880672454834, 'learning_rate': 0.00010933333333333333, 'mean_token_accuracy': 0.9698874652385712, 'epoch': 11.44}
{'loss': 0.0856, 'grad_norm': 0.49761784076690674, 'learning_rate': 0.00010666666666666667, 'mean_token_accuracy': 0.9737091809511185, 'epoch': 11.89}
{'loss': 0.143, 'grad_norm': 3.0372018814086914, 'learning_rate': 0.00010400000000000001, 'mean_token_accuracy': 0.9512194991111755, 'epoch': 12.0}
{'loss': 0.0823, 'grad_norm': 0.8601316213607788, 'learning_rate': 0.00010133333333333335, 'mean_token_accuracy': 0.9698874652385712, 'epoch': 12.44}
{'loss': 0.0716, 'grad_norm': 0.5071003437042236, 'learning_rate': 9.866666666666668e-05, 'mean_token_accuracy': 0.9687604308128357, 'epoch': 12.89}
{'loss': 0.0636, 'grad_norm': 1.115402340888977, 'learning_rate': 9.6e-05, 'mean_token_accuracy': 0.95652174949646, 'epoch': 13.0}
{'loss': 0.0621, 'grad_norm': 0.44717466831207275, 'learning_rate': 9.333333333333334e-05, 'mean_token_accuracy': 0.9580893963575363, 'epoch': 13.44}
{'loss': 0.0505, 'grad_norm': 0.49897530674934387, 'learning_rate': 9.066666666666667e-05, 'mean_token_accuracy': 0.9737091809511185, 'epoch': 13.89}
{'loss': 0.0539, 'grad_norm': 1.6923249959945679, 'learning_rate': 8.800000000000001e-05, 'mean_token_accuracy': 0.9863013625144958, 'epoch': 14.0}
{'loss': 0.043, 'grad_norm': 0.41969457268714905, 'learning_rate': 8.533333333333334e-05, 'mean_token_accuracy': 0.9841817021369934, 'epoch': 14.44}
{'loss': 0.0394, 'grad_norm': 0.9888722896575928, 'learning_rate': 8.266666666666667e-05, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 14.89}
{'loss': 0.0127, 'grad_norm': 0.4191919267177582, 'learning_rate': 8e-05, 'mean_token_accuracy': 1.0, 'epoch': 15.0}
{'loss': 0.0275, 'grad_norm': 0.6723085641860962, 'learning_rate': 7.733333333333333e-05, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 15.44}
{'loss': 0.0298, 'grad_norm': 0.3247329592704773, 'learning_rate': 7.466666666666667e-05, 'mean_token_accuracy': 0.9841817021369934, 'epoch': 15.89}
{'loss': 0.023, 'grad_norm': 1.0682705640792847, 'learning_rate': 7.2e-05, 'mean_token_accuracy': 0.9863013625144958, 'epoch': 16.0}
{'loss': 0.019, 'grad_norm': 0.13458505272865295, 'learning_rate': 6.933333333333334e-05, 'mean_token_accuracy': 0.9904777854681015, 'epoch': 16.44}
{'loss': 0.0182, 'grad_norm': 0.1721552610397339, 'learning_rate': 6.666666666666667e-05, 'mean_token_accuracy': 0.9902792572975159, 'epoch': 16.89}
{'loss': 0.037, 'grad_norm': 0.9847946763038635, 'learning_rate': 6.400000000000001e-05, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 17.0}
{'loss': 0.026, 'grad_norm': 0.6007171869277954, 'learning_rate': 6.133333333333334e-05, 'mean_token_accuracy': 0.978084146976471, 'epoch': 17.44}
{'loss': 0.0243, 'grad_norm': 0.5682116150856018, 'learning_rate': 5.866666666666667e-05, 'mean_token_accuracy': 0.9824796468019485, 'epoch': 17.89}
{'loss': 0.0448, 'grad_norm': 1.3603119850158691, 'learning_rate': 5.6000000000000006e-05, 'mean_token_accuracy': 0.9710144996643066, 'epoch': 18.0}
{'loss': 0.0224, 'grad_norm': 0.3100469708442688, 'learning_rate': 5.333333333333333e-05, 'mean_token_accuracy': 0.9929521530866623, 'epoch': 18.44}
{'loss': 0.0299, 'grad_norm': 0.6363517642021179, 'learning_rate': 5.0666666666666674e-05, 'mean_token_accuracy': 0.9826781749725342, 'epoch': 18.89}
{'loss': 0.0206, 'grad_norm': 0.627301812171936, 'learning_rate': 4.8e-05, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 19.0}
{'loss': 0.0279, 'grad_norm': 0.625548243522644, 'learning_rate': 4.5333333333333335e-05, 'mean_token_accuracy': 0.9826781749725342, 'epoch': 19.44}
{'loss': 0.0188, 'grad_norm': 0.2290528565645218, 'learning_rate': 4.266666666666667e-05, 'mean_token_accuracy': 0.9904777854681015, 'epoch': 19.89}
{'loss': 0.0337, 'grad_norm': 0.7975773811340332, 'learning_rate': 4e-05, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 20.0}
{'loss': 0.0173, 'grad_norm': 0.12529949843883514, 'learning_rate': 3.733333333333334e-05, 'mean_token_accuracy': 0.9904777854681015, 'epoch': 20.44}
{'loss': 0.0236, 'grad_norm': 0.2748052775859833, 'learning_rate': 3.466666666666667e-05, 'mean_token_accuracy': 0.984380230307579, 'epoch': 20.89}
{'loss': 0.0185, 'grad_norm': 0.809395968914032, 'learning_rate': 3.2000000000000005e-05, 'mean_token_accuracy': 0.9863013625144958, 'epoch': 21.0}
{'loss': 0.0171, 'grad_norm': 0.11877240985631943, 'learning_rate': 2.9333333333333336e-05, 'mean_token_accuracy': 0.9902792572975159, 'epoch': 21.44}
{'loss': 0.0227, 'grad_norm': 0.2608608901500702, 'learning_rate': 2.6666666666666667e-05, 'mean_token_accuracy': 0.9841817021369934, 'epoch': 21.89}
{'loss': 0.0293, 'grad_norm': 1.0545191764831543, 'learning_rate': 2.4e-05, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 22.0}
{'loss': 0.0173, 'grad_norm': 0.15324412286281586, 'learning_rate': 2.1333333333333335e-05, 'mean_token_accuracy': 0.9902792572975159, 'epoch': 22.44}
{'loss': 0.0228, 'grad_norm': 0.47179096937179565, 'learning_rate': 1.866666666666667e-05, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 22.89}
{'loss': 0.0352, 'grad_norm': 0.7394832968711853, 'learning_rate': 1.6000000000000003e-05, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 23.0}
{'loss': 0.0185, 'grad_norm': 0.5680955648422241, 'learning_rate': 1.3333333333333333e-05, 'mean_token_accuracy': 0.989130437374115, 'epoch': 23.44}
{'loss': 0.0189, 'grad_norm': 0.3210718035697937, 'learning_rate': 1.0666666666666667e-05, 'mean_token_accuracy': 0.9878048896789551, 'epoch': 23.89}
{'loss': 0.0302, 'grad_norm': 0.6938251256942749, 'learning_rate': 8.000000000000001e-06, 'mean_token_accuracy': 0.9756097793579102, 'epoch': 24.0}
{'loss': 0.0189, 'grad_norm': 0.36323145031929016, 'learning_rate': 5.333333333333334e-06, 'mean_token_accuracy': 0.9878048896789551, 'epoch': 24.44}
{'loss': 0.018, 'grad_norm': 0.2214939445257187, 'learning_rate': 2.666666666666667e-06, 'mean_token_accuracy': 0.9866560697555542, 'epoch': 24.89}
{'loss': 0.0175, 'grad_norm': 0.6966819763183594, 'learning_rate': 0.0, 'mean_token_accuracy': 0.9855072498321533, 'epoch': 25.0}
{'train_runtime': 312.0504, 'train_samples_per_second': 0.961, 'train_steps_per_second': 0.24, 'train_loss': 0.3531599778061112, 'epoch': 25.0}
100%|โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ| 75/75 [05:12<00:00, 4.16s/it]
Saving model to deepseek-finetuned
EC2 Nodes
Here's a comparison of EC2 instances suitable for fine-tuning this model:
Instance Type | Specs | Cost | Notes |
---|---|---|---|
g5.xlarge | 4 vCPUs, 16GB RAM, 1x A10G GPU (24GB VRAM) | $1.01/hr ($24.14/day, $724/mo) | Model fits entirely in GPU memory, fastest training |
g4dn.xlarge | 4 vCPUs, 16GB RAM, 1x T4 GPU (16GB) | $0.53/hr ($12.72/day, $382/mo) | Requires some CPU offloading, slightly slower |
I used the g5.xlarge for this tutorial as the extra VRAM allows the entire model to stay in GPU memory, which speeds up training considerably. I also tested on the smaller g4dn.xlarge to demonstrate that CPU offloading works as expected, though with a modest increase in training time.