Add snippet numbers to filenames in the Dev cog to fix exception formatting (#6135)

This commit is contained in:
Jakub Kuczys
2023-05-12 00:27:19 +02:00
committed by GitHub
parent e7d7eba68f
commit 70ca8ff1f4
2 changed files with 184 additions and 63 deletions

View File

@@ -23,7 +23,7 @@ import types
import re
import sys
from copy import copy
from typing import Any, Awaitable, Dict, Iterator, Literal, Type, TypeVar, Union
from typing import Any, Awaitable, Dict, Iterator, List, Literal, Tuple, Type, TypeVar, Union
from types import CodeType, TracebackType
import discord
@@ -83,13 +83,55 @@ def cleanup_code(content: str) -> str:
return content.strip("` \n")
class SourceCache:
MAX_SIZE = 1000
def __init__(self) -> None:
# estimated to take less than 100 kB
self._data: Dict[str, Tuple[str, int]] = {}
# this just keeps going up until the bot is restarted, shouldn't really be an issue
self._next_index = 0
def take_next_index(self) -> int:
next_index = self._next_index
self._next_index += 1
return next_index
def __getitem__(self, key: str) -> Tuple[List[str], int]:
value = self._data.pop(key) # pop to put it at the end as most recent
self._data[key] = value
# To mimic linecache module's behavior,
# all lines (including the last one) should end with \n.
source_lines = [f"{line}\n" for line in value[0].splitlines()]
# Note: while it might seem like a waste of time to always calculate the list of source lines,
# this is a necessary memory optimization. If all of the data in `self._data` were list,
# it could theoretically take up to 1000x as much memory.
return source_lines, value[1]
def __setitem__(self, key: str, value: Tuple[str, int]) -> None:
self._data.pop(key, None)
self._data[key] = value
if len(self._data) > self.MAX_SIZE:
del self._data[next(iter(self._data))]
class DevOutput:
def __init__(
self, ctx: commands.Context, *, source: str, filename: str, env: Dict[str, Any]
self,
ctx: commands.Context,
*,
source_cache: SourceCache,
filename: str,
source: str,
env: Dict[str, Any],
) -> None:
self.ctx = ctx
self.source = source
self.source_cache = source_cache
self.filename = filename
self.source_line_offset = 0
#: raw source - as received from the command after stripping the code block
self.raw_source = source
self.set_compilable_source(source)
self.env = env
self.always_include_result = False
self._stream = io.StringIO()
@@ -98,12 +140,14 @@ class DevOutput:
self._old_streams = []
@property
def source(self) -> str:
return self._original_source
def compilable_source(self) -> str:
"""Source string that we pass to async_compile()."""
return self._compilable_source
@source.setter
def source(self, value: str) -> None:
self._source = self._original_source = value
def set_compilable_source(self, compilable_source: str, *, line_offset: int = 0) -> None:
self._compilable_source = compilable_source
self.source_line_offset = line_offset
self.source_cache[self.filename] = (compilable_source, line_offset)
def __str__(self) -> str:
output = []
@@ -124,10 +168,8 @@ class DevOutput:
if tick and not self.formatted_exc:
await self.ctx.tick()
def set_exception(self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1) -> None:
self.formatted_exc = self.format_exception(
exc, line_offset=line_offset, skip_frames=skip_frames
)
def set_exception(self, exc: Exception, *, skip_frames: int = 1) -> None:
self.formatted_exc = self.format_exception(exc, skip_frames=skip_frames)
def __enter__(self) -> None:
self._old_streams.append(sys.stdout)
@@ -144,31 +186,49 @@ class DevOutput:
@classmethod
async def from_debug(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any]
cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput:
output = cls(ctx, source=source, filename="<debug command>", env=env)
output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<debug command - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_debug()
return output
@classmethod
async def from_eval(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any]
cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput:
output = cls(ctx, source=source, filename="<eval command>", env=env)
output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<eval command - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_eval()
return output
@classmethod
async def from_repl(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any]
cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput:
output = cls(ctx, source=source, filename="<repl session>", env=env)
output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<repl session - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_repl()
return output
async def run_debug(self) -> None:
self.always_include_result = True
self._source = self.source
self.set_compilable_source(self.raw_source)
try:
compiled = self.async_compile_with_eval()
except SyntaxError as exc:
@@ -182,12 +242,14 @@ class DevOutput:
async def run_eval(self) -> None:
self.always_include_result = False
self._source = "async def func():\n%s" % textwrap.indent(self.source, " ")
self.set_compilable_source(
"async def func():\n%s" % textwrap.indent(self.raw_source, " "), line_offset=1
)
try:
compiled = self.async_compile_with_exec()
exec(compiled, self.env)
except SyntaxError as exc:
self.set_exception(exc, line_offset=1, skip_frames=3)
self.set_exception(exc, skip_frames=3)
return
func = self.env["func"]
@@ -195,13 +257,13 @@ class DevOutput:
with self:
self.result = await func()
except Exception as exc:
self.set_exception(exc, line_offset=1)
self.set_exception(exc)
async def run_repl(self) -> None:
self.always_include_result = False
self._source = self.source
self.set_compilable_source(self.raw_source)
executor = None
if self.source.count("\n") == 0:
if self.raw_source.count("\n") == 0:
# single statement, potentially 'eval'
try:
code = self.async_compile_with_eval()
@@ -231,14 +293,12 @@ class DevOutput:
self.env["_"] = self.result
def async_compile_with_exec(self) -> CodeType:
return async_compile(self._source, self.filename, "exec")
return async_compile(self.compilable_source, self.filename, "exec")
def async_compile_with_eval(self) -> CodeType:
return async_compile(self._source, self.filename, "eval")
return async_compile(self.compilable_source, self.filename, "eval")
def format_exception(
self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1
) -> str:
def format_exception(self, exc: Exception, *, skip_frames: int = 1) -> str:
"""
Format an exception to send to the user.
@@ -260,33 +320,44 @@ class DevOutput:
break
tb = tb.tb_next
# To mimic linecache module's behavior,
# all lines (including the last one) should end with \n.
source_lines = [f"{line}\n" for line in self._source.splitlines()]
filename = self.filename
# sometimes SyntaxError.text is None, sometimes it isn't
if (
issubclass(exc_type, SyntaxError)
and exc.filename == filename
and exc.lineno is not None
):
if exc.text is None:
# line numbers are 1-based, the list indexes are 0-based
exc.text = source_lines[exc.lineno - 1]
exc.lineno -= line_offset
if issubclass(exc_type, SyntaxError) and exc.lineno is not None:
try:
source_lines, line_offset = self.source_cache[exc.filename]
except KeyError:
pass
else:
if exc.text is None:
try:
# line numbers are 1-based, the list indexes are 0-based
exc.text = source_lines[exc.lineno - 1]
except IndexError:
# the frame might be pointing at a different source code, ignore...
pass
else:
exc.lineno -= line_offset
else:
exc.lineno -= line_offset
traceback_exc = traceback.TracebackException(exc_type, exc, tb)
py311_or_above = sys.version_info >= (3, 11)
stack_summary = traceback_exc.stack
for idx, frame_summary in enumerate(stack_summary):
if frame_summary.filename != filename:
try:
source_lines, line_offset = self.source_cache[frame_summary.filename]
except KeyError:
continue
lineno = frame_summary.lineno
if lineno is None:
continue
# line numbers are 1-based, the list indexes are 0-based
line = source_lines[lineno - 1]
try:
# line numbers are 1-based, the list indexes are 0-based
line = source_lines[lineno - 1]
except IndexError:
# the frame might be pointing at a different source code, ignore...
continue
lineno -= line_offset
# support for enhanced error locations in tracebacks
if py311_or_above:
@@ -327,6 +398,7 @@ class Dev(commands.Cog):
self._last_result = None
self.sessions = {}
self.env_extensions = {}
self.source_cache = SourceCache()
def get_environment(self, ctx: commands.Context) -> dict:
env = {
@@ -382,7 +454,9 @@ class Dev(commands.Cog):
env = self.get_environment(ctx)
source = cleanup_code(code)
output = await DevOutput.from_debug(ctx, source=source, env=env)
output = await DevOutput.from_debug(
ctx, source=source, source_cache=self.source_cache, env=env
)
self._last_result = output.result
await output.send()
@@ -415,7 +489,9 @@ class Dev(commands.Cog):
env = self.get_environment(ctx)
source = cleanup_code(body)
output = await DevOutput.from_eval(ctx, source=source, env=env)
output = await DevOutput.from_eval(
ctx, source=source, source_cache=self.source_cache, env=env
)
if output.result is not None:
self._last_result = output.result
await output.send()
@@ -483,7 +559,9 @@ class Dev(commands.Cog):
del self.sessions[ctx.channel.id]
return
output = await DevOutput.from_repl(ctx, source=source, env=env)
output = await DevOutput.from_repl(
ctx, source=source, source_cache=self.source_cache, env=env
)
try:
await output.send(tick=False)
except discord.Forbidden: