Urvashi Khandelwal, Omer Levy, Dan Jurafsky, Luke Zettlemoyer, and Mike Lewis
Stanford University & Facebook
Published at
ICLR 2020
Abstract
$k$-nearest neighbors($k$NN) 모델을 통한 pre-trained neural language model(LM)인 $k$NN-LMs를 제안함
Nearest neighbors는 LM 학습데이터에 있는 text 데이터에 대해 pre-trained LM embedding space 상의 거리에 따라 계산됨
이 $k$NN 기반의 augmentation 기법을 WikiText-103 LM을 적용하였고 state-of-the-art perplexity of 15.79를 기록하였다. 이는 2.9 point 향상이며 추가 학습은 하지 않음
Large training set에 대해서도 효율적으로 scaling up이 되고, domain adaptation도 효과적으로 이루어짐. 역시 추가 학습은 필요 없음
정성적으로 보았을 때, factual knowledge 같은 희귀한 패턴을 예측하는데 상당히 도움이 됨
이와 함께, 이런 실험 결과들을 통해 다음 단어를 예측하는 것보다 sequence of text에 대한 similarity를 학습하는 게 더 쉽고, nearest neighbor search가 long tail 패턴에 대한 Language modeling에 효과적이라는 것을 알 수 있음
1. Introduction
LM는 전형적으로 다음의 두가지 subproblems을 품
sentence prefix를 fixed-sized representation에 매핑시키는 문제
이 representation을 사용해서 다음 단어를 예측하는 문제
“Representation learning 문제가 다음 단어 예측 문제보다 더 쉽다”는 가정 하에 새로운 language modeling approach를 제안함
기존 LM의 prefix embedding을 사용해서 LM이 첫번째 문제를 더 잘한다는 강력한 증거를 제시함
(이하 Abatract의 내용과 같음)
2. Nearest Neighbor Language modeling
LM은 기본적으로 주어진 context sequence of tokens $c_t = (w_1, …, w_{t-1})$에 대해서 다음에 올 target token인 $w_t$에 대한 분포 $p(w_t
c_t)$를 예측함
간단히 $k$NN-LM은 pre-trained LM에 nearest neighbors retrieval mechanism을 더한 형태로서 추가적인 학습은 하지 않음
*context-target 쌍은 **inference 시에 사용할 key-value datastore에 저장함 (Figure 1 참고)
Datastore
$f(\cdot)$는 어떤 context $c$를 pre-trained LM을 이용해서 fixed-length vector representation으로 매핑해주는 함수임
$i$-th training example $(c_i, w_i) \in \mathcal{D}$에 대해서 key-value 쌍을 $(k_i, v_i)$ 만듬. 여기서 $k_i$는 context representation $f(c_i)$가 되고, $v_i$는 target word $w_i$가 됨
최종적으로 datastore $(\mathcal{K}, \mathcal{V})$는 $\mathcal{D}$에 있는 모든 training examples로 만든 모든 key-value 쌍이 되고, 다음과 같은 수식으로 나타낼 수 있음:
일단 주어진 input context $x$에 대해서 모델은 다음 단어에 대한 distribution $p_{LM}(y
x)$와 context representation $f(x)$를 만들어냄
여기서 모델은 datastore에 $f(x)$에 대한 $k$-nearest neighbors $\mathcal{N}$을 검색 쿼리로 던짐 (검색은 distance function $d(\cdot, \cdot)$을 따르며 논문의 실험에서는 $L^2$를 사용)
그렇게 얻어낸 nearest neighbor $\mathcal{N}$에 대해서 negative distance에 대한 softmax를 기반으로 neighbor에 대한 distribution을 구함. 검색된 target이 중복될 경우 이들의 softmax probability을 aggregating함
이런 대용량의 datastore에 대한 검색을 위해 FAISS (Johnson et al., 2017)을 사용함
FAISS는 high-dimensional vector space에 대해 빠르고 memory-efficient 하게 nearest neighbor를 검색할 수 있는 라이브러리임
뒤에 실험 파트에서 나오지만 distance metric으로 $L^2$를 썼을 때가 inner product를 썼을 때보다 더 좋았음
3. Experimental Setup
Data
WikiText-103: standard benchmark by Merity et al. (2017) for autoregressive language modeling. training 103M tokens, devset & testset 250K, respectively
Books: Toronto Books Cropus (Zhu et al., 2015). 0.7B tokens
Wiki-3B: English Wikipedia. 2.87B tokens
Wiki-100M: random 100M token subset of Wiki-3B
WikiText-103을 제외하고는 모두 BPE를 사용
Model Architecture
LM으로 Transformer (Vaswani et al., 2017) Decoder를 사용함
Baevski & Auli (2019)에서 설명한 정확한 구조와 optimization을 따랐음
구조: Transformer Decoder + sinusoidal position embedding and stack 16 block (자세한 내용은 해당 논문의 Section 4.1 참고)
Optimization: Nesterov’s accelerated gradient method (Sutskever et al., 2013)을 사용 (자세한 내용은 해당 논문의 Section 4.5를 참고)
This model consists of 16 layers, each with 16 self-attention heads, 1024 dimensional hidden states, and 4096 demensional feedforward layers, amounting to 247M trainalble parameters.
Adaptive inputs (Baevski & Auli, 2019)과 adaptive softmax (Grave et al., 2017b) with tied weights (Press & Wolf, 2017)을 적용하였음
WikiText-103에만 적용하고 나머지 데이터셋에는 적용하지 않음
Evaluation
Trained to minimize the negative log-likelihood of the training corpus
Evaluated by perplexity (exponentiated negative log-likelihood) on held out data
최대 2560 tokens (in WikiText-103) 을 extra prior context로 주어지고 각 test example 별로 512 tokens에 대해서 ppl scoring을 함. 다른 데이터셋에 대해서는 512 tokens을 extra prior context로 줌
$k$NN-LM
Sentence prefix의 representation이자 $k$NN-LM의 datastore에 사용할 key는 1024-dimensional representation이며, Transformer LM의 final layer에서 feed forward network에 들어가기 전 hidden vector를 사용함 (after self-attention and layernorm; Section 5에서 자세히 설명)
학습된 LM로 training set에 대해 single forwarding 하고 이를 datastore의 key(vector), value(next word)로서 활용함
Forwarding할 때, 각 target token에 대해 WikiText-103는 최소 1536 tokens이 prior context를 제공하였고, 이외의 데이터셋은 512 tokens을 제공함
FAISS
Index는 1M 개의 randomly sampled key를 사용하여 4096개의 cluster centroid를 학습시켜서 만듬
효율을 위해 key(vector)는 64-bytes로 quantization함. 다만 WikiText-103은 full precision을 사용
Inference 시에 $k=1024$ neighbors를 검색하고 최대 32개의 cluster centroid만을 보도록 제한함