Goglides Dev 🌱

Balkrishna Pandey
Balkrishna Pandey

Posted on

trl package, ImportError: cannot import name 'top_k_top_p_filtering' from 'transformers'

Recently I encountered an error message with my Python code when I was trying to fine-tune LLaMA2-7B parameters model. I was using the following code.

import transformers  # Importing the transformers library, which provides tools for working with transformer models.

from trl import SFTTrainer  # Importing SFTTrainer from the trl (transformer reinforcement learning) package.

# Setting the padding token of the tokenizer to be the same as its end-of-sentence token.
base_tokenizer.pad_token = base_tokenizer.eos_token

# Clearing the GPU cache to free up memory and avoid potential out-of-memory issues.
torch.cuda.empty_cache()

# Initializing the Supervised Fine-Tuning (SFT) Trainer.
trainer = SFTTrainer(
    model=base_model,  # The model to be fine-tuned.
    train_dataset=train_data,  # The dataset for training.
    eval_dataset=test_data,  # The dataset for evaluation.
    dataset_text_field="prompt",  # The field in the dataset that contains the text to be processed.
    peft_config=lora_config,  # The PEFT (Parameter-Efficient Fine-Tuning) configuration, here using LORA.
    args=transformers.TrainingArguments(  # Configuration for the training process.
        per_device_train_batch_size=1,  # Batch size per device.
        gradient_accumulation_steps=16,  # Number of steps to accumulate gradients before updating model weights.
        warmup_steps=50,  # Absolute number of warmup steps for the learning rate scheduler.
        max_steps=-1,  # Maximum number of training steps, -1 means unlimited.
        learning_rate=1e-5,  # The learning rate for optimization.
        logging_dir="./logs",  # Directory where training logs will be stored.
        logging_first_step=True,  # Log the first training step, useful for debugging.
        logging_steps=20,  # Frequency of logging training information.
        evaluation_strategy="steps",  # Strategy to perform model evaluation.
        optim="adamw_torch",  # The optimizer to be used.
        eval_steps=50,  # Number of steps before performing evaluation.
        output_dir="/opt/app-root/src/data/v8-finance-3/outputs",  # Directory to store output files.
        load_best_model_at_end=True,  # Whether to load the best model at the end of training.
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(base_tokenizer, mlm=False),  # Data collator for language modeling.
)
Enter fullscreen mode Exit fullscreen mode

And here is the error message.

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[47], line 3
      1 import transformers  # Importing the transformers library, which provides tools for working with transformer models.
----> 3 from trl import SFTTrainer  # Importing SFTTrainer from the trl (transformer reinforcement learning) package.
      5 # Setting the padding token of the tokenizer to be the same as its end-of-sentence token.
      6 base_tokenizer.pad_token = base_tokenizer.eos_token

File /opt/app-root/lib64/python3.9/site-packages/trl/__init__.py:5
      1 # flake8: noqa
      3 __version__ = "0.7.11"
----> 5 from .core import set_seed
      6 from .environment import TextEnvironment, TextHistory
      7 from .extras import BestOfNSampler

File /opt/app-root/lib64/python3.9/site-packages/trl/core.py:25
     23 import torch.nn.functional as F
     24 from torch.nn.utils.rnn import pad_sequence
---> 25 from transformers import top_k_top_p_filtering
     27 from .import_utils import is_npu_available, is_xpu_available
     30 try:

ImportError: cannot import name 'top_k_top_p_filtering' from 'transformers' (/opt/app-root/lib64/python3.9/site-packages/transformers/__init__.py)
Enter fullscreen mode Exit fullscreen mode

I found the solution is already committed in the following PR, for this specific package, but it's not released yet. https://github.com/huggingface/trl/pull/1415

So here is how you can manually pip install this package to fix this error

pip install git+https://github.com/huggingface/trl.git@7630f877f91c556d9e5a3baa4b6e2894d90ff84c
Enter fullscreen mode Exit fullscreen mode

If you are using the requirements.txt file add following line,

trl @ git+https://github.com/huggingface/trl.git@7630f877f91c556d9e5a3baa4b6e2894d90ff84c
Enter fullscreen mode Exit fullscreen mode

Top comments (0)