diff --git a/.gitignore b/.gitignore index d52b3635c..b84734b6f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ build/ /.devcontainer docs*/**/_autosummary docs*/_build +docs*/_collections docs*/**/tmp flaxlib_src/build flaxlib_src/builddir diff --git a/docs_nnx/conf.py b/docs_nnx/conf.py index 4b3a8c760..ac73e3f3d 100644 --- a/docs_nnx/conf.py +++ b/docs_nnx/conf.py @@ -65,8 +65,45 @@ 'codediff', 'flax_module', 'sphinx_design', + 'sphinx_collections', ] +# -- Sphinx-collections configuration ---------------------------------------- +# Custom function to clone optax repo and extract only the docs/api folder +def clone_optax_api_docs(config): + """Clone optax repo and copy only docs/api folder.""" + import tempfile + import shutil + from pathlib import Path + from git import Repo + + target = Path(config['target']) + if target.exists(): + return None # Already cloned, return None to skip writing + + with tempfile.TemporaryDirectory() as tmpdir: + Repo.clone_from( + 'https://github.com/google-deepmind/optax.git', + tmpdir, + depth=1, + ) + src = Path(tmpdir) / 'docs' / 'api' + shutil.copytree(src, target) + + return None # We handle file creation ourselves + + +collections = { + 'optax_api': { + 'driver': 'function', + 'source': clone_optax_api_docs, + 'target': 'optax_api', # Will be placed in _collections/optax_api + 'write_result': False, # We handle file creation ourselves + 'final_clean': False, # Keep the cloned files after build + 'clean': False, # Don't clean before building (we check if exists) + } +} + # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] diff --git a/docs_nnx/index.rst b/docs_nnx/index.rst index b02d1ae82..296838205 100644 --- a/docs_nnx/index.rst +++ b/docs_nnx/index.rst @@ -200,3 +200,10 @@ Learn more philosophy contributing api_reference/index + +.. toctree:: + :hidden: + :maxdepth: 2 + :caption: Optax API + + _collections/optax_api/index diff --git a/pyproject.toml b/pyproject.toml index 92856df13..d76983b28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,8 @@ testing = [ docs = [ "sphinx==6.2.1", "sphinx-book-theme", + "sphinx-collections", + "gitpython", "Pygments>=2.6.1", "ipykernel", "myst_nb",