| import re |
|
|
| from flax.core.frozen_dict import freeze |
| from flax.traverse_util import flatten_dict, unflatten_dict |
| from jax.experimental import PartitionSpec as P |
|
|
| |
| |
| _unmatched = object() |
|
|
| |
| empty_dict = object() |
|
|
|
|
| def _match(qs, ks): |
| """Return True if regexes in qs match any window of strings in tuple ks.""" |
| |
| qts = tuple(map(lambda x: re.compile(x + "$"), qs)) |
| for i in range(len(ks) - len(qs) + 1): |
| matches = [x.match(y) for x, y in zip(qts, ks[i:])] |
| if matches and all(matches): |
| return True |
| return False |
|
|
|
|
| def _replacement_rules(rules): |
| def replace(key, val): |
| for rule, replacement in rules: |
| if _match(rule, key): |
| return replacement |
| return val |
|
|
| return replace |
|
|
|
|
| def _get_partition_rules(): |
| return [ |
| |
| (("embed_positions", "embedding"), P("mp", None)), |
| (("embed_tokens", "embedding"), P("mp", None)), |
| (("rel_bias", "embedding"), P(None, "mp")), |
| |
| (("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")), |
| (("out_proj", "kernel"), P("mp", None)), |
| |
| (("Dense_0", "kernel"), P(None, "mp")), |
| (("GLU.*", "Dense_1", "kernel"), P(None, "mp")), |
| (("GLU.*", "Dense_2", "kernel"), P("mp", None)), |
| (("FFN.*", "Dense_1", "kernel"), P("mp", None)), |
| |
| (("(bias|scale)",), None), |
| (("lm_head", "kernel"), P(None, "mp")), |
| |
| (("(head_scale|tau)",), None), |
| ] |
|
|
|
|
| def set_partitions(in_dict, use_scan): |
| rules = _get_partition_rules() |
| replace = _replacement_rules(rules) |
| initd = {k: _unmatched for k in flatten_dict(in_dict)} |
| result = {k: replace(k, v) for k, v in initd.items()} |
| for k, v in result.items(): |
| if v == _unmatched: |
| print(f"Unmatched -> {k}") |
| l = list(result.keys()) |
| if use_scan: |
| |
| result = { |
| k: (P(*(None,) + v) if v is not None else None) |
| if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"]) |
| else v |
| for k, v in result.items() |
| } |
| assert _unmatched not in result.values(), "Incomplete partition spec." |
| return freeze(unflatten_dict(result)) |
|
|