Add Sphinx extension for extracting prompt contents (#6671)

This commit is contained in:
Jakub Kuczys
2026-03-04 22:12:17 +01:00
committed by GitHub
parent bdc66c3f56
commit b7c11c016e
4 changed files with 163 additions and 0 deletions

154
docs/_ext/prompt_builder.py Normal file
View File

@@ -0,0 +1,154 @@
from __future__ import annotations
import json
import os
from typing import Any, Dict, List, Set
from docutils import nodes
from docutils.io import StringOutput
from docutils.nodes import Element
from sphinx.application import Sphinx
from sphinx.builders.text import TextBuilder
from sphinx.writers.text import TextWriter
from sphinx.util import logging
from sphinx.util.docutils import SphinxTranslator
logger = logging.getLogger(__name__)
class PromptTranslator(SphinxTranslator):
builder: PromptBuilder
def __init__(self, document: nodes.document, builder: PromptBuilder) -> None:
super().__init__(document, builder)
self.body = ""
self.prompts: List[Dict[str, str]] = []
def visit_document(self, node: Element) -> None:
pass
def depart_document(self, node: Element) -> None:
if not self.prompts:
self.body = ""
return
if self.builder.out_suffix.endswith(".json"):
self.body = json.dumps(self.prompts, indent=4)
else:
self.body = "\n".join(prompt["content"] for prompt in self.prompts)
def unknown_visit(self, node: Element) -> None:
pass
def unknown_departure(self, node: Element) -> None:
pass
def visit_prompt(self, node: Element) -> None:
self.prompts.append(
{
"language": node.attributes["language"],
"prompts": node.attributes["prompts"],
"modifiers": node.attributes["modifiers"],
"rawsource": node.rawsource,
"content": node.children[0],
}
)
class PromptWriter(TextWriter):
def translate(self) -> None:
visitor = self.builder.create_translator(self.document, self.builder)
self.document.walkabout(visitor)
self.output = visitor.body
class prompt(nodes.literal_block):
pass
class PromptBuilder(TextBuilder):
"""Extract prompts from documents."""
format = "json"
epilog = "The files with prompts are in %(outdir)s."
out_suffix = ".json"
default_translator_class = PromptTranslator
writer: PromptWriter
def init(self) -> None:
sphinx_prompt = __import__("sphinx-prompt")
def run(self) -> List[prompt]:
self.assert_has_content()
rawsource = "\n".join(self.content)
language = self.options.get("language") or "text"
prompts = [
p
for p in (
self.options.get("prompts") or sphinx_prompt.PROMPTS.get(language, "")
).split(",")
if p
]
modifiers = [
modifier for modifier in self.options.get("modifiers", "").split(",") if modifier
]
content = rawsource
if "auto" in modifiers:
parts = []
for line in self.content:
for p in prompts:
if line.startswith(p):
line = line[len(p) + 1 :].rstrip()
parts.append(line)
content = "\n".join(parts)
node = prompt(
rawsource,
content,
directive_content=self.content,
language=language,
prompts=self.options.get("prompts") or sphinx_prompt.PROMPTS.get(language, ""),
modifiers=modifiers,
)
return [node]
sphinx_prompt.PromptDirective.run = run
def prepare_writing(self, docnames: Set[str]) -> None:
del docnames
self.writer = PromptWriter(self)
def write_doc(self, docname: str, doctree: nodes.document) -> None:
self.writer.write(doctree, StringOutput(encoding="utf-8"))
if not self.writer.output:
# don't write empty files
return
filename = os.path.join(self.outdir, docname.replace("/", os.path.sep) + self.out_suffix)
os.makedirs(os.path.dirname(filename), exist_ok=True)
try:
with open(filename, "w", encoding="utf-8") as f:
f.write(self.writer.output)
except OSError as err:
logger.warning("error writing file %s: %s", filename, err)
class JsonPromptBuilder(PromptBuilder):
name = "jsonprompt"
out_suffix = ".json"
class TextPromptBuilder(PromptBuilder):
name = "textprompt"
out_suffix = ".txt"
def setup(app: Sphinx) -> Dict[str, Any]:
app.add_builder(JsonPromptBuilder)
app.add_builder(TextPromptBuilder)
return {
"version": "1.0",
"parallel_read_safe": True,
"parallel_write_safe": True,
}