Minimal Code Changes: Typically 2–4 Lines for Basic LLM Inference
Switching from GPU (CUDA) to TPU in PyTorch for LLM inference is straightforward, thanks to PyTorch/XLA's design to mimic native PyTorch semantics. As of November 2025, with PyTorch/XLA 2.5+ and tools like vLLM's TPU backend, you often need just 2–3 code changes for a simple setup—primarily swapping device references and adding an import. This assumes a standard inference script (e.g., using Hugging Face Transformers or vLLM) without custom dynamic shapes or advanced distributed features. For more complex agent systems with loops or multi-host scaling, it could rise to 4–6 changes, plus potential optimizations for performance.
These changes are mostly mechanical: no rewriting model architecture or training loops. The XLA compiler handles the rest under the hood. Below, I'll break it down with examples based on official docs and recent benchmarks.
Core Changes Required
- Import torch_xla: This initializes the XLA backend. (1 line)
- Change Device from 'cuda' to 'xla': Replace device='cuda' with device=xm.xla_device() (or just 'xla' in simpler cases). Remove any .cuda() or .to('cuda') calls. (1–2 lines)
- Optional: Add mark_step() for Graph Optimization: In loops (e.g., generation steps), insert xm.mark_step() or xm.rendezvous() to break large graphs and improve throughput—crucial for autoregressive inference in LLMs. (1 line, but recommended for >10% perf gains)
- For Scaling: Wrap in xmp.spawn() or torch_xla.launch(): If using multi-core TPUs (e.g., v5p-8), add this for data parallelism. (1–2 lines, often in a launcher script)
No changes needed for optimizers, losses, or model definitions—XLA tensors are drop-in compatible.
Example: Basic LLM Inference (Hugging Face Style)
GPU Version (original):
import torch
from transformers import pipeline
model = pipeline("text-generation", model="meta-llama/Llama-3-8B", device="cuda")
output = model("Hello, world!", max_length=50)
TPU Version (after changes):
import torch
import torch_xla.core.xla_model as xm # Change 1: Import
from transformers import pipeline
device = xm.xla_device() # Change 2: Device swap (1 line)
model = pipeline("text-generation", model="meta-llama/Llama-3-8B", device=device)
output = model("Hello, world!", max_length=50)
# Optional: For loop-based generation, add xm.mark_step() inside the loop (Change 3)
- Total Changes: 2 lines (import + device). Runs on a single TPU core; scale with torch_xla.distributed.xla_multiprocessing.spawn() for pods (+1 line).
Example: vLLM for High-Throughput Serving (Production-Ready)
vLLM's TPU support (GA in Oct 2025) is even more seamless—often zero code changes if your GPU script uses vLLM's API, as it auto-detects backends via environment vars. GPU Version:
from vllm import LLM
llm = LLM(model="meta-llama/Llama-3-8B", tensor_parallel_size=1, device="cuda")
outputs = llm.generate(["Hello!"])
TPU Version:
import os
os.environ["PJRT_DEVICE"] = "TPU" # Change 1: Env var (or set in shell)
from vllm import LLM
llm = LLM(model="meta-llama/Llama-3-8B", tensor_parallel_size=8) # Auto-uses TPU; adjust for cores
outputs = llm.generate(["Hello!"])
- Total Changes: 1 line (env var). Benchmarks show 2–5x throughput on TPU v5p vs. H100 for 70B models, with paged attention intact.
For Agent AI Systems (e.g., LangChain with Iterative Calls)
Agents add loops for tool calls/reasoning, so expect 3–4 changes to handle dynamic shapes:
- Same as above, plus xm.mark_step() in agent loops to avoid recompilations (e.g., per query step).
- Example: In a ReAct loop, insert after each LLM call. No other rewrites needed—PyTorch/XLA's eager mode (preview in 2.5) hides most XLA quirks.
| Scenario | Typical # Changes | Key Additions | Perf Notes (2025) |
|---|---|---|---|
| Simple Inference | 2 | Import + device swap | Sub-1s TTFT on v5e |
| vLLM Serving | 0–1 | Env var (PJRT_DEVICE=TPU) | 140+ tokens/s on v5p |
| Agent Loops | 3–4 | + mark_step() in iterations | Handles dynamic prompts; use StaticCache for KV |
| Multi-Host Scaling | +1–2 | xmp.spawn() wrapper | Pods up to 256 cores |
Caveats and Tips
- Performance Tweaks: Initial runs may be slow due to compilation (first forward pass takes 10–30s); subsequent are fast. Avoid dynamic shapes in loops without barriers—common LLM gotcha.
- Setup Overhead: Install pip install torch_xla (+ env setup on Google Cloud). Test on free Colab TPUs first.
- Ongoing Evolution: PyTorch/XLA 2.5 deprecates old APIs for a "more native" eager mode, potentially reducing changes to near-zero by mid-2026.
- When More Changes?: Custom kernels or non-XLA ops (rare in LLMs) might need JAX fallbacks via Torchax (+2–3 lines).
In short, it's not a full rewrite—think "device swap + import" for 80% of cases. Start with the PyTorch/XLA docs or vLLM TPU guide for templates. If your code has specifics (e.g., share a snippet), I can pinpoint exact deltas!