Transformers 库

                     

贡献者: int256

  • 本文处于草稿阶段。

   transformers 库是要使用 HuggingFace 网站的已经训练好的 torch 网络参数或 tensorflow 网络参数必备的一个包(现在访问 huggingface 网站需要科学上网)。

   安装可以使用 pip 命令直接安装:pip install transformers,使用 huggingface 网站提供的数据集或神经网络参数需要保证电脑里已经装有 PyTorch。(CPU 版的 torch 可以使用命令 pip install torch 安装。安装包较大,建议使用国内 pypi 的镜像站。)

   特别的,有的网络参数(.h5 文件)则是 tensorflow 的,那么需要保证电脑已经装有 tensorflow 包。

   对于网站上的每个网络参数,介绍一些常见的文件:

  1. readme.md 一类文件:介绍文档。
  2. license:这文件存储开源协议。
  3. config.json:这个文件存储了网络的一些配置。
  4. 分词配置:tokenizer.jsontokenizer_config.jsonvocab,其中 vocab 文件可能是 txt 或 json 格式。这些文件给分词器使用。
  5. pytorch_model.bin:存储了 PyTorch 网络的训练好的参数。
  6. tf_model.h5:存储了 tensorflow 网络的训练好的参数。
  7. special_tokens_map.json:特殊的 token 的一些配置。

   如果要下载一个网络参数,下载上述文件中存在的文件即可。

   例如对于一个 PyTorch 的一个网络,下载以下几个文件:

   假若我们现在已经下载好了,将上述文件放在了 ./package-name 目录下。

   观察 config.json,若发现网络模型是 "BertForSequenceClassification",或我们已经知道这个网络是一个 Bert 的句子分类网络了。那么就可以按如下例程的方式使用这个网络参数:

import torch
from transformers import AutoTokenizer, BertForSequenceClassification 
# 后面的 BertForSequenceClassification 是根据网络类型得到的

tokenizer = AutoTokenizer.from_pretrained("./package-name")
model = BertForSequenceClassification.from_pretrained("./package-name")
model.eval()

def get_output(text):
    output=[]
    model_input = tokenizer(text, return_tensors="pt", padding=True)
    model_output = model(**model_input, return_dict=False)
    prediction = torch.argmax(model_output[0].cpu(), dim=-1)
    prediction = [p.item() for p in prediction]
    for i in range(len(prediction)):
        if prediction[i]==1:
            output.append("Class A")
        else:
            output.append('Class B')
    return output

print(get_output(input()))

   类似的,对于一个 GPT2LMHead 模型的网络,例程如下:

import torch
from transformers import AutoTokenizer, GPT2LMHeadModel

model_name = "./package-name"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = input()  # 输入提示,可以替换为自己需要的内容
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# 生成文本,max_length 指定生成文本的最大长度
# num_return_sequences 指定生成的文本序列数量,一般设置为 1
configs = {
    "max_length": 100,
    "num_return_sequences": 1,
    "pad_token_id": tokenizer.eos_token_id
}
output = model.generate(input_ids, **configs)

# skip_special_tokens=True 会忽略特殊的控制标记,只返回可读文本
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print("Generated Text:", generated_text)

   对于其中的 config.json,我们可以通过其看出这个网络的一些配置,一个典型的配置文件以及各配置意义如下:

{
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 30522
}

   各个配置意义如下:

  1. architectures:模型的名称;
  2. attention_probs_dropout_prob:注意力的 dropout,默认为 0.1;
  3. directionality:文字编码方向采用的算法,一般为 bidi;
  4. hidden_act:编码器内激活函数,默认"gelu",还可为"relu"、"swish"或 "gelu_new";
  5. hidden_dropout_prob:词嵌入层或编码器的 dropout 配置,默认为 0.1;
  6. hidden_size:编码器内隐藏层神经元数量,默认 768;
  7. initializer_range:神经元权重的标准差,默认为 0.02;
  8. intermediate_size:编码器内全连接层的输入维度,默认 3072;
  9. layer_norm_eps:layer normalization 的 epsilon 值,默认为 1e-12;
  10. max_position_embeddings:模型使用的最大序列长度,默认为 512;
  11. model_type:模型的类型,一般是 bert;
  12. num_attention_heads:编码器内注意力头数,默认 12;
  13. num_hidden_layers:编码器内隐藏层层数,默认 12;
  14. pooler_fc_size:Pooler 层(相当于一个全连接层,作为分类器解决序列级 NLP 任务)的大小,默认也为 768;
  15. pooler_num_attention_heads:Pooler 层注意力头,默认 12;
  16. pooler_num_fc_layers:Pooler 连接层数,默认 3;
  17. pooler_size_per_head:每个注意力头的大小;
  18. pooler_type:Pooler 层类型;
  19. type_vocab_size:词汇表的类型,默认是 2;
  20. vocab_size:词汇数,bert 默认 30522,这是因为 bert 以中文字为单位进入输入。

                     

© 小时科技 保留一切权利