Big Board Alerts

November 22, 2023
How Amazon Search M5 saved 30% for LLM training cost by using AWS Trainium

For decades, Amazon has pioneered and innovated machine learning (ML), bringing delightful experiences to its customers. From the earliest days, Amazon has used ML for various use cases such as book recommendations, search, and fraud detection. Similar to the rest of the industry, the advancements of accelerated hardware have allowed Amazon teams to pursue model architectures using neural networks and deep learning (DL).

The M5 program within Amazon Search owns the discovery learning strategy for Amazon and builds large-scale models across multi-lingual, multi-locale, multi-entity, multitask, and multi-modal such as text, image, and video. The M5 program has been serving universal embeddings and large-scale foundation models to hundreds of ML teams across Amazon while maintaining strict controls over cost optimization. In order to achieve this, the M5 team regularly evaluates new techniques to reduce cost.

Like many ML organizations, accelerators are largely used to accelerate DL training and inference. When AWS launched purpose-built accelerators with the first release of AWS Inferentia in 2020, the M5 team quickly began to utilize them to more efficiently deploy production workloads, saving both cost and reducing latency. Last year, AWS launched its AWS Trainium accelerators, which optimize performance per cost for developing and building next generation DL models. In this post, we discuss how M5 was able to reduce the cost to train their models by 30%, and share some of the best practices we learned along the way.

Trainium instances

With the advances in purpose-built accelerators, Amazon also provides compelling accelerators in the form of AWS Inferentia and Trainium. As their names imply, these chips are optimized to exceed the needs of inference and training workloads, respectively. For large-scale training of foundation models that reach billions of parameters in size, Trainium Trn1 and Trn1n instances are ideal choices due to their characteristics. Trn1 instances are powered by the state-of-the-art NeuronCore-v2, and have a copious amount of accelerator compute and memory. Trn1n instances can also be chosen for a greater amount of networking bandwidth (1,600 Gbs), so are ideally suited for performant training with cost optimization in mind.

To use accelerators, you need a software layer to support them. With Trn and Inf chips, the AWS Neuron SDK unlocks Amazon purpose-built accelerators with the help of PyTorch XLA. PyTorch XLA converts PyTorch’s eager mode to lazy mode graph-based implementation. These graphs are then used and further compiled to be used with the accelerator. PyTorch Neuron (part of the Neuron SDK) enables PyTorch users to train their models on Trainium NeuronCores with a few lines of code.

Model and workload

The M5 team trains and deploys foundational models and universal representations to assist various teams across Amazon in bringing delight to customers. One such model is a text encoder model followed by a multi-layer perceptron (MLP) with explicit or implicit feature interactions defined by the neural network architecture with hundreds of millions of trainable parameters. This model is trained on billions of tokens, and is used to generate millions of embeddings in an offline batch inference setting. These embeddings are inputs to a customer-facing tier-1 Amazon service.

The infrastructure for the production pipeline uses AWS Batch with fair share queuing strategies, using an EFA-enabled multi-node trn1.32xlarge cluster as the compute for model training. Functionally, the production pipeline performs incremental model training, evaluation of trained model, and offline batch inference on the trained model, all using PyTorch as the underlying DL library.


Delighting our customers is a foremost tenet. Given the customer-facing nature of the pipeline, it’s critical that all service-level agreements (SLAs) be met without regressions. We identified two critical acceptance criteria to adapt our existing GPU production pipeline and transition it to Trainium:

Model quality – The quality of our models directly impacts customer experience. We require that there should be less than 0.1% difference in model quality between GPU and Trainium.
Training throughput – We iteratively train our models periodically to provide the freshest experience to our customers. We require that model convergence must be achieved within a predefined period of time (such as 1 week) to meet our production SLAs.

In the following sections, we share our journey of working backward from this criteria, and our learnings to support Amazon-scale production workloads.

Training script

Before starting with model training, we need to make changes to the training script to make it XLA compliant. Given the size of the model, we use distributed data parallel (DDP) to train the model. DDP allows us to increase the throughput of model training by scaling up the number of machines used to run model training, without any code changes. We followed the instructions provided in the Neuron PyTorch MLP training tutorial to add XLA-specific constructs in our training scripts. These code changes are straightforward to implement. The following are some significant technical learnings from the exercise that greatly improved our model throughput:

Placement of xm.mark_step() – xm.mark_step() compiles and runs the lazily collected computation graphs. Invoking mark_step too many times will lead to a larger number of small graphs, whereas invoking it too few times will lead to few, but large graphs. Depending on your application, the throughput and implementation of your model training will vary based on your placement of xm.mark_step(). Our implementation places one xm.mark_step() after a forward and backward pass, and one after the optimizer step.
Data loader wrapping with XLA multiprocessing device loader – This is a critical step that can be easily missed. The multiprocessing device loader torch_xla.distributed.parallel_loader.MpDeviceLoader loads training data on each XLA device with options to preload and overlap data loading with device runs for improving throughput. The device loader also invokes xm.mark_step() and is therefore able to build graphs for data loading to device from host.

Compilation for Trainium

Traditionally, the model development cycle with GPUs involves making changes to the model or training script and directly running it on the GPU device. Accelerators such as Trainium that use XLA require an additional step before model training can be run on the accelerator. XLA computation graphs can only be run after they have been compiled. Generally, there are two ways to perform this compilation: Ahead of Time (AOT), where you trace and compile all graphs first and then run them, or Just In Time (JIT), where graphs are traced, compiled, and run as they are encountered. The Neuron SDK provides both of these out of the box. Typically, AOT compilation is performed first. Graphs are then run after this compilation. If new graphs are encountered, the Neuron runtime invokes a JIT compilation before running them. To perform AOT compilation, the Neuron SDK provides neuron_parallel_compile, a compilation utility that extracts graphs from a trial run of the training script and performs parallel AOT compilation.

An important aspect of AOT compilation is to ensure that no new computation graphs are created over the course of training. One source of new computation graphs (and therefore recompilations) is dynamic shapes of the training batches during model training. We found that using static shapes and fixed-size batches eliminates training time compilations and greatly improves training throughput without any effect on model accuracy. By enforcing such constraints on training, we observed that only 4–5 steps of model training, one step of model validation, and checkpointing the model one time is required for tracing all the graphs during AOT compilation. It’s important to note that the Neuron SDK is constantly evolving, and in the future will support dynamic shapes as well.

Furthermore, the compiled graphs are stored in the Neuron Persistent Cache on disk or in an Amazon Simple Storage Service (Amazon S3) bucket. This is especially useful for production workloads where model architecture and training configuration doesn’t change. Therefore, the overhead of compilation is incurred just one time. Using the cache is as simple as setting an environment flag:


The Neuron compiler also provides three compiler-level optimization options (O1, O2, O3) to balance compilation time and model run throughput. O1 enables core optimizations on the compute graph and minimizes compilation time, O3 provides improved model run throughput at the cost of higher compilation time, and O2 (default option) is a balance between the two. For our use case, we used the O1 optimization and observed an 86% reduction in compilation time with no change to model accuracy metrics, while observing approximately a 5–7% reduction in throughput compared to the default optimization (O2). Depending on the use case, you can choose different levels of optimization.

To summarize, we used the following flags for compilation:

NEURON_CC_FLAGS=”–target trn1 –auto-cast all –auto-cast-type bf16 –model-type transformer –optlevel O1″

Checkpoint compatibility

When compilation is successfully complete, we can proceed to train our models on Trainium. As mentioned earlier, we incrementally train our models, meaning we load a previously trained model checkpoint and continue training with new data. PyTorch and PyTorch XLA allow seamless transitioning between accelerators through checkpoint interoperability. Having the flexibility of moving between GPU and Trainium enabled us to seamlessly load the previous GPU model and train on Trainium machines. This was critical to ensure that we can initialize our model with the best previously trained model without any production downtime or loss in model accuracy.

Because the GPU model was saved using standard PyTorch model saving utilities, we were able to use the PyTorch checkpoint loading utility to load the GPU model on Trainium devices.

For example, on GPU/CPU, you can save the model with the following code:, PATH)

Then you load the model back on Trainium:

import torch_xla.core.xla_model as xm
xla_device = xm.xla_device()
model = MyModel(*args, **kwargs)

Similarly, you can save the model on Trainium with the following code:

import torch_xla.core.xla_model as xm
# automatically moves the data to CPU for the master device, PATH)

And load the model back on GPU/CPU:

model = MyModel(*args, **kwargs)
model.load_state_dict(torch.load(PATH)) # can be any device

In fact, because we use DDP for model training, the model loading is agnostic of the number of machines used to train the previous checkpoint. This allows us to horizontally scale the Trn1 fleet with no code changes or adverse effects to model training. These PyTorch-based checkpoints can be directly used or even torch-scripted for inference use cases on AWS Inferentia2 or other accelerators.

Operational stability

It cannot be emphasized enough that running workloads in production requires multiple SLAs to be met. For our use case, apart from the model quality and training throughput SLAs, it’s imperative that the production pipeline be operationally stable, meaning minimal downtime and disruptions during model training, evaluation, and inference.

As with the existing GPU based pipeline, we added numerous mechanisms to make the pipeline operationally stable. Before starting model training, we run multiple sanity tests to assess the health of the machines. These tests generally include simple tensor operations to verify the health of the accelerator devices. We have observed that for distributed training, it’s important to run tests to verify collective communication between instances as well. We used the NCCOM test suite from the Neuron SDK to achieve this, running a variety of operations such as all-gather, all-reduce, and reduce-scatter.

Even after following the suggestions we’ve mentioned, we have observed that transient issues are inevitable in any pipeline, irrespective of the underlying accelerator. To build resiliency in any training pipeline, we recommend building in retry mechanisms to resolve these potential issues. We use AWS Batch automated retries to retry jobs that encounter a transient failure during model training. These restarts can be costly if a failure is encountered towards the end of training. To counter this problem, we have adapted our training scripts to load a previously trained model checkpoint and continue training from that point. With this functionality, we are able to aggressively restart failed training jobs with minimal overhead.

With these resiliency mechanisms in place, we were able to achieve 98.5% success rates for our workloads on Trn1, comparable to our existing GPU pipeline success rates.


To validate the accuracy of our models, we initialized two models from the same GPU checkpoint, and trained one on Trainium and the other on a comparable GPU. Both models were trained with the same training hyperparameters. The dataset used for metrics calculation is a holdout dataset, and we evaluate the model’s accuracy on this dataset every N global steps. X-axis is the global step, and Y-axis is the model accuracy. We observed less than 0.1% difference in the model accuracy at each point in the following graph.

Furthermore, to evaluate the cost-effectiveness of the model training, we prefer to compare the wall clock time taken to reach model convergence. We believe this provides a more practical view of cost savings compared to measures such as cost per token, achieved FLOPS/dollar, and other factors. Considering the training time of trn1.32xl and comparable Amazon Elastic Compute Cloud (Amazon EC2) instances, we have observed that Trainium offers up to 30% cheaper cost to model convergence.


There are many factors to consider when evaluating different accelerators for your DL workloads. Some of the most important are model quality, throughput, cost, and availability. It is paramount to ensure that your model quality and throughput are not sacrificed based on the accelerator you choose.

Thanks to our partnership and collaboration with the Annapurna Neuron team, the Amazon Search M5 team has been able to save up to 30% in cost by moving to Trainium. The team is able to use Trainium and achieve model quality and throughput parity with comparable accelerators in the market. Checkpoint interoperability and minimal code changes with support for XLA have allowed M5 to choose between multiple accelerators for their workloads. This has enabled the M5 team to take advantage of the large compute power of Trainium, and build accelerator agnostic solutions to delight customers. From an operational standpoint, Trainium has been proven capable of supporting tier-1 services at Amazon scale. The M5 team continues to move more workloads to Trainium to provide the best models for Amazon at the lowest costs.

In summary, the M5 team has been able to perform cost-effective, production-grade ML training by adding Trainium to the fleet of accelerators. We encourage you to take a look at Trainium and other Neuron devices like AWS Inferentia to reap the benefits of purpose-built Amazon silicon for ML workloads. Get started easily with one of the many tutorials featuring different models, like Llama 2, available on Trainium.

About the Authors

Jerry Mannil is a software engineer at Amazon Search. He works on improving the efficiency, robustness and scalibility of the distributed training infrastructure.

Ken Su is a software engineer at Amazon Search. He works on improving training efficiency and scalable distributed training workflow. Outside work, he likes hiking and tennis.

RJ is an Engineer within Amazon. He builds and optimizes systems for distributed systems for training and works on optimizing adopting systems to reduce latency for ML Inference. Outside work, he is exploring using Generative AI for building food recipes.

Abhinandan Patni is a Senior Software Engineer at Amazon Search. He focuses on building systems and tooling for scalable distributed deep learning training and real time inference.

James Park is a Solutions Architect at Amazon Web Services. He works with to design, build, and deploy technology solutions on AWS, and has a particular interest in AI and machine learning. In h is spare time he enjoys seeking out new cultures, new experiences,  and staying up to date with the latest technology trends. You can find him on LinkedIn.


Latest Articles


By submitting this form on our website, you agree that we may collect and use your personal information for marketing, and for other purposes as set forth in our privacy policy, which we encourage you to review.