Skip to content

Commit

Permalink
GPT-2 Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Nov 8, 2022
1 parent c94d930 commit a582a46
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
3 changes: 2 additions & 1 deletion optimum/graphcore/fx/transformation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import copy
import functools
import operator
from typing import Iterator, List, Tuple, Union

import torch
Expand Down Expand Up @@ -123,6 +124,6 @@ def compose_reversible_transformations(self, optimization_level: int) -> Reversi
(1, MergeLinears()),
# (1, FuseBiasInLinear()),
# Those change the computation, but are actually needed for fp16 stability.
(0, ClipValuesSymmetric(1e4, exclude_targets=("view",))),
(0, ClipValuesSymmetric(1e4, include_targets=(torch.add, torch.mul, operator.add, operator.mul))),
(0, ClipValues(1e-4, float("inf"), include_targets=(torch.nn.LayerNorm,))),
)
14 changes: 11 additions & 3 deletions optimum/graphcore/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def __init__(
):
if clip_value < 0:
raise ValueError(f"The provided clip value must be equal or greater than 0, but here {clip_value}.")
return super().__init__(-clip_value, clip_value, exclude_targets=exclude_targets)
return super().__init__(
-clip_value, clip_value, include_targets=include_targets, exclude_targets=exclude_targets
)


class OutlineAttribute(ReversibleTransformation):
Expand Down Expand Up @@ -406,7 +408,9 @@ def sort_nodes_function(node):

embedding_node = max(embedding_nodes, key=sort_nodes_function)
if embedding_node.op == "call_function":
raise NotImplementedError("VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet.")
raise NotImplementedError(
"VocabEmbeddingToSerializedEmbedding does not support torch.nn.functional.embedding yet."
)

split = embedding_node.target.rsplit(".", maxsplit=1)
if len(split) == 1:
Expand Down Expand Up @@ -520,7 +524,11 @@ def transform(self, graph_module: "GraphModule") -> "GraphModule":


class ShareEmbeddingComputation(Transformation):
def __init__(self, name_regex: Optional[str] = None, allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding)):
def __init__(
self,
name_regex: Optional[str] = None,
allowed_embedding_classes: Union[Tuple[Type], Type] = (torch.nn.Embedding, SerializedEmbedding),
):
self.name_regex = re.compile(name_regex) if name_regex else None
self.allowed_embedding_classes = allowed_embedding_classes
if not isinstance(self.allowed_embedding_classes, tuple):
Expand Down
1 change: 0 additions & 1 deletion optimum/graphcore/fx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

# TODO: keep this until transformers >= 4.23.2
class GCProxy(HFProxy):

@property
def dtype(self):
return self.__getattr__("dtype")
Expand Down
4 changes: 2 additions & 2 deletions optimum/graphcore/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
DEFAULT_TRANSFORMATION_MANAGER,
AddPoptorchBlock,
AddPoptorchBlocksInSeries,
LinearToSerializedLinear,
OutlineAttribute,
RecomputationCheckpoint,
VocabEmbeddingToSerializedEmbedding,
LinearToSerializedLinear,
TieWeights,
VocabEmbeddingToSerializedEmbedding,
symbolic_trace_pipelined_model,
)
from ...modeling_utils import OnehotGather, PipelineMixin, get_layer_ipu, register
Expand Down
10 changes: 7 additions & 3 deletions optimum/graphcore/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
symbolic_trace_pipelined_model,
)
from ...modeling_utils import PipelineMixin, get_layer_ipu, register
from .optimized_gpt2_attn import OptimizedGPT2Attention


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -69,7 +70,7 @@ def get_transformations(self):
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
transformations = [
AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions),
AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions),
AddPoptorchBlock("Position Embedding", 0, "transformer.wpe", log_insertions=log_insertions),
OutlineAttribute("transformer.ln_f", "LayerNorm"),
AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions),
# Only one of the following AddPoptorchBlock, will actually add a block.
Expand All @@ -84,7 +85,7 @@ def get_transformations(self):
)
)
if self.ipu_config.embedding_serialization_factor > 1:
transformations.append(VocabEmbeddingToSerializedEmbedding())
transformations.append(VocabEmbeddingToSerializedEmbedding("transformer.wte"))

return transformations

Expand All @@ -96,6 +97,9 @@ def parallelize(self):
- Adds recomputation checkpoints
"""
PipelineMixin.parallelize(self)
if not isinstance(self, torch.fx.GraphModule):
for layer in self.transformer.h:
layer.attn.__class__ = OptimizedGPT2Attention
if self.ipu_config.embedding_serialization_factor > 1:
self.resize_vocab(False)
traced = symbolic_trace_pipelined_model(self)
Expand Down Expand Up @@ -137,7 +141,7 @@ def get_transformations(self):
layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu)
transformations = [
AddPoptorchBlock("Token Embedding", 0, "transformer.wte", log_insertions=log_insertions),
AddPoptorchBlock("Position Embedding", 1, "transformer.wtp", log_insertions=log_insertions),
AddPoptorchBlock("Position Embedding", 0, "transformer.wpe", log_insertions=log_insertions),
OutlineAttribute("transformer.ln_f", "LayerNorm"),
AddPoptorchBlocksInSeries("Layer", layer_ipu, r"transformer.h.[0-9]+", log_insertions=log_insertions),
AddPoptorchBlock("LM Head", 0, "lm_head", log_insertions=log_insertions),
Expand Down

0 comments on commit a582a46

Please sign in to comment.