MT-DNN 핵심 요약
- BERT에 Multi-task learning(GLUE의 9개 Task 활용)을 수행하여 성능 개선한 모델
- 다양한 Task의 Supervised Dataset을 모두 사용하여 대용량 데이터로 학습
- Multi-task Learning을 통해 특정 Task에 Overfitting되지 않도록 Regularization
- 다음의 Task들에 대해 State-of-the Art 성능 (BERT 보다 높은 성능)
- 8개의 GLUE Task
- SNLI와 SciTail Task로 Domain Adaptation
다음의 순서로 논문을 설명하도록하곘습니다.
3. Model Architecture & Training
1. Language Representation
단어 혹은 문장을 Vector 형태로 표현하는 것을 Language Representation이라고 합니다. 최근 가장 인기를 얻고있는 BERT를 비롯하여 다양한 방식의 Language Representation이 존재하는데 이는 크게 다음 2가지로 설명할 수 있습니다.
Language Model Pre-training
대표적인 Language Model은 ELMo, BERT 등이 있으며 가장 큰 특징은 대용량의 unsupervised dataset을 활용하여 모델을 학습한다는 것입니다.
예를 들어, BERT는 Book Corpus와 Wikipedia를 활용하여 모델을 학습하는데 다음 2가지 방법을 활용합니다.
- Masked Word Prediction
- 문장이 주어졌을 때 특정 Word를 Masking하고 다른 주변 Word들을 활용하여 Masked Word를 예측하는 방식으로 학습
- ex) my dog is [Mask] -> my dog is hairy
- Next Sentence Prediction
- 문장이 2개 주어졌을 때, 2개의 문장이 연결된 문장인지 아닌지를 Classification 하는 방식으로 학습
- ex) Input = the man went to the store [SEP] he bought a gallon of milk → IsNext
MT-DNN을 이해하기위해서는 BERT를 꼭 이해해야합니다.BERT에 대한 자세한 설명은 아래 블로그를 참고하기 바랍니다.
Multi-Task Learning
Multi-Task Learning이란 여러가지 Supervised Task를 1개의 모델을 통해 학습하고자 하는 것입니다.
그렇다면, 왜 여러가지 Task를 학습하고자 하는 것일까요? 만약에 어떤 사람이 스키 타는 법을 배웠다면 그렇지 않은 사람 보다 아마도 스케이트를 잘 탈 것입니다. 즉, Task 간 배운 지식이 서로 영향을 주어 성능을 향상시킬 수 있다는 가정을 가지고 모델을 학습하는 것입니다.
특히, Machine Learning 모델을 Multi-task Learning 시 크게 다음 2가지의 이점이 있습니다.
- 대용량 Supervised Dataset을 활용하여 학습할 수 있다.
- Supervised Dataset은 Task에 따라 제한된량을 가지고 있기 때문에 데이터량이 적을 경우 성능이 상당히 저하되지만, Multi-task Learning 시 이러한 데이터를 모두 합쳐서 활용하므로 비교적 대용량 데이터를 사용할 수 있다.
- 모델이 특정 Task에 Overfitting 되지 않도록 Regularization 효과를 줄 수 있다.
MT-DNN은 Language Model Pre-Training을 활용한 BERT에 Multi-task learning을 적용하여 성능 개선한 모델입니다. 그렇다면 어떤 Task들을 Multi-task learning에 활용 했고 모델을 어떻게 학습 시켰는지 살펴 봅시다.
2. Tasks
MT-DNN은 GLUE의 9개의 Task를 Multi-task learning에 활용합니다. 이를 아래 4개의 분류로 나누어 설명합니다.
Single Sentence Classification
하나의 문장이 주어졌을 때 문장의 class를 분류하는 Task
- CoLA - 문장이 문법적으로 맞는지 분류 (True/False)
- ex) The book was written by John. -> True
- ex) Books were sent to each other by the students. -> False
- SST-2 - 영화 Review 문장의 감정 분류 (Positive/Negative)
Text Similarity
문장 쌍이 주어졌을 때, 점수를 예측하는 Regression Task
- STS-B - 문장 간의 의미적 유사도를 점수로 예측
Pairwise Text Classification
문장 쌍이 주어졌을 때, 문장의 관계를 분류하는 Task
- RTE, MNLI - 문장 간의 의미적 관계를 3가지로 분류 (Entailment, Contradiction, Neutral)
- QQP, MRPC - 문장 간 의미가 같음 여부를 분류 (True/False)
Relevance Ranking
질문 문장과 지문이 주어졌을 때, 지문 중 정답이 있는 문장을 Ranking을 통해 찾는 Task
- QNLI - 질문과 해당 지문 중 한 문장이 쌍으로 주어졌을 때 해당 지문 문장에 질문의 답이 있는지 여부를 분류 (True/False)
- MT-DNN에서는 이를 Rank 방식으로 바꾸어 모든 지문 문장에 정답이 있을 가능성을 Scoring 하여 가장 Score가 높은 지문 문장만을 True로 분류하는 방식으로 Task 수행
3. Model Architecture & Training
MT-DNN은 위 그림과 같이 Transformer로 구성된 Shared Layer(Lexicon Encoder, Transformer Encoder)와 Task 별로 다른 Task Specific Layer로 구성됩니다. 여기서 Shared Layer는 BERT와 동일하지만 Task specific Layer는 BERT의 Task Specific Layer와 일부 다르게 구성됩니다.
모델 학습시에는 무작위로 특정 Task의 Data를 Batch로 뽑아서 학습하게 되는데 이 때, Shared Layer는 계속해서 학습되는 반면 Task Specific Layer는 해당 Task의 Data로 학습시만 학습됩니다. (Task 별로 Task Specific Layer가 1개씩 존재하며 Task에 적합한 Loss Function이 구성됩니다.)
Lexicon Encoder
Task에 따라 한개의 문장 혹은 여러개의 문장을 Input으로 받는 Layer로 그림과 같이 3가지 Embedding을 활용하여 Input을 구성 하게 됩니다.
- Token Embedding - Wordpiece Embedding Vector
- 1번째 Token은 [CLS] Token(Class Token)으로 추후 Output에서 Classification 등을 위해 사용됨
- 각 문장은 Wordpiece로 Tokenization 되며 [SEP] Token이 두 문장 사이의 구분자로 사용됨
- Sentence Embedding - 1번째 혹은 2번째 문장임을 표현하는 Vector
- Positional Embedding - 각 Token의 위치 정보를 표현하는 Vector
Transformer Encoder
Lexicon Encoder로 부터 각 Token의 Input Vector를 입력으로 받아 Transformer를 통해 각 Token의 Output Vector를 추출합니다. Transformer이기 때문에 생성된 Vector는 Self Attention을 통해 주변 Token 정보를 반영한 Contextual Embedding Vector입니다.
Single-Sentence Classification Output
하나의 문장을 분류하는 Task는 [CLS] Token을 사용합니다. 수식과 같이 [CLS] Token과 Task Specific Parameter의 곱에 Softmax를 취하여 Output을 추출합니다.
Text Similarity Output
Single Sentence Classification과 마찬가지로 [CLS] Token을 활용하며 Regression Task이므로 Task Specific Parameter와 곱하고 sigmoid function을 사용하여 Score를 예측합니다.
Pairwise Text Classification Output
두 문장 간의 의미 관계 등을 분류 하기 위해서 Stochastic Answer Network(SAN)라는 것을 사용합니다. SAN의 핵심 Idea는 Multi-Step Reasoning입니다. 즉, 1번에 Classification 결과를 예측하지 않고 여러번의 예측을 통한 Reasoning으로 결과를 예측하고자 하는 것입니다. 문장 문장 간 의미 관계를 분류하는 MNLI 문제를 예시로 들어볼게요. 다음 두 문장의 의미가 같은지 여부를 판단하려면 어떻게 해야할까요?
- If you need this book, it is probably too late unless you are about to take an SAT or GRE
- It’s never too late, unless you’re about to take a test.
먼저, SAT와 GRE가 test라는 것을 유추하고 이후에 두 문장의 의미가 비슷한지 여부를 판단하는 적어도 2 step의 Reasoning이 필요할 것입니다. 이를 위해, RNN으로 time step 마다 Classification 결과를 예측하고 해당 결과들을 조합하여 사용하고자 하는 것이 SAN입니다.
두 문장을 각각 Hypothiesis 문장, Premise 문장이라고 할 때, RNN(GRU)의 Input은 Hypothesis 문장을 그리고 Hidden State는 Premise 문장을 Embedding 하는데 사용됩니다. 여기서 문장 Embedding이란 Transformer에 의해 생성된 Token Vector들의 Weighted Sum을 의미합니다.
즉, Input x는 이전 Hidden State(Hypothesis 문장의 Embedding)를 고려한 Premise 문장의 Embedding Vector, Hidden State는 Input값(Premise 문장 Embedding)을 고려하여 변형한 Hypothesis 문장의 Embedding이 됩니다. SAN은 time step 마다 문장 간 관계를 고려하여 각 문장의 Embedding을 변형해 가며 생성한다고 볼 수 있습니다.
각 time step의 Classification 예측은 위 식을 통해 이루어집니다. Softmax를 구성하는 Vector는 Heuristic하게 구성된 것이며 이 곳에서 부터 얻은 Idea입니다. 두 문장 자체의 Embedding Vector, 그리고 두 문장 간 관계(차의 크기와 Dot Product)를 concat하여 구성된 Vector를 활용하여 문장 간 관계를 분류합니다. Multi-step Reasoning 이므로 만약 K번 RNN을 통해 예측했다면 K번 결과의 평균값을 통해 최종 결과를 예측하게 됩니다.
Relevance Ranking Output
Relevance Ranking은 Question과 문장 Pair를 Input으로 넣어 생성한 [CLS] Token에 Sigmoid를 취하여 문장 별로 점수를 Scoring하고 가장 높은 점수 만 Question에 해당하는 정답이 있다고 예측하는 방식으로 Classification을 수행합니다.
4. Evaluation
위의 표는 GLUE Task에 대한 성능 평가 결과를 나타냅니다. MT-DNN은 BERT 보다 전체 성능이 약 1.8% 향상되어 82.2%로 현재(2019.05 기준) GLUE Leaderboard에서 가장 높은 성능을 보이고 있습니다. (표의 성능 지표는 Accuracy 혹은 Accuracy/F1-score를 의미합니다.)
주목해야할 점 중 하나는 Dataset이 적은 Task(MRPC, RTE)의 경우 비교적 높은 성능 향상이 있었다는 것입니다. 이는 다른 Task를 통해 학습한 정보를 모델이 잘 활용한 덕분에 얻은 결과로 볼 수 있습니다.
위의 표에서 ST-DNN은 Multi-task learning은 수행하지 않고 Task Speicific Layer만 바꾸어 수행한 결과를 나타냅니다. 논문에서 SAN을 사용한 Multi-step Reasoning을 하나의 강점으로 제안했으나 BERT의 결과 보다 큰 성능 개선은 없어 보입니다. 다만, SAN을 사용하기 때문에 BERT와 달리 Pairwise Text Classification에서 [CLS] Token을 사용하지 않고 Wordpiece Token들을 사용했다는게 특이한점입니다. 또한, Relavance Ranking의 경우 BERT 보다 상당히 높은 성능을 보이는데 BERT는 Ranking 방식을 사용하지 않고 모든 문장 Pair에 대해 결과를 따로 예측하기 때문입니다.
마지막으로 위의 표는 MT-DNN을 GLUE Task와 다른 Task인 SNLI와 SciTail에 적용했을 경우 BERT와의 성능 비교를 보여줍니다. 그래프의 x축은 Fine Tuning에 사용된 SNLI와 SciTail의 Dataset량을 나타냅니다. 사실, SNLI와 SciTail은 Natural Language Inference Task로 MT-DNN은 이에 대해 벌써 학습되었고 BERT는 그렇지 않으므로 적은 데이터량으로 Fine Tuning 시 MT-DNN의 성능이 훨씬 높은 것은 당연한 것으로 보입니다. 다만, MT-DNN은 문장의 의미를 BERT에 비해 더 잘 학습한 상태인 것은 분명해 보이니 모델 사용시 상황에 따라 MT-DNN을 BERT 대신 사용하는 것은 확실히 이점이 있을 것 같습니다.
NLP Language Representation transformer 고급