首先安装 conda 环境,可以参考之前安装 conda 的文章
- 创建新的conda 环境
conda create -n myenv python=3.10
- 激活 conda 环境
conda activate myenv
- 安装 transformer 类库
pip install transformers
- 安装 pytorch
conda install pytorch torchvision torchaudio cpuonly -c pytorch
- 确认安装成功
# test_transformers.py
try:
from transformers import pipeline
print("Transformers library has been successfully installed.")
except ImportError as e:
print("An error occurred:", e)
- 下载 gpt2模型测试
# Import the pipeline function from the transformers library
from transformers import pipeline
# Initialize a text generation pipeline with the GPT-2 model
generator = pipeline('text-generation', model='gpt2')
# Generate text based on a prompt
generated_text = generator("Hello, my name is", max_length=30)
# Print the generated text
print(generated_text)
如果遇到以下错误,升级一下 charset
An error occurred: cannot import name ‘COMMON_SAFE_ASCII_CHARACTERS’ from ‘charset_normalizer.constant’
pip install --upgrade charset_normalizer