Zero-shot text classification with Amazon SageMaker JumpStart

Natural language processing (NLP) is the field in machine learning (ML) concerned with giving computers the ability to understand text and spoken words in the same way as human beings can. Recently, state-of-the-art architectures like the transformer architecture are used to achieve near-human performance on NLP downstream tasks like text summarization, text classification, entity recognition, and more.

Large language models (LLMs) are transformer-based models trained on a large amount of unlabeled text with hundreds of millions (BERT) to over a trillion parameters (MiCS), and whose size makes single-GPU training impractical. Due to their inherent complexity, training an LLM from scratch is a very challenging task that very few organizations can afford. A common practice for NLP downstream tasks is to take a pre-trained LLM and fine-tune it. For more information about fine-tuning, refer to Domain-adaptation Fine-tuning of Foundation Models in Amazon SageMaker JumpStart on Financial data and Fine-tune transformer language models for linguistic diversity with Hugging Face on Amazon SageMaker.

Zero-shot learning in NLP allows a pre-trained LLM to generate responses to tasks that it hasn’t been explicitly trained for (even without fine-tuning). Specifically speaking about text classification, zero-shot text classification is a task in natural language processing where an NLP model is used to classify text from unseen classes, in contrast to supervised classification, where NLP models can only classify text that belong to classes in the training data.

We recently launched zero-shot classification model support in Amazon SageMaker JumpStart. SageMaker JumpStart is the ML hub of Amazon SageMaker that provides access to pre-trained foundation models (FMs), LLMs, built-in algorithms, and solution templates to help you quickly get started with ML. In this post, we show how you can perform zero-shot classification using pre-trained models in SageMaker Jumpstart. You will learn how to use the SageMaker Jumpstart UI and SageMaker Python SDK to deploy the solution and run inference using the available models.

Zero-shot learning

Zero-shot classification is a paradigm where a model can classify new, unseen examples that belong to classes that were not present in the training data. For example, a language model that has beed trained to understand human language can be used to classify New Year’s resolutions tweets on multiple classes like career, health, and finance, without the language model being explicitly trained on the text classification task. This is in contrast to fine-tuning the model, since the latter implies re-training the model (through transfer learning) while zero-shot learning doesn’t require additional training.

The following diagram illustrates the differences between transfer learning (left) vs. zero-shot learning (right).

Transfer learning vs Zero-shot

Yin et al. proposed a framework for creating zero-shot classifiers using natural language inference (NLI). The framework works by posing the sequence to be classified as an NLI premise and constructs a hypothesis from each candidate label. For example, if we want to evaluate whether a sequence belongs to the class politics, we could construct a hypothesis of “This text is about politics.” The probabilities for entailment and contradiction are then converted to label probabilities. As a quick review, NLI considers two sentences: a premise and a hypothesis. The task is to determine whether the hypothesis is true (entailment) or false (contradiction) given the premise. The following table provides some examples.

Premise Label Hypothesis
A man inspects the uniform of a figure in some East Asian country. Contradiction The man is sleeping.
An older and younger man smiling. Neutral Two men are smiling and laughing at the cats playing on the floor.
A soccer game with multiple males playing. entailment Some men are playing a sport.

Solution overview

In this post, we discuss the following:

  • How to deploy pre-trained zero-shot text classification models using the SageMaker JumpStart UI and run inference on the deployed model using short text data
  • How to use the SageMaker Python SDK to access the pre-trained zero-shot text classification models in SageMaker JumpStart and use the inference script to deploy the model to a SageMaker endpoint for a real-time text classification use case
  • How to use the SageMaker Python SDK to access pre-trained zero-shot text classification models and use SageMaker batch transform for a batch text classification use case

SageMaker JumpStart provides one-click fine-tuning and deployment for a wide variety of pre-trained models across popular ML tasks, as well as a selection of end-to-end solutions that solve common business problems. These features remove the heavy lifting from each step of the ML process, simplifying the development of high-quality models and reducing time to deployment. The JumpStart APIs allow you to programmatically deploy and fine-tune a vast selection of pre-trained models on your own datasets.

The JumpStart model hub provides access to a large number of NLP models that enable transfer learning and fine-tuning on custom datasets. As of this writing, the JumpStart model hub contains over 300 text models across a variety of popular models, such as Stable Diffusion, Flan T5, Alexa TM, Bloom, and more.

Note that by following the steps in this section, you will deploy infrastructure to your AWS account that may incur costs.

Deploy a standalone zero-shot text classification model

In this section, we demonstrate how to deploy a zero-shot classification model using SageMaker JumpStart. You can access pre-trained models through the JumpStart landing page in Amazon SageMaker Studio. Complete the following steps:

  1. In SageMaker Studio, open the JumpStart landing page.
    Refer to Open and use JumpStart for more details on how to navigate to SageMaker JumpStart.
  2. In the Text Models carousel, locate the “Zero-Shot Text Classification” model card.
  3. Choose View model to access the facebook-bart-large-mnli model.
    Alternatively, you can search for the zero-shot classification model in the search bar and get to the model in SageMaker JumpStart.
  4. Specify a deployment configuration, SageMaker hosting instance type, endpoint name, Amazon Simple Storage Service (Amazon S3) bucket name, and other required parameters.
  5. Optionally, you can specify security configurations like AWS Identity and Access Management (IAM) role, VPC settings, and AWS Key Management Service (AWS KMS) encryption keys.
  6. Choose Deploy to create a SageMaker endpoint.

This step takes a couple of minutes to complete. When it’s complete, you can run inference against the SageMaker endpoint that hosts the zero-shot classification model.

In the following video, we show a walkthrough of the steps in this section.

Use JumpStart programmatically with the SageMaker SDK

In the SageMaker JumpStart section of SageMaker Studio, under Quick start solutions, you can find the solution templates. SageMaker JumpStart solution templates are one-click, end-to-end solutions for many common ML use cases. As of this writing, over 20 solutions are available for multiple use cases, such as demand forecasting, fraud detection, and personalized recommendations, to name a few.

The “Zero Shot Text Classification with Hugging Face” solution provides a way to classify text without the need to train a model for specific labels (zero-shot classification) by using a pre-trained text classifier. The default zero-shot classification model for this solution is the facebook-bart-large-mnli (BART) model. For this solution, we use the 2015 New Year’s Resolutions dataset to classify resolutions. A subset of the original dataset containing only the Resolution_Category (ground truth label) and the text columns is included in the solution’s assets.

New year's resolutions table

The input data includes text strings, a list of desired categories for classification, and whether the classification is multi-label or not for synchronous (real-time) inference. For asynchronous (batch) inference, we provide a list of text strings, the list of categories for each string, and whether the classification is multi-label or not in a JSON lines formatted text file.

Zero-shot input example

The result of the inference is a JSON object that looks something like the following screenshot.

Zero-shot output example

We have the original text in the sequence field, the labels used for the text classification in the labels field, and the probability assigned to each label (in the same order of appearance) in the field scores.

To deploy the Zero Shot Text Classification with Hugging Face solution, complete the following steps:

  1. On the SageMaker JumpStart landing page, choose Models, notebooks, solutions in the navigation pane.
  2. In the Solutions section, choose Explore All Solutions.
    Amazon SageMaker JumpStart landing page
  3. On the Solutions page, choose the Zero Shot Text Classification with Hugging Face model card.
  4. Review the deployment details and if you agree, choose Launch.
    Zero-shot text classification with hugging face

The deployment will provision a SageMaker real-time endpoint for real-time inference and an S3 bucket for storing the batch transformation results.

The following diagram illustrates the architecture of this method.

Zero-shot text classification solution architecture

Perform real-time inference using a zero-shot classification model

In this section, we review how to use the Python SDK to run zero-shot text classification (using any of the available models) in real time using a SageMaker endpoint.

  1. First, we configure the inference payload request to the model. This is model dependent, but for the BART model, the input is a JSON object with the following structure: { “inputs”: # The text to be classified “parameters”: { “candidate_labels”: # A list of the labels we want to use for the text classification “multi_label”: True | False } }
  2. Note that the BART model is not explicitly trained on the candidate_labels. We will use the zero-shot classification technique to classify the text sequence to unseen classes. The following code is an example using text from the New Year’s resolutions dataset and the defined classes: classification_categories = [‘Health’, ‘Humor’, ‘Personal Growth’, ‘Philanthropy’, ‘Leisure’, ‘Career’, ‘Finance’, ‘Education’, ‘Time Management’] data_zero_shot = { “inputs”: “#newyearsresolution :: read more books, no scrolling fb/checking email b4 breakfast, stay dedicated to pt/yoga to squash my achin’ back!”, “parameters”: { “candidate_labels”: classification_categories, “multi_label”: False } }
  3. Next, you can invoke a SageMaker endpoint with the zero-shot payload. The SageMaker endpoint is deployed as part of the SageMaker JumpStart solution. response = runtime.invoke_endpoint(EndpointName=sagemaker_endpoint_name, ContentType=’application/json’, Body=json.dumps(payload)) parsed_response = json.loads(response[‘Body’].read())
  4. The inference response object contains the original sequence, the labels sorted by score from max to min, and the scores per label: {‘sequence’: “#newyearsresolution :: read more books, no scrolling fb/checking email b4 breakfast, stay dedicated to pt/yoga to squash my achin’ back!”, ‘labels’: [‘Personal Growth’, ‘Health’, ‘Time Management’, ‘Leisure’, ‘Education’, ‘Humor’, ‘Career’, ‘Philanthropy’, ‘Finance’], ‘scores’: [0.4198768436908722, 0.2169460505247116, 0.16591140627861023, 0.09742163866758347, 0.031757451593875885, 0.027988269925117493, 0.015974704176187515, 0.015464971773326397, 0.008658630773425102]}

Run a SageMaker batch transform job using the Python SDK

This section describes how to run batch transform inference with the zero-shot classification facebook-bart-large-mnli model using the SageMaker Python SDK. Complete the following steps:

  1. Format the input data in JSON lines format and upload the file to Amazon S3.
    SageMaker batch transform will perform inference on the data points uploaded in the S3 file.
  2. Set up the model deployment artifacts with the following parameters:
    1. model_id – Use huggingface-zstc-facebook-bart-large-mnli.
    2. deploy_image_uri – Use the image_uris Python SDK function to get the pre-built SageMaker Docker image for the model_id. The function returns the Amazon Elastic Container Registry (Amazon ECR) URI.
    3. deploy_source_uri – Use the script_uris utility API to retrieve the S3 URI that contains scripts to run pre-trained model inference. We specify the script_scope as inference.
    4. model_uri – Use model_uri to get the model artifacts from Amazon S3 for the specified model_id.
      #imports from sagemaker import image_uris, model_uris, script_uris, hyperparameters #set model id and version model_id, model_version, = ( “huggingface-zstc-facebook-bart-large-mnli”, “*”, ) # Retrieve the inference Docker container URI. This is the base Hugging Face container image for the default model above. deploy_image_uri = image_uris.retrieve( region=None, framework=None, # Automatically inferred from model_id image_scope=”inference”, model_id=model_id, model_version=model_version, instance_type=”ml.g4dn.xlarge”, ) # Retrieve the inference script URI. This includes all dependencies and scripts for model loading, inference handling, and more. deploy_source_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope=”inference”) # Retrieve the model URI. This includes the pre-trained model and parameters. model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope=”inference”) 
  3. Use HF_TASK to define the task for the Hugging Face transformers pipeline and HF_MODEL_ID to define the model used to classify the text: # Hub model configuration hub = { ‘HF_MODEL_ID’:’facebook/bart-large-mnli’, # The model_id from the Hugging Face Hub ‘HF_TASK’:’zero-shot-classification’ # The NLP task that you want to use for predictions }

    For a complete list of tasks, see Pipelines in the Hugging Face documentation.

  4. Create a Hugging Face model object to be deployed with the SageMaker batch transform job: # Create HuggingFaceModel class huggingface_model_zero_shot = HuggingFaceModel( model_data=model_uri, # path to your trained sagemaker model env=hub, # configuration for loading model from Hub role=role, # IAM role with permissions to create an endpoint transformers_version=”4.17″, # Transformers version used pytorch_version=”1.10″, # PyTorch version used py_version=’py38′, # Python version used )
  5. Create a transform to run a batch job: # Create transformer to run a batch job batch_job = huggingface_model_zero_shot.transformer( instance_count=1, instance_type=’ml.m5.xlarge’, strategy=’SingleRecord’, assemble_with=’Line’, output_path=s3_path_join(“s3://”,sagemaker_config[‘S3Bucket’],”zero_shot_text_clf”, “results”), # we are using the same s3 path to save the output with the input )
  6. Start a batch transform job and use S3 data as input: batch_job.transform( data=data_upload_path, content_type=’application/json’, split_type=’Line’, logs=False, wait=True )

You can monitor your batch processing job on the SageMaker console (choose Batch transform jobs under Inference in the navigation pane). When the job is complete, you can check the model prediction output in the S3 file specified in output_path.

For a list of all the available pre-trained models in SageMaker JumpStart, refer to Built-in Algorithms with pre-trained Model Table. Use the keyword “zstc” (short for zero-shot text classification) in the search bar to locate all the models capable of doing zero-shot text classification.

Clean up

After you’re done running the notebook, make sure to delete all resources created in the process to ensure that the costs incurred by the assets deployed in this guide are stopped. The code to clean up the deployed resources is provided in the notebooks associated with the zero-shot text classification solution and model.

Default security configurations

The SageMaker JumpStart models are deployed using the following default security configurations:

To learn more about SageMaker security-related topics, check out Configure security in Amazon SageMaker.

Conclusion

In this post, we showed you how to deploy a zero-shot classification model using the SageMaker JumpStart UI and perform inference using the deployed endpoint. We used the SageMaker JumpStart New Year’s resolutions solution to show how you can use the SageMaker Python SDK to build an end-to-end solution and implement zero-shot classification application. SageMaker JumpStart provides access to hundreds of pre-trained models and solutions for tasks like computer vision, natural language processing, recommendation systems, and more. Try out the solution on your own and let us know your thoughts.

About the authors

David Laredo is a Prototyping Architect at AWS Envision Engineering in LATAM, where he has helped develop multiple machine learning prototypes. Previously, he has worked as a Machine Learning Engineer and has been doing machine learning for over 5 years. His areas of interest are NLP, time series, and end-to-end ML.

Vikram Elango is an AI/ML Specialist Solutions Architect at Amazon Web Services, based in Virginia, US. Vikram helps financial and insurance industry customers with design and thought leadership to build and deploy machine learning applications at scale. He is currently focused on natural language processing, responsible AI, inference optimization, and scaling ML across the enterprise. In his spare time, he enjoys traveling, hiking, cooking, and camping with his family.

Vivek MadanDr. Vivek Madan is an Applied Scientist with the Amazon SageMaker JumpStart team. He got his PhD from University of Illinois at Urbana-Champaign and was a Post Doctoral Researcher at Georgia Tech. He is an active researcher in machine learning and algorithm design and has published papers in EMNLP, ICLR, COLT, FOCS, and SODA conferences.



Source