 Back

# Decoding Strategies in Large Language Models – TechToday

The tokenizer, Byte-Pair Encoding in this case, translates each token in the input text into a corresponding token identifier. GPT-2 then uses these token IDs as input and tries to predict the next most likely token. Finally, the model generates logits, which are converted to probabilities using a softmax function.

For example, the model assigns a 17% probability to the token that “of” is the next token after “I have a dream.” This output essentially represents a sorted list of possible next tokens in the sequence. More formally, we denote this probability as P(of | I have a dream) = 17%.

Autoregressive models like GPT predict the next token in a sequence based on previous tokens. Consider a sequence of tiles w = ( ww…, w). The joint probability of this sequence P(w) can be broken down as:

For each witness wᵢ in the sequence, P(wᵢ | w₁, w₂, …, wᵢ₋₁) represents the conditional probability of wᵢ given all the previous tiles (w₁, w₂, …, wᵢ₋₁). GPT-2 calculates this conditional probability for each of the 50,257 tokens in its vocabulary.

This leads to the question: how do we use these probabilities to generate text? This is where decoding strategies such as greedy search and beam search come into play.

Greedy search is a decoding method that takes the most likely token at each step as the next in the sequence. Simply put, it only keeps the most likely token at each stage, discarding all other potential options. Using our example:

• Step 1: Input: “I have a dream” → Most likely indicator: “of”
• Step 2: Input: “I have a dream of” → Most likely Indicator: “be”
• Step 3: Input: “I have a dream to be” → Most likely indicator: “a”
• Step 4: Input: “I have a dream to be” → Most likely symbol: “doctor”
• Step 5: Input: “I have a dream to be a doctor” → Most likely symbol: “.”

Although this approach may seem intuitive, it is important to note that greedy search is short-sighted: it only considers the most likely token at each step without considering the overall effect on the sequence. This property makes it fast and efficient since it doesn’t need to track multiple sequences, but it also means that it can miss better sequences that might have appeared with slightly less likely tiles.

Next, we illustrate the greedy search implementation using graphviz and networkx. We select the ID with the highest score, calculate its log probability (we take the log to simplify calculations), and add it to the tree. We will repeat this process for five tiles.

`import matplotlib.pyplot as pltimport networkx as nximport numpy as npimport timedef get_log_prob(logits, token_id):# Compute the softmax of the logitsprobabilities = torch.nn.functional.softmax(logits, dim=-1)log_probabilities = torch.log(probabilities)# Get the log probability of the tokentoken_log_probability = log_probabilities[token_id].item()return token_log_probabilitydef greedy_search(input_ids, node, length=5):if length == 0:return input_idsoutputs = model(input_ids)predictions = outputs.logits# Get the predicted next sub-word (here we use top-k search)logits = predictions[0, -1, :]token_id = torch.argmax(logits).unsqueeze(0)# Compute the score of the predicted tokentoken_score = get_log_prob(logits, token_id)# Add the predicted token to the list of input idsnew_input_ids = torch.cat([input_ids, token_id.unsqueeze(0)], dim=-1)# Add node and edge to graphnext_token = tokenizer.decode(token_id, skip_special_tokens=True)current_node = list(graph.successors(node))graph.nodes[current_node]['tokenscore'] = np.exp(token_score) * 100graph.nodes[current_node]['token'] = next_token + f"_length"# Recursive callinput_ids = greedy_search(new_input_ids, current_node, length-1)return input_ids# Parameterslength = 5beams = 1# Create a balanced tree with height 'length'graph = nx.balanced_tree(1, length, create_using=nx.DiGraph())# Add 'tokenscore', 'cumscore', and 'token' attributes to each nodefor node in graph.nodes:graph.nodes[node]['tokenscore'] = 100graph.nodes[node]['token'] = text# Start generating textoutput_ids = greedy_search(input_ids, 0, length=length)output = tokenizer.decode(output_ids.squeeze().tolist(), skip_special_tokens=True)print(f"Generated text: output")`
`Generated text: I have a dream of being a doctor.`

Our greedy search generates the same text as the transformer library: “I have a dream to be a doctor.” Let’s visualize the tree we have created.

`import matplotlib.pyplot as pltimport networkx as nximport matplotlib.colors as mcolorsfrom matplotlib.colors import LinearSegmentedColormapdef plot_graph(graph, length, beams, score):fig, ax = plt.subplots(figsize=(3+1.2*beams**length, max(5, 2+length)), dpi=300, facecolor="white")# Create positions for each nodepos = nx.nx_agraph.graphviz_layout(graph, prog="dot")# Normalize the colors along the range of token scoresif score == 'token':scores = [data['tokenscore'] for _, data in graph.nodes(data=True) if data['token'] is not None]elif score == 'sequence':scores = [data['sequencescore'] for _, data in graph.nodes(data=True) if data['token'] is not None]vmin = min(scores)vmax = max(scores)norm = mcolors.Normalize(vmin=vmin, vmax=vmax)cmap = LinearSegmentedColormap.from_list('rg', ["r", "y", "g"], N=256) # Draw the nodesnx.draw_networkx_nodes(graph, pos, node_size=2000, node_shape="o", alpha=1, linewidths=4, node_color=scores, cmap=cmap)# Draw the edgesnx.draw_networkx_edges(graph, pos)# Draw the labelsif score == 'token':labels = node: data['token'].split('_') + f"ndata['tokenscore']:.2f%" for node, data in graph.nodes(data=True) if data['token'] is not Noneelif score == 'sequence':labels = node: data['token'].split('_') + f"ndata['sequencescore']:.2f" for node, data in graph.nodes(data=True) if data['token'] is not Nonenx.draw_networkx_labels(graph, pos, labels=labels, font_size=10)plt.box(False)# Add a colorbarsm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)sm.set_array([])if score == 'token':fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label="Token probability (%)")elif score == 'sequence':fig.colorbar(sm, ax=ax, orientation='vertical', pad=0, label="Sequence score")plt.show()# Plot graphplot_graph(graph, length, 1.5, 'token')` 