100% remote role
Need to work as per EST
No agency or C2C will NOT be considered and Visa sponsorship is not available nor provided
Machine Learning Engineer: Framework Migration &Systems Optimization (PyTorch to JAX)
We are seeking a specialized Machine Learning Engineer with deep expertise in the
high-performance AI stack. This role isn''t just about "translating" code; it s about
re-architecting Large Language Models (LLMs) to thrive in a JAX-native environment,
specifically targeting TPU and GPU clusters at scale. You will bridge the gap between high-level PyTorch research implementations and thefunctional, XLA-optimized world of JAX/XLA, ensuring that our models achieve maximum throughput and hardware efficiency.
1. Core Framework Migration
Structural Porting: Manually migrate complex PyTorch LLM architectures (Transformers, MoE, SSMs) into JAX-based frameworks (Equinox, Flax, or Pax).
State Management: Transition imperative PyTorch state management to JAX s purely functional paradigm, handling PRNGKey management and immutable state updates with precision.
Weight Translation: Develop robust pipelines for checkpoint conversion, ensuring numerical parity between frameworks via rigorous unit testing and error tolerance checks.
2. Advanced Profiling & Numerical Stability
Bottleneck Analysis: Use the NVIDIA Nsight and TensorBoard Profiler to identify XLA compilation overheads, excessive rematerialization, or un-fused kernels.
Numerical Debugging: Implement precision-tracking tools to ensure that $BF16$ or $FP8$ training runs remain stable during the transition, preventing gradient divergence.
3. Scaling & Distributed Training
Parallelism Strategies: Implement and optimize Fully Sharded Data Parallelism (FSDP) equivalents in JAX (using pjit or sharding APIs).
Hybrid Parallelism: Design 3D parallelism strategies (Data, Pipeline, and Tensor) tailored for the interconnect topology (e.g., NVLink or TPU IC) of the target hardware.
4. Hardware-Aware Optimization
XLA Mastery: Understand and influence the XLA (Accelerated Linear Algebra) compiler behavior. You will optimize HLO (High-Level Optimizer) graphs to minimize "jit-time" and maximize "run-time" efficiency.
Memory Management: Apply optimizations like Selective Activation Checkpointing and memory-efficient attention (FlashAttention-2 JAX implementations) based on the specific HBM (High Bandwidth Memory) constraints of the hardware.
5. Fine-Tuning & Adaptation
Efficient Fine-Tuning: Port PyTorch-based PEFT (LoRA, DoRA) methods into the JAX stack.
Architectural Evolution: Stay ahead of the curve by adapting JAX implementations for newer primitives like Mamba/SSMs, Grouped-Query Attention (GQA), and Linear Attention as they emerge in the research space.
Familiarity with the following technical Stack & Tooling
1. Core Frameworks & Libraries:
JAX Ecosystem: Expertise in Flax or Equinox (for model definition), Optax (for optimization/schedules), and Orbax (for checkpointing).
PyTorch Ecosystem: Deep knowledge of PyTorch 2.x, including torch.compile, DistributedDataParallel (DDP), and FSDP.
Intermediate Representations: Proficiency in HLO (High-Level Optimizer) and MLIR to understand how JAX code translates to hardware instructions.
Data Loaders: Experience migrating from torch.utils.data to Grain or tf.data for high-throughput JAX pipelines.
2. Profiling & Observability device memory traffic.
JAX Profiler / TensorBoard: For identifying XLA compilation bottlenecks and tracing
NVIDIA Nsight Systems: To analyze GPU utilization, SM occupancy, and NVLink
Perfetto: For deep-dive trace analysis across multi-node TPU/GPU clusters.
3. Infrastructure & Hardware
Accelerator Hardware: Strong understanding of NVIDIA H100/A100 (Hopper/Ampere) architecture and Google TPU (v4/v5p) topology.
Orchestration: Experience with Slurm or Kubernetes (K8s) for managing large-scale training jobs.
Cloud Providers: Proficiency in Google Cloud (Google Cloud Platform) for TPUs or AWS/Azure for high-end GPU instances.
Core Skills & Competencies
1. Software Engineering Excellence
Functional Programming: A shift in mindset from OOP (Object-Oriented) to pure functions, immutability, and stateless logic.
Asynchronous Programming: Understanding JAX s asynchronous dispatch model and how to avoid "host-sync" bottlenecks.
Testing Rigor: Ability to write property-based tests for numerical stability.
2. Distributed Systems Knowledge
Collective Communications: Deep understanding of All-Reduce, All-Gather, and Reduce-Scatter primitives.
Network Topology: Understanding how rack-level interconnects (e.g., InfiniBand) affect the choice of 3D parallelism strategies.
3. Mathematical & AI Domain Expertise (Desirable)
Linear Algebra: Mastery of tensor contractions, Einstein summation (einsum), and matrix decomposition.
Mixed Precision Training: Expert-level knowledge of Stochastic Rounding, Loss Scaling, and the nuances of BF16 vs. FP8 training.
Architecture Insight: Ability to decompose modern LLM components (KV Caches, Rotary Embeddings, Gated Linear Units) into their primitive mathematical operations