Dr. House ChatBot Design
This project is to hold a “chat” with a Dr. House language model.
See current demo here.
Status: DRAFT This project recreates a prior version of this project, which used the older, pre-Keras TF 1.0 seq2seq package from 2016. There have been significant LM improvements since then (namely, Transformers).
Objective
Language modelling is quite general, and the medium-term idea is to perform protein language modelling in some form. The main goal here is to re-familiarize with the latest language modelling toolkits by building an LM-powered web-app.
There are several reasons why Dr. House chatbot makes a particularly good language modeling project:
- Data: eight seasons of House talking back and forth to a variety of doctors, patients, other miscellaneous characters. The transcripts are all available unofficially here.
- Modelling: In a sequence-to-sequence model, the “prior” sentence can be translated into a response from Dr. House. Further, transformers can use attention over an entire prior conversation to generate the next word.
- Scope: by limiting to a single character’s interactions, a relatively small language model should provide sufficient coverage.
- Inference: The smaller model is more amenable to an embedded web environment, e.g. via ONNX Runtime Web or TensorFlow.JS.
- Chat UI: preact.js is a good library for a small Chat UI.
- Fun: it is relatively fun to imagine one is discussing with Dr. House about something.
Data
The simplest format would be a single input, followed by a single response:
{
"input": "I'm playing racquetball tomorrow night, with Taub.",
"response": "Why would you hide that?"
}
However, this option negates some important aspects of the back-and-forth nature of the dialogue. It could make more sense to model the data as a longer-form conversation, with a “target speaker” that is modelled. Conversational text such as movie scripts tend to have more turn-taking and exchange than NMT or other “typical” seq2seq tasks. Transformer architectures in particular can query for specific pieces of information in the larger context to generate the next word, likely leading to dramatic quality improvements over RNNs. To achieve these benefits, providing more context to the above conversation might provide data in a format like the following:
{
"target_speaks_at_turns": [0, 2, 4, 6, 8, ...],
"conversation": [
"Anyone sitting here?",
"Just my persona.",
"You know, it's amazing the way people cling on to insults. Or what they think are insults. (He takes a sandwich and fries off of Wilson’s plate and puts them on his own)",
"So that wasn't an insult?",
"I'm not suggesting that, like our patient, you're hiding a dark, sarcastic core beneath a candy shell of compulsive niceness. (House has pulled a fork out of his breast pocket)",
"I'm not always nice. I'm not nice to you.",
"Because you know nice bores me. Hence, still nice. No, I'm suggesting that you have no core. You're what whoever you're with needs you to be. Okay, I guess that could be insulting. The interesting question is why. Why do you think the world will end in chaos and destruction if you're not there to save it? (He starts eating Wilson’s lunch)",
"Because when my parents put me in the rocket and sent me here, they said, 'James, you will grow to manhood under a yellow sun.'",
"And why'd you lie about monster trucks?",
"I didn’t.",
"I checked your appointment book. You got tomorrow night marked off, but you didn't put down what you were doing. So you thought someone might look at the book —",
"I'm playing racquetball tomorrow night, with Taub.",
"Why would you hide that?",
"Because the world revolves around you. I devote time to anyone else, you'd end up stalking me and harassing them.",
"You say that as though it wouldn't be fun.",
"And maybe I didn't want to rub your nose in the fact that we'd be doing something you can no longer do. Because I'm nice.",
]
}
BeautifulSoup and Requests can be used to download/parse the conversations (e.g. this tutorial).
Modelling
Model types: Two powerhouse architectures of language modeling are the LSTM and the Transformer. Since the LSTM provides streaming capabilities, it can be more efficient in inference; however, this efficiency comes at the cost of accuracy. Transformers use self-attention to query for long-range dependencies in the input explicitly, and can dramatically improve the modelling quality. An encoder-decoder architecture in which the decoder generates a full response from a frozen copy of the input context can further improve performance.
The loss function can be cross entropy over the “next” token: $P(y_t | y_{1 \dots t - 1})$. At decoding, either sampling or $\argmax$ over this distribution can be used.
The loss will be masked to apply only to tokens where the target speaker (Dr. House) talks. A special <EOS>
token will demarcate speaker changes. At inference time, the model will receive the prior conversation as input, decode until an <EOS>
token is emitted, and then pass the turn back to the user. Modelling a stream of text with a turn-taking <EOS>
token also has some small benefits:
- The language modelling task can come in two modes: a generic LM task predicting the next token, followed with a fine-tuned task using the masked approach. Since the data is relatively small (only scripts from House), this approach can be helpful.
- The LM can pick up on longer-term dependencies (especially if a Transformer with causal attention is used).
- bucketing is easier, since the input is now 1D instead of 2D pairs.
Experimental notes follow:
The downside to masking was implementing a tf.function in a custom training loop, rather than using Keras built-in training loops (compile
and then fit
). The custom training loop provides flexibility at the cost of verbosity. The keras class_weights feature could provide some improvements here.
The ml-collections library is helpful for creating training configurations.
Training and testing configurations were split on a per-episode level, rather than a per-example level. This split ensures that the test set remains relatively novel.
A reference for different gradient descent optimizers is here.
Inference
The model should run on the web, to serve it for this website. While it is more difficult to configure, it will be more sustainable in the long-term to remove any dependency on a separate web server. Fewer moving parts will enable faster iteration.
Another alternative would be AWS Lambda, which removes the overhead of a separate service. It is likely that this approach would be viable. If the local load time is too high, then the model can be moved to a lambda function relatively easily to provide a good user experience for low-latency loading. If the local load time is not too high, then the simplicity of on-device inference is preferred.
The user will need to load a vocabulary file, and the client will have to do some basic tokenization, to pre-process the input text. For this reason, it can be beneficial to use simple tokenization (either whole-word or char-level), rather than relying on a large WordPiece model to tokenize the text.
In terms of JavaScript neural net runtimes, I see two primary options (one for TensorFlow, and one for PyTorch).
(1) TensorFlow.JS has mature documentation:
- Conversion is documented. It may be difficult to debug errors, but the process is relatively straightforward.
- There are working NLP examples in Keras as well as a big demo page and gallery page with a few relevant text-based demos (seq2seq, language modelling, and translation).
- Benchmarking software is available in-browser.
(2) ONNX documentation is also pretty good:
- Conversion API is simple on the surface, as for Keras: (torch.onnx)- ONNX Web Runtime Examples are more sparse, and not running in-browser.
- MobileNet demo is here. There is extensive documentation about the model.
Conversion
The architecture of a language model has a distinction between training and inference.
- At training time, one needs to compute the output for every timestep in the sequence, in order to perform some kind of loss function at each timestep (e.g. with cross-entropy predicting the next token).
- At inference time, one typically needs to stream the predictions, and only needs the latest sequence.
The LSTM char-seq2seq Keras tutorial hints at this issue in the inference section: the code creates a separate copy of the model for inference (via copy-paste-modify).
To reuse the same model in either mode, the keyword-spotting streaming library (github) provides some inspiration. In particular, the stateful model $LSTM(x_t)$ is technically a stateless function of its state $LSTM(x_t, S_{t-1})$. The inference loop can track the external state, in a loop resembling something like the following code snippet:
def Decode(lstm: keras.Model, input_token_ids: List[int],
eos_token_id: int=EOS_TOKEN_ID,
max_decoder_len=20) -> List[int]:
# Initialize the model state from the inputs.
state = lstm.zero_states()
softmax, *state = lstm([token_ids, state])
# The decode function is a greedy argmax
# (could also implement beam search).
decode = lambda softmax: np.argmax(softmax)
# Decode the rest of the sequence.
# Keep predicting the next token,
# until the model emits end-of-sentence.
# The max decoder length ensures termination,
# even in the case of model error.
next_token = decode(softmax)
decoder_outputs = [next_token]
while (next_token != eos_token_id or
len(decoder_outputs) >= max_decoder_len):
# Run the model for another timestep,
# with the prior output and prior state as input.
softmax, *state = lstm([[next_token_id], state])
# Decode the next token, and store it.
next_token = decode(softmax)
decoder_output.append(next_token_id)
# Finished.
return decoder_output
To achieve the desired inference behavior, try:
- Using the
return_state
mode in the LSTM, instead ofreturn_sequence
(for training). - During model creation, append the state tuple from each layer as an output.
- Theory: One can “cast” the training version of the LSTM model to an inference version of the model (making two above modifications) using the Keras weight-restoring mechanism, i.e. Saving & loading only the model’s weights values. In particular, the following code snippet may work, provided that the
GetConfig
function returns all the necessary layer configurations:
training_model = LoadModel(training_dir)
inference_model = CreateModel(GetConfig(training_dir), mode='inference')
for training_layer, inference_layer in zip(
training_model.layers, inference_model.layers):
inference_layer.set_weights(
training_layer.get_weights())
Weight Tying
Weight-Tying is helpful for small LMs, because if the vocabulary is large (e.g. 10k) and in a small LSTM LM, most of the parameters could be used in the embedding layer.
A note on SavedModel format and Weight Tying: the SavedModel format traces the Keras model, and in calling each layer, records the operations. This approach means that a simple operationa like WeightTying
does not copy any variables it uses twice, provided you do not save them as local variables. The below is a simple implementation of a custom Keras Layer to emit logits (the Activation layer can add the softmax on top).
class WeightTiedLayer(keras.layers.Layer):
'''
Implements Weight-Tying to the embedding layer.
Usage:
logits = WeightTiedLayer(embedding)(outputs)
'''
def __init__(self, embedding_layer=None, **kwargs):
super(WeightTiedLayer, self).__init__(**kwargs)
self.embedding_layer = embedding_layer
def build(self, input_shape):
self.bias = self.add_weight(
shape=(self.embedding_layer.input_dim,),
name='bias',
initializer='random_normal',
trainable=True)
def call(self, inputs):
# Shapes:
# - inputs: (batch, seq, N)
# - embedding_weights: (vocab_size, N)
# - embedding_weights.T: (N, vocab_size)
embedding_weights = self.embedding_layer.weights[0]
output = inputs @ tf.transpose(embedding_weights)
return output + self.bias
UI
The UI can be quite simple; the main challenge will be to design the integration.
- For the UI itself, the user hits
Enter
into a textbox, triggering an inference session of the model. The model then emits some text (and while the model is executing, the user cannot submit more text). React/Preact are good tools for this simple of a UI. - The main integration issues will be the asynchronous response from the model, and configuration of the input. There are two inputs:
MOST_RECENT
andFULL
. The former passes only the last message the user sent, while the latter passes the full sequence to the model.FULL
could be most beneficial for Transformer-based models; however, an LSTM baseline could still be helpful. - It is also possible that the model is slow to load on the client. Making the model smaller may help ameliorate this drawback, but perhaps at the cost of accuracy. If needed, the JavaScript model could be moved to a “serverless” AWS Lambda function; however, for low usage, the Lambda function could have significant load time.