AttributeError: module 'jax' has no attribute 'tree_multimap'

Problem

AttributeError: module 'jax' has no attribute 'tree_multimap'


Solution

Use jax.tree_util.tree_map or jax.tree_map instead of jax.tree_multimap.





- jax.tree_util.tree_multimap() has been removed. It has been deprecated since JAX release 0.3.5, and jax.tree_util.tree_map() is a direct replacement.



댓글