目录
知乎专栏的讨论:https://zhuanlan.zhihu.com/p/28680474
https://github.com/tensorflow/tensor2tensor
建议的超参是在Cloud TPUs或者8-GPU machines上训练好的
数据集方面,
--problem=image_imagenet
,或者是缩小版的数据集image_imagenet224, image_imagenet64, image_imagenet32
--problem=image_cifar10
,或者是关闭了data augmentation的--problem=image_cifar10_plain
--problem=image_cifar100
--problem=image_mnist
模型方面,
--model=resnet --hparams_set=resnet_50
,resnet的top-1 accuracy能达到76%以上。--model=xception --hparams_set=xception_base
--model=shake_shake --hparams_set=shakeshake_big
。当--train_steps=700000
的时候,在CIFAR-10上,可以达到97% accuracy。数据集方面,
--problem=languagemodel_ptb10k
--problem=languagemodel_ptb_characters
--problem=languagemodel_lm1b32k
--problem=languagemodel_lm1b_characters
模型方面,建议直接上--model=transformer
,
--hparams_set=transformer_small
--hparams_set=transformer_base
IMDB数据集:--problem=sentiment_imdb
建议使用模型--model=transformer_encoder
,由于这个数据集很小,使用--hparams_set=transformer_tiny
,以及比较少的训练步数就行了--train_steps=2000
。
数据集:Librispeech (English speech to text)
--problem=librispeech
--problem=librispeech_clean
将CNN/DailyMail的文章摘要成一些句子的数据集:--problem=summarize_cnn_dailymail32k
模型使用--model=transformer
,超参使用--hparams_set=transformer_prepend
,这样可以得到不错的ROUGE scores。
数据集:
--problem=translate_ende_wmt32k
--problem=translate_enfr_wmt32k
--problem=translate_encs_wmt32k
--problem=translate_enzh_wmt32k
--problem=translate_envi_iwslt32k
如果要将源语言和目标语言调换,那么可以直接加一个_rev
,也就是German-English即为--problem=translate_ende_wmt32k_rev
翻译问题,建议使用--model=transformer
,在8 GPUs上训练300K steps之后,在English-German上可以达到28的BLEU。如果在单GPU上,建议使用--hparams_set=transformer_base_single_gpu
,在大数据集上(例如English-French),想要达到很好的效果,使用大模型--hparams_set=transformer_big
。
针对机器翻译问题的一个基本流程:
# See what problems, models, and hyperparameter sets are available.
# You can easily swap between them (and add new ones).
t2t-trainer --registry_help
# 1. 设置模型参数
PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base_single_gpu
DATA_DIR=$HOME/t2t_data
TMP_DIR=/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
# 2. 生成数据
t2t-datagen \
--data_dir=$DATA_DIR \
--tmp_dir=$TMP_DIR \
--problem=$PROBLEM
# 3. 训练
# * If you run out of memory, add --hparams='batch_size=1024'.
t2t-trainer \
--data_dir=$DATA_DIR \
--problem=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR
# 4. Decode
DECODE_FILE=$DATA_DIR/decode_this.txt
echo "Hello world" >> $DECODE_FILE
echo "Goodbye world" >> $DECODE_FILE
echo -e 'Hallo Welt\nAuf Wiedersehen Welt' > ref-translation.de
BEAM_SIZE=4
ALPHA=0.6
t2t-decoder \
--data_dir=$DATA_DIR \
--problem=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
--decode_from_file=$DECODE_FILE \
--decode_to_file=translation.en
# 5. 查看翻译结果
cat translation.en
# 6. Evaluate the BLEU score
# Note: Report this BLEU score in papers, not the internal approx_bleu metric.
t2t-bleu --translation=translation.en --reference=ref-translation.de
tensorflow.Example
的protobuf标准化处理过的TFRecord
文件。所有的problem都在problem_hparams.py中定义了,或者通过@registry.register_problem
进行注册。
直接运行t2t-datagen
可以查看目前支持的problems列表。
所有的modalities定义在modality.py中。
T2TModel
定义了核心的tensor-to-tensor变换,与input/output的modality或者task无关。模型的输入是dense tensors,输出是dense tensors,可以在final step中被一个modality进行变换(例如通过一直final linear transform,产出logits供softmax over classes使用)。
在models目录下的init.py中import了所有model。
这些模型的基类是T2TModel,都是通过@registry.register_model来进行注册的。
超参数集合通过@registry.register_hparams进行注册,并通过tf.contrib.training.HParams编码成对象。
HParams对model和problem都是可用的。
common_hparams.py中定义了基本的超参集,而且超参集的函数可以组成其他超参集的函数。
支持分布式训练,参考https://tensorflow.github.io/tensor2tensor/distributed_training.html,包括:
参考:https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/insights
首先,安装nodejs的npm:https://nodejs.org/en/
然后用npm安装Bower:
npm install -g bower
然后用bower对本项目的insights部分进行安装:
pushd tensor2tensor/insights/polymer
bower install
popd
还需要
pip install oauth2client
然后写一个json,表示模型的各种配置:
{
"configuration": [{
"source_language": "en",
"target_language": "de",
"label": "transformers_wmt32k",
"transformer": {
"model": "transformer",
"model_dir": "/tmp/t2t/train",
"data_dir": "/tmp/t2t/data",
"hparams": "",
"hparams_set": "transformer_base_single_gpu",
"problem": "translate_ende_wmt32k"
},
}],
"language": [{
"code": "en",
"name": "English",
},{
"code": "de",
"name": "German",
}]
}
然后启动:
t2t-insights-server \
--configuration=configuration.json \
--static_path=`pwd`/tensor2tensor/insights/polymer