Senior Machine Learning Engineer with JAX W2 role

Remote • Posted 2 hours ago • Updated 3 minutes ago
Contract Independent
Contract W2
Remote
Depends on Experience
Fitment

Dice Job Match Score™

🫥 Flibbertigibetting...

Job Details

Skills

  • Machine Learning Engineer: Framework Migration &Systems Optimization (PyTorch to JAX)We are seeking a specialized Machine Learning Engineer with deep expertise in thehigh-performance AI stack. This role isn''t just about "translating" code; its aboutre-architecting Large Language Models (LLMs) to thrive in a JAX-native environment

Summary

100% remote role

Need to work as per EST


No agency or C2C will NOT be considered and Visa sponsorship is not available nor provided

Machine Learning Engineer: Framework Migration &Systems Optimization (PyTorch to JAX)

We are seeking a specialized Machine Learning Engineer with deep expertise in the

high-performance AI stack. This role isn''t just about "translating" code; it s about

re-architecting Large Language Models (LLMs) to thrive in a JAX-native environment,

specifically targeting TPU and GPU clusters at scale. You will bridge the gap between high-level PyTorch research implementations and thefunctional, XLA-optimized world of JAX/XLA, ensuring that our models achieve maximum throughput and hardware efficiency.

1. Core Framework Migration

Structural Porting: Manually migrate complex PyTorch LLM architectures (Transformers, MoE, SSMs) into JAX-based frameworks (Equinox, Flax, or Pax).

State Management: Transition imperative PyTorch state management to JAX s purely functional paradigm, handling PRNGKey management and immutable state updates with precision.

Weight Translation: Develop robust pipelines for checkpoint conversion, ensuring numerical parity between frameworks via rigorous unit testing and error tolerance checks.

2. Advanced Profiling & Numerical Stability

Bottleneck Analysis: Use the NVIDIA Nsight and TensorBoard Profiler to identify XLA compilation overheads, excessive rematerialization, or un-fused kernels.

Numerical Debugging: Implement precision-tracking tools to ensure that $BF16$ or $FP8$ training runs remain stable during the transition, preventing gradient divergence.

3. Scaling & Distributed Training

Parallelism Strategies: Implement and optimize Fully Sharded Data Parallelism (FSDP) equivalents in JAX (using pjit or sharding APIs).

Hybrid Parallelism: Design 3D parallelism strategies (Data, Pipeline, and Tensor) tailored for the interconnect topology (e.g., NVLink or TPU IC) of the target hardware.

4. Hardware-Aware Optimization

XLA Mastery: Understand and influence the XLA (Accelerated Linear Algebra) compiler behavior. You will optimize HLO (High-Level Optimizer) graphs to minimize "jit-time" and maximize "run-time" efficiency.

Memory Management: Apply optimizations like Selective Activation Checkpointing and memory-efficient attention (FlashAttention-2 JAX implementations) based on the specific HBM (High Bandwidth Memory) constraints of the hardware.

5. Fine-Tuning & Adaptation

Efficient Fine-Tuning: Port PyTorch-based PEFT (LoRA, DoRA) methods into the JAX stack.

Architectural Evolution: Stay ahead of the curve by adapting JAX implementations for newer primitives like Mamba/SSMs, Grouped-Query Attention (GQA), and Linear Attention as they emerge in the research space.

Familiarity with the following technical Stack & Tooling

1. Core Frameworks & Libraries:

JAX Ecosystem: Expertise in Flax or Equinox (for model definition), Optax (for optimization/schedules), and Orbax (for checkpointing).

PyTorch Ecosystem: Deep knowledge of PyTorch 2.x, including torch.compile, DistributedDataParallel (DDP), and FSDP.

Intermediate Representations: Proficiency in HLO (High-Level Optimizer) and MLIR to understand how JAX code translates to hardware instructions.

Data Loaders: Experience migrating from torch.utils.data to Grain or tf.data for high-throughput JAX pipelines.

2. Profiling & Observability device memory traffic.

JAX Profiler / TensorBoard: For identifying XLA compilation bottlenecks and tracing

NVIDIA Nsight Systems: To analyze GPU utilization, SM occupancy, and NVLink

Perfetto: For deep-dive trace analysis across multi-node TPU/GPU clusters.

3. Infrastructure & Hardware

Accelerator Hardware: Strong understanding of NVIDIA H100/A100 (Hopper/Ampere) architecture and Google TPU (v4/v5p) topology.

Orchestration: Experience with Slurm or Kubernetes (K8s) for managing large-scale training jobs.

Cloud Providers: Proficiency in Google Cloud (Google Cloud Platform) for TPUs or AWS/Azure for high-end GPU instances.

Core Skills & Competencies

1. Software Engineering Excellence

Functional Programming: A shift in mindset from OOP (Object-Oriented) to pure functions, immutability, and stateless logic.

Asynchronous Programming: Understanding JAX s asynchronous dispatch model and how to avoid "host-sync" bottlenecks.

Testing Rigor: Ability to write property-based tests for numerical stability.

2. Distributed Systems Knowledge

Collective Communications: Deep understanding of All-Reduce, All-Gather, and Reduce-Scatter primitives.

Network Topology: Understanding how rack-level interconnects (e.g., InfiniBand) affect the choice of 3D parallelism strategies.

3. Mathematical & AI Domain Expertise (Desirable)

Linear Algebra: Mastery of tensor contractions, Einstein summation (einsum), and matrix decomposition.

Mixed Precision Training: Expert-level knowledge of Stochastic Rounding, Loss Scaling, and the nuances of BF16 vs. FP8 training.

Architecture Insight: Ability to decompose modern LLM components (KV Caches, Rotary Embeddings, Gated Linear Units) into their primitive mathematical operations

Employers have access to artificial intelligence language tools (“AI”) that help generate and enhance job descriptions and AI may have been used to create this description. The position description has been reviewed for accuracy and Dice believes it to correctly reflect the job opportunity.
  • Dice Id: 10107189
  • Position Id: 8913262
  • Posted 2 hours ago
Create job alert
Set job alertNever miss an opportunity! Create an alert based on the job you applied for.

Similar Jobs

Remote or Almont, Colorado

Today

Contract

Remote or New Jersey

2d ago

Easy Apply

Full-time

Remote

26d ago

Easy Apply

Full-time

140,000 - 185,000

Remote

26d ago

Easy Apply

Full-time

28,000 - 45,000

Search all similar jobs