Conversation
Summary of ChangesHello @kevssim, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request implements support for a combined Expert Parallelism and Fully Sharded Data Parallel (EP+FSDP) strategy, specifically targeting Mixture-of-Experts (MoE) models. The changes enable more efficient distributed training by allowing experts to be sharded across a dedicated FSDP mesh while maintaining existing expert parallelism. This involves updates to expert sharding, integration into the FSDP wrapping mechanism, and improvements to gradient clipping for complex distributed tensor configurations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for EP+FSDP, a combination of expert parallelism and fully sharded data parallelism. The changes are spread across model parallelism logic, strategy definitions, and utilities for device mesh and gradient clipping. The implementation provides a new method for sharding and executing experts under this combined parallelism scheme. My review highlights two main concerns: a potential dead code path in the expert execution logic which could lead to confusion, and a restrictive assumption in a device mesh utility function that might limit its use with more complex distributed configurations.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for a new parallel strategy, ep_fsdp, which combines Expert Parallelism (EP) with Fully Sharded Data Parallelism (FSDP). This is a significant and complex feature. The implementation is well-structured, involving a major refactoring of the expert parallelism logic, a new FSDP wrapping strategy, and custom EP-aware gradient clipping. The changes are accompanied by a new cookbook example and a rigorous precision test, which is excellent.
The core of the new design is the introduction of a separate ep_fsdp_device_mesh to manage EP and FSDP on experts, decoupling it from the main device mesh. This allows for flexible parallel configurations. The expert communication logic has been refactored into a new ep_utils.py file, improving modularity. The NativeFSDPStrategy is now much more sophisticated, applying different sharding strategies and mixed-precision policies to expert and non-expert layers.
Overall, this is a high-quality contribution that adds a powerful new capability. My feedback is minor, focusing on improving a type hint for clarity in a new utility file.
| def tokens_post_all2all( | ||
| expert_outputs: torch.Tensor, | ||
| routing_weights: torch.Tensor, | ||
| selected_experts: int, |
There was a problem hiding this comment.
| self._enable_expert_parallel = self._should_enable_expert_parallel(self._expert_parallel_config, | ||
| self.device_mesh) | ||
| self._expert_parallel_applied = False | ||
| # Store ep_size for later use (EP mesh construction, grad clip, etc.) |
There was a problem hiding this comment.
这里的逻辑是否可以封装进NativeFSDPStrategy,看没有其他地方使用
| fsdp_config=self._fsdp_config, | ||
| device_mesh=self.device_mesh, | ||
| enable_ep=self._enable_expert_parallel, | ||
| ep_fsdp_device_mesh=ep_fsdp_mesh, |
| return None | ||
| world_size = self.world_size | ||
| assert world_size % ep_size == 0, (f'world_size ({world_size}) must be divisible by ep_size ({ep_size})') | ||
| ep_fsdp_size = world_size // ep_size |
There was a problem hiding this comment.
如果这里能算出来ep_fsdp_size,还需要外部传入吗?
There was a problem hiding this comment.
或者说,这里如何判定需要在ep内部开启fsdp呢?
| mesh = ( | ||
| torch.arange(math.prod((ep_size, ep_fsdp_size)), dtype=torch.int).view(ep_fsdp_size, | ||
| ep_size).transpose(0, 1)) | ||
| return torch.distributed.DeviceMesh(self.device_type, mesh, mesh_dim_names=('ep', 'ep_fsdp')) |
There was a problem hiding this comment.
to_torch_device_mesh这个已经有了,考虑复用?
| self.mesh = np.array(self.mesh) | ||
|
|
||
| valid_dim_names = {'dp', 'fsdp', 'tp', 'pp', 'cp', 'ep'} | ||
| valid_dim_names = {'dp', 'fsdp', 'tp', 'pp', 'cp', 'ep', 'ep_fsdp'} |
There was a problem hiding this comment.
可以考虑使用from_sizes,这里应该就不用修改了
| # EP: reduce over ep_fsdp_group, then ep_group | ||
| ep_val = _local_norm_stat(ep_params, norm_type) | ||
| if ep_fsdp_group is not None: | ||
| op = dist.ReduceOp.MAX if math.isinf(norm_type) else dist.ReduceOp.SUM |
There was a problem hiding this comment.
fsdp2是不是不需要使用all_reduce?
| total_norm_tensor = torch.tensor(local_norm, device=reduce_device, dtype=torch.float32) | ||
| if dist.is_initialized(): | ||
| dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=group) | ||
| dist.all_reduce(total_norm_tensor, op=dist.ReduceOp.MAX, group=reduce_group) |
PR type
PR information
support ep_fsdp.
Experiment results
Env&Config
NPU*8, Qwen3-30B-A3B, GBS=16, Grad_acc=4, ep_size=8, fsdp_size=8
Loss curve comparison : ep_fsdp vs. pure fsdp
Performance comparison