diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4d10d1e77356..9d252ff19252 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -220,7 +220,7 @@ jobs: # To speed-up process until ast_serialize is on PyPI. - name: Install pinned ast-serialize if: ${{ matrix.dev_ast_serialize }} - run: pip install ast-serialize@git+https://github.com/mypyc/ast_serialize.git@da5a16cf268dbec63ed6b2e6b715470576e2d1a6 + run: pip install ast-serialize@git+https://github.com/mypyc/ast_serialize.git@9f3645587cc2cc4a3d93183c8ef255b03f06f647 - name: Setup tox environment run: | diff --git a/mypy/build.py b/mypy/build.py index 4522bd1255fb..4fe6f52f5828 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -107,7 +107,15 @@ send, ) from mypy.messages import MessageBuilder -from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable +from mypy.nodes import ( + FileRawData, + Import, + ImportAll, + ImportBase, + ImportFrom, + MypyFile, + SymbolTable, +) from mypy.options import OPTIONS_AFFECTING_CACHE_NO_PLATFORM from mypy.partially_defined import PossiblyUndefinedVariableVisitor from mypy.semanal import SemanticAnalyzer @@ -148,7 +156,7 @@ ) from mypy.nodes import Expression from mypy.options import Options -from mypy.parse import parse +from mypy.parse import load_from_raw, parse from mypy.plugin import ChainedPlugin, Plugin, ReportConfigContext from mypy.plugins.default import DefaultPlugin from mypy.renaming import LimitedVariableRenameVisitor, VariableRenameVisitor @@ -400,7 +408,7 @@ def default_flush_errors( finally: for worker in workers: try: - send(worker.conn, SccRequestMessage(scc_id=None, import_errors={})) + send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={})) except (OSError, IPCException): pass for worker in workers: @@ -419,7 +427,7 @@ def build_inner( workers: list[WorkerClient], ) -> BuildResult: if platform.python_implementation() == "CPython": - # Run gc less frequently, as otherwise we can spent a large fraction of + # Run gc less frequently, as otherwise we can spend a large fraction of # cpu in gc. This seems the most reasonable place to tune garbage collection. gc.set_threshold(200 * 1000, 30, 30) @@ -928,8 +936,6 @@ def __init__( self.import_options: dict[str, bytes] = {} # Cache for transitive dependency check (expensive). self.transitive_deps_cache: dict[tuple[int, int], bool] = {} - # Resolved paths for each module in build. - self.path_by_id: dict[str, str] = {} # Packages for which we know presence or absence of __getattr__(). self.known_partial_packages: dict[str, bool] = {} @@ -1052,16 +1058,30 @@ def is_module(self, id: str) -> bool: return find_module_simple(id, self) is not None def parse_file( - self, id: str, path: str, source: str, ignore_errors: bool, options: Options + self, + id: str, + path: str, + source: str, + ignore_errors: bool, + options: Options, + raw_data: FileRawData | None = None, ) -> MypyFile: """Parse the source of a file with the given name. Raise CompileError if there is a parse error. """ + imports_only = False + if self.workers and self.fscache.exists(path): + # Currently, we can use the native parser only for actual files. + imports_only = True t0 = time.time() if ignore_errors: self.errors.ignored_files.add(path) - tree = parse(source, path, id, self.errors, options=options) + if raw_data: + # If possible, deserialize from known binary data instead of parsing from scratch. + tree = load_from_raw(path, id, raw_data, self.errors, options) + else: + tree = parse(source, path, id, self.errors, options=options, imports_only=imports_only) tree._fullname = id self.add_stats( files_parsed=1, @@ -1129,14 +1149,14 @@ def add_stats(self, **kwds: Any) -> None: def stats_summary(self) -> Mapping[str, object]: return self.stats - def submit(self, sccs: list[SCC]) -> None: + def submit(self, graph: Graph, sccs: list[SCC]) -> None: """Submit a stale SCC for processing in current process or parallel workers.""" if self.workers: - self.submit_to_workers(sccs) + self.submit_to_workers(graph, sccs) else: self.scc_queue.extend([(0, 0, scc) for scc in sccs]) - def submit_to_workers(self, sccs: list[SCC] | None = None) -> None: + def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None: if sccs is not None: for scc in sccs: heappush(self.scc_queue, (-scc.size_hint, self.queue_order, scc)) @@ -1147,11 +1167,24 @@ def submit_to_workers(self, sccs: list[SCC] | None = None) -> None: import_errors = { mod_id: self.errors.recorded[path] for mod_id in scc.mod_ids - if (path := self.path_by_id[mod_id]) in self.errors.recorded + if (path := graph[mod_id].xpath) in self.errors.recorded } send( self.workers[idx].conn, - SccRequestMessage(scc_id=scc.id, import_errors=import_errors), + SccRequestMessage( + scc_id=scc.id, + import_errors=import_errors, + mod_data={ + mod_id: ( + # Although workers don't really need to know about details + # of dependencies, they will write cache, so we need to pass + # suppressed_deps_opts() as part of module data. + graph[mod_id].suppressed_deps_opts(), + tree.raw_data if (tree := graph[mod_id].tree) else None, + ) + for mod_id in scc.mod_ids + }, + ), ) def wait_for_done( @@ -1166,14 +1199,16 @@ def wait_for_done( The last item is only used for parallel processing. """ if self.workers: - return self.wait_for_done_workers() + return self.wait_for_done_workers(graph) if not self.scc_queue: return [], False, {} _, _, next_scc = self.scc_queue.pop(0) process_stale_scc(graph, next_scc, self) return [next_scc], bool(self.scc_queue), {} - def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, list[str]]]]: + def wait_for_done_workers( + self, graph: Graph + ) -> tuple[list[SCC], bool, dict[str, tuple[str, list[str]]]]: if not self.scc_queue and len(self.free_workers) == len(self.workers): return [], False, {} @@ -1188,7 +1223,7 @@ def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, l assert data.result is not None results.update(data.result) done_sccs.append(self.scc_by_id[scc_id]) - self.submit_to_workers() # advance after some workers are free. + self.submit_to_workers(graph) # advance after some workers are free. return ( done_sccs, bool(self.scc_queue) or len(self.free_workers) < len(self.workers), @@ -1582,7 +1617,7 @@ def exclude_from_backups(target_dir: str) -> None: def create_metastore(options: Options, parallel_worker: bool = False) -> MetadataStore: """Create the appropriate metadata store.""" if options.sqlite_cache: - # We use this flag in both coordinator and workers to seep up commits, + # We use this flag in both coordinator and workers to speed up commits, # see mypy.metastore.connect_db() for details. sync_off = options.num_workers > 0 or parallel_worker mds: MetadataStore = SqliteMetadataStore(_cache_dir_prefix(options), sync_off=sync_off) @@ -2456,12 +2491,10 @@ def new_state( # Parse the file (and then some) to get the dependencies. state.parse_file(temporary=temporary) state.compute_dependencies() - if manager.workers: - # We don't need parsed trees in coordinator process, we parse only to - # compute dependencies. Keep temporary tree until the caller uses it - if not temporary: - state.tree = None - del manager.modules[id] + if manager.workers and state.tree: + # We don't need imports in coordinator process anymore, we parse only to + # compute dependencies. + state.tree.imports = [] del manager.ast_cache[id] return state @@ -2522,6 +2555,9 @@ def __init__( self.add_ancestors() self.imports_ignored = imports_ignored self.size_hint = size_hint + # Pre-computed opaque value of suppressed_deps_opts() used + # to minimize amount of data sent to parallel workers. + self.known_suppressed_deps_opts: bytes | None = None def write(self, buf: WriteBuffer) -> None: """Serialize State for sending to build worker. @@ -2745,7 +2781,7 @@ def fix_cross_refs(self) -> None: # Methods for processing modules from source code. - def parse_file(self, *, temporary: bool = False) -> None: + def parse_file(self, *, temporary: bool = False, raw_data: FileRawData | None = None) -> None: """Parse file and run first pass of semantic analysis. Everything done here is local to the file. Don't depend on imported @@ -2769,7 +2805,6 @@ def parse_file(self, *, temporary: bool = False) -> None: with self.wrap_context(): source = self.source - self.source = None # We won't need it again. if self.path and source is None: try: path = manager.maybe_swap_for_shadow_path(self.path) @@ -2810,7 +2845,12 @@ def parse_file(self, *, temporary: bool = False) -> None: if not cached: ignore_errors = self.ignore_all or self.options.ignore_errors self.tree = manager.parse_file( - self.id, self.xpath, source, ignore_errors=ignore_errors, options=self.options + self.id, + self.xpath, + source, + ignore_errors=ignore_errors, + options=self.options, + raw_data=raw_data, ) else: # Reuse a cached AST @@ -2875,9 +2915,10 @@ def semantic_analysis_pass1(self) -> None: # # TODO: This should not be considered as a semantic analysis # pass -- it's an independent pass. - analyzer = SemanticAnalyzerPreAnalysis() - with self.wrap_context(): - analyzer.visit_file(self.tree, self.xpath, self.id, options) + if not options.native_parser: + analyzer = SemanticAnalyzerPreAnalysis() + with self.wrap_context(): + analyzer.visit_file(self.tree, self.xpath, self.id, options) # TODO: Do this while constructing the AST? self.tree.names = SymbolTable() if not self.tree.is_stub: @@ -3103,6 +3144,8 @@ def update_fine_grained_deps(self, deps: dict[str, set[str]]) -> None: def suppressed_deps_opts(self) -> bytes: if not self.suppressed: return b"" + if self.known_suppressed_deps_opts: + return self.known_suppressed_deps_opts buf = WriteBuffer() import_options = self.manager.import_options for dep in sorted(self.suppressed): @@ -3909,7 +3952,6 @@ def load_graph( if dep not in graph: st.suppress_dependency(dep) manager.plugin.set_modules(manager.modules) - manager.path_by_id = {id: graph[id].xpath for id in graph} manager.errors.global_watcher = False return graph @@ -4045,7 +4087,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: # Broadcast graph to workers before computing SCCs to save a bit of time. # TODO: check if we can optimize by sending only part of the graph needed for given SCC. # For example only send modules in the SCC and their dependencies. - graph_message = GraphMessage(graph=graph, missing_modules=set(manager.missing_modules)) + graph_message = GraphMessage(graph=graph, missing_modules=manager.missing_modules) buf = WriteBuffer() graph_message.write(buf) graph_data = buf.getvalue() @@ -4093,7 +4135,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: for scc in stale: for id in scc.mod_ids: graph[id].mark_as_rechecked() - manager.submit(stale) + manager.submit(graph, stale) still_working = True # We eagerly walk over fresh SCCs to reach as many stale SCCs as soon # as possible. Only when there are no fresh SCCs, we wait on scheduled stale ones. @@ -4219,7 +4261,9 @@ def process_stale_scc( if ( not manager.options.test_env and platform.python_implementation() == "CPython" - and manager.gc_freeze_cycles < MAX_GC_FREEZE_CYCLES + # Parallel workers perform loading in many smaller "pieces", so we + # should repeat the GC hack multiple times to actually benefit from it. + and (manager.gc_freeze_cycles < MAX_GC_FREEZE_CYCLES or manager.parallel_worker) ): # When deserializing cache we create huge amount of new objects, so even # with our generous GC thresholds, GC is still doing a lot of pointless @@ -4228,8 +4272,8 @@ def process_stale_scc( # generation with the freeze()/unfreeze() trick below. This is arguably # a hack, but it gives huge performance wins for large third-party # libraries, like torch. - if manager.gc_freeze_cycles > 0: - gc.collect() + gc.collect(generation=1) + gc.collect(generation=0) gc.disable() for prev_scc in fresh_sccs_to_load: manager.done_sccs.add(prev_scc.id) @@ -4237,7 +4281,7 @@ def process_stale_scc( if ( not manager.options.test_env and platform.python_implementation() == "CPython" - and manager.gc_freeze_cycles < MAX_GC_FREEZE_CYCLES + and (manager.gc_freeze_cycles < MAX_GC_FREEZE_CYCLES or manager.parallel_worker) ): manager.gc_freeze_cycles += 1 gc.freeze() @@ -4517,9 +4561,16 @@ class SccRequestMessage(IPCMessage): If scc_id is None, then it means that the coordinator requested a shutdown. """ - def __init__(self, *, scc_id: int | None, import_errors: dict[str, list[ErrorInfo]]) -> None: + def __init__( + self, + *, + scc_id: int | None, + import_errors: dict[str, list[ErrorInfo]], + mod_data: dict[str, tuple[bytes, FileRawData | None]], + ) -> None: self.scc_id = scc_id self.import_errors = import_errors + self.mod_data = mod_data @classmethod def read(cls, buf: ReadBuffer) -> SccRequestMessage: @@ -4530,6 +4581,13 @@ def read(cls, buf: ReadBuffer) -> SccRequestMessage: read_str(buf): [ErrorInfo.read(buf) for _ in range(read_int_bare(buf))] for _ in range(read_int_bare(buf)) }, + mod_data={ + read_str_bare(buf): ( + read_bytes(buf), + FileRawData.read(buf) if read_bool(buf) else None, + ) + for _ in range(read_int_bare(buf)) + }, ) def write(self, buf: WriteBuffer) -> None: @@ -4541,6 +4599,15 @@ def write(self, buf: WriteBuffer) -> None: write_int_bare(buf, len(errors)) for error in errors: error.write(buf) + write_int_bare(buf, len(self.mod_data)) + for mod, (suppressed_deps_opts, raw_data) in self.mod_data.items(): + write_str_bare(buf, mod) + write_bytes(buf, suppressed_deps_opts) + if raw_data is None: + write_bool(buf, False) + else: + write_bool(buf, True) + raw_data.write(buf) class SccResponseMessage(IPCMessage): @@ -4664,7 +4731,7 @@ def write(self, buf: WriteBuffer) -> None: class GraphMessage(IPCMessage): """A message wrapping the build graph computed by the coordinator.""" - def __init__(self, *, graph: Graph, missing_modules: set[str]) -> None: + def __init__(self, *, graph: Graph, missing_modules: dict[str, int]) -> None: self.graph = graph self.missing_modules = missing_modules # Send this data separately as it will be lost during state serialization. @@ -4675,7 +4742,7 @@ def read(cls, buf: ReadBuffer, manager: BuildManager | None = None) -> GraphMess assert manager is not None assert read_tag(buf) == GRAPH_MESSAGE graph = {read_str_bare(buf): State.read(buf, manager) for _ in range(read_int_bare(buf))} - missing_modules = {read_str_bare(buf) for _ in range(read_int_bare(buf))} + missing_modules = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))} message = GraphMessage(graph=graph, missing_modules=missing_modules) message.from_cache = {read_str_bare(buf) for _ in range(read_int_bare(buf))} return message @@ -4687,8 +4754,9 @@ def write(self, buf: WriteBuffer) -> None: write_str_bare(buf, mod_id) state.write(buf) write_int_bare(buf, len(self.missing_modules)) - for module in self.missing_modules: + for module, reason in self.missing_modules.items(): write_str_bare(buf, module) + write_int(buf, reason) write_int_bare(buf, len(self.from_cache)) for module in self.from_cache: write_str_bare(buf, module) diff --git a/mypy/build_worker/worker.py b/mypy/build_worker/worker.py index d5069731b54c..b35da8c412c7 100644 --- a/mypy/build_worker/worker.py +++ b/mypy/build_worker/worker.py @@ -27,22 +27,24 @@ from mypy import util from mypy.build import ( + SCC, AckMessage, BuildManager, + Graph, GraphMessage, SccRequestMessage, SccResponseMessage, SccsDataMessage, SourcesDataMessage, - load_graph, load_plugins, process_stale_scc, ) from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT -from mypy.errors import CompileError, Errors, report_internal_error +from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error from mypy.fscache import FileSystemCache from mypy.ipc import IPCException, IPCServer, receive, send from mypy.modulefinder import BuildSource, BuildSourceSet, compute_search_paths +from mypy.nodes import FileRawData from mypy.options import Options from mypy.util import read_py_file from mypy.version import __version__ @@ -123,42 +125,24 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: if manager is None: return - # Mirror the GC freeze hack in the coordinator. - if platform.python_implementation() == "CPython": - gc.disable() - try: - graph = load_graph(sources, manager) - except CompileError: - # CompileError during loading will be reported by the coordinator. - return - if platform.python_implementation() == "CPython": - gc.freeze() - gc.unfreeze() - gc.enable() - for id in graph: - manager.import_map[id] = graph[id].dependencies_set - # Ignore errors during local graph loading to check that receiving - # early errors from coordinator works correctly. - manager.errors.reset() - - # Notify worker we are done loading graph. + # Notify coordinator we are done with setup. send(server, AckMessage()) - - # Compare worker graph and coordinator, with parallel parser we will only use the latter. graph_data = GraphMessage.read(receive(server), manager) - assert set(manager.missing_modules) == graph_data.missing_modules - coordinator_graph = graph_data.graph - assert coordinator_graph.keys() == graph.keys() + # Update some manager data in-place as it has been passed to semantic analyzer. + manager.missing_modules |= graph_data.missing_modules + graph = graph_data.graph for id in graph: - assert graph[id].dependencies_set == coordinator_graph[id].dependencies_set - assert graph[id].suppressed_set == coordinator_graph[id].suppressed_set - send(server, AckMessage()) + manager.import_map[id] = graph[id].dependencies_set + # Link modules dicts, so that plugins will get access to ASTs as we parse them. + manager.plugin.set_modules(manager.modules) + # Notify coordinator we are ready to receive computed graph SCC structure. + send(server, AckMessage()) sccs = SccsDataMessage.read(receive(server)).sccs manager.scc_by_id = {scc.id: scc for scc in sccs} manager.top_order = [scc.id for scc in sccs] - # Notify coordinator we are ready to process SCCs. + # Notify coordinator we are ready to start processing SCCs. send(server, AckMessage()) while True: scc_message = SccRequestMessage.read(receive(server)) @@ -169,20 +153,17 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: scc = manager.scc_by_id[scc_id] t0 = time.time() try: - for id in scc.mod_ids: - state = graph[id] - # Extra if below is needed only because we are using local graph. - # TODO: clone options when switching to coordinator graph. - if state.tree is None: - # Parse early to get errors related data, such as ignored - # and skipped lines before replaying the errors. - state.parse_file() - else: - state.setup_errors() - if id in scc_message.import_errors: - manager.errors.set_file(state.xpath, id, state.options) - for err_info in scc_message.import_errors[id]: - manager.errors.add_error_info(err_info) + if platform.python_implementation() == "CPython": + # Since we are splitting the GC freeze hack into multiple smaller freezes, + # we should collect young generations to not accumulate accidental garbage. + gc.collect(generation=1) + gc.collect(generation=0) + gc.disable() + load_states(scc, graph, manager, scc_message.import_errors, scc_message.mod_data) + if platform.python_implementation() == "CPython": + gc.freeze() + gc.unfreeze() + gc.enable() result = process_stale_scc(graph, scc, manager, from_cache=graph_data.from_cache) # We must commit after each SCC, otherwise we break --sqlite-cache. manager.metastore.commit() @@ -193,6 +174,34 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1) +def load_states( + scc: SCC, + graph: Graph, + manager: BuildManager, + import_errors: dict[str, list[ErrorInfo]], + mod_data: dict[str, tuple[bytes, FileRawData | None]], +) -> None: + """Re-create full state of an SCC as it would have been in coordinator.""" + for id in scc.mod_ids: + state = graph[id] + # Re-clone options since we don't send them, it is usually faster than deserializing. + state.options = state.options.clone_for_module(state.id) + suppressed_deps_opts, raw_data = mod_data[id] + state.parse_file(raw_data=raw_data) + # Set data that is needed to be written to cache meta. + state.known_suppressed_deps_opts = suppressed_deps_opts + assert state.tree is not None + import_lines = {imp.line for imp in state.tree.imports} + state.imports_ignored = { + line: codes for line, codes in state.tree.ignored_lines.items() if line in import_lines + } + # Replay original errors encountered during graph loading in coordinator. + if id in import_errors: + manager.errors.set_file(state.xpath, id, state.options) + for err_info in import_errors[id]: + manager.errors.add_error_info(err_info) + + def setup_worker_manager(sources: list[BuildSource], ctx: ServerContext) -> BuildManager | None: data_dir = os.path.dirname(os.path.dirname(__file__)) # This is used for testing only now. diff --git a/mypy/cache.py b/mypy/cache.py index 528abdde01bf..5a1d6c79219e 100644 --- a/mypy/cache.py +++ b/mypy/cache.py @@ -262,6 +262,7 @@ def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None: LIST_BYTES: Final[Tag] = 23 TUPLE_GEN: Final[Tag] = 24 DICT_STR_GEN: Final[Tag] = 30 +DICT_INT_GEN: Final[Tag] = 31 # Misc classes. EXTRA_ATTRS: Final[Tag] = 150 diff --git a/mypy/main.py b/mypy/main.py index bc20eb38ca2b..14148720269a 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -97,6 +97,10 @@ def main( stdout, stderr, options.hide_error_codes, hide_success=bool(options.output) ) + if options.num_workers: + # Supporting both parsers would be really tricky, so just support the new one. + options.native_parser = True + if options.allow_redefinition_new and not options.local_partial_types: fail( "error: --local-partial-types must be enabled if using --allow-redefinition-new", diff --git a/mypy/nativeparse.py b/mypy/nativeparse.py index 5a6566c8a266..6b5d5306a659 100644 --- a/mypy/nativeparse.py +++ b/mypy/nativeparse.py @@ -76,6 +76,7 @@ EllipsisExpr, Expression, ExpressionStmt, + FileRawData, FloatExpr, ForStmt, FuncDef, @@ -193,7 +194,7 @@ def add_error( def native_parse( - filename: str, options: Options, skip_function_bodies: bool = False + filename: str, options: Options, skip_function_bodies: bool = False, imports_only: bool = False ) -> tuple[MypyFile, list[dict[str, Any]], TypeIgnores]: """Parse a Python file using the native Rust-based parser. @@ -206,6 +207,8 @@ def native_parse( skip_function_bodies: If True, many function and method bodies are omitted from the AST, useful for parsing stubs or extracting signatures without full implementation details + imports_only: If True create an empty MypyFile with actual serialized defs + stored in binary_data. Returns: A tuple containing: @@ -226,13 +229,18 @@ def native_parse( data = ReadBuffer(b) n = read_int(data) state = State(options) - defs = read_statements(state, data, n) + if imports_only: + defs = [] + else: + defs = read_statements(state, data, n) imports = deserialize_imports(import_bytes) node = MypyFile(defs, imports) node.path = filename node.is_partial_stub_package = is_partial_package + if imports_only: + node.raw_data = FileRawData(b, import_bytes, errors, dict(ignores), is_partial_package) # Merge deserialization errors with parsing errors all_errors = errors + state.errors return node, all_errors, ignores diff --git a/mypy/nodes.py b/mypy/nodes.py index a09094879843..21ee67647c59 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -30,6 +30,7 @@ import mypy.strconv from mypy.cache import ( + DICT_INT_GEN, DICT_STR_GEN, DT_SPEC, END_TAG, @@ -41,6 +42,7 @@ Tag, WriteBuffer, read_bool, + read_bytes, read_int, read_int_list, read_int_opt, @@ -52,6 +54,7 @@ read_str_opt_list, read_tag, write_bool, + write_bytes, write_int, write_int_list, write_int_opt, @@ -307,6 +310,56 @@ def read(cls, data: ReadBuffer) -> SymbolNode: Definition: _TypeAlias = tuple[str, "SymbolTableNode", Optional["TypeInfo"]] +class FileRawData: + """Raw (binary) data representing parsed, but not deserialized file.""" + + __slots__ = ("defs", "imports", "raw_errors", "ignored_lines", "is_partial_stub_package") + + defs: bytes + imports: bytes + raw_errors: list[dict[str, Any]] # TODO: switch to more precise type here. + ignored_lines: dict[int, list[str]] + is_partial_stub_package: bool + + def __init__( + self, + defs: bytes, + imports: bytes, + raw_errors: list[dict[str, Any]], + ignored_lines: dict[int, list[str]], + is_partial_stub_package: bool, + ) -> None: + self.defs = defs + self.imports = imports + self.raw_errors = raw_errors + self.ignored_lines = ignored_lines + self.is_partial_stub_package = is_partial_stub_package + + def write(self, data: WriteBuffer) -> None: + write_bytes(data, self.defs) + write_bytes(data, self.imports) + write_tag(data, LIST_GEN) + write_int_bare(data, len(self.raw_errors)) + for err in self.raw_errors: + write_json(data, err) + write_tag(data, DICT_INT_GEN) + write_int_bare(data, len(self.ignored_lines)) + for line, codes in self.ignored_lines.items(): + write_int(data, line) + write_str_list(data, codes) + write_bool(data, self.is_partial_stub_package) + + @classmethod + def read(cls, data: ReadBuffer) -> FileRawData: + defs = read_bytes(data) + imports = read_bytes(data) + assert read_tag(data) == LIST_GEN + raw_errors = [read_json(data) for _ in range(read_int_bare(data))] + assert read_tag(data) == DICT_INT_GEN + ignored_lines = {read_int(data): read_str_list(data) for _ in range(read_int_bare(data))} + return FileRawData(defs, imports, raw_errors, ignored_lines, read_bool(data)) + + class MypyFile(SymbolNode): """The abstract syntax tree of a single source file.""" @@ -328,6 +381,7 @@ class MypyFile(SymbolNode): "plugin_deps", "future_import_flags", "_is_typeshed_file", + "raw_data", ) __match_args__ = ("name", "path", "defs") @@ -370,6 +424,8 @@ class MypyFile(SymbolNode): # Future imports defined in this file. Populated during semantic analysis. future_import_flags: set[str] _is_typeshed_file: bool | None + # For native parser store actual serialized data here. + raw_data: FileRawData | None def __init__( self, @@ -400,6 +456,7 @@ def __init__( self.uses_template_strings = False self.future_import_flags = set() self._is_typeshed_file = None + self.raw_data = None def local_definitions(self) -> Iterator[Definition]: """Return all definitions within the module (including nested). diff --git a/mypy/parse.py b/mypy/parse.py index a87e786a2543..3caa881b31eb 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -3,9 +3,12 @@ import os import re +from librt.internal import ReadBuffer + from mypy import errorcodes as codes +from mypy.cache import read_int from mypy.errors import Errors -from mypy.nodes import MypyFile +from mypy.nodes import FileRawData, MypyFile from mypy.options import Options @@ -16,6 +19,7 @@ def parse( errors: Errors, options: Options, raise_on_error: bool = False, + imports_only: bool = False, ) -> MypyFile: """Parse a source file, without doing any semantic analysis. @@ -36,7 +40,10 @@ def parse( errors.set_file(fnam, module, options=options) tree, parse_errors, type_ignores = mypy.nativeparse.native_parse( - fnam, options, skip_function_bodies=strip_function_bodies + fnam, + options, + skip_function_bodies=strip_function_bodies, + imports_only=imports_only, ) # Convert type ignores list to dict tree.ignored_lines = dict(type_ignores) @@ -66,6 +73,7 @@ def parse( return tree # Fall through to fastparse for non-existent files + assert not imports_only if options.transform_source is not None: source = options.transform_source(source) import mypy.fastparse @@ -74,3 +82,43 @@ def parse( if raise_on_error and errors.is_errors(): errors.raise_error() return tree + + +def load_from_raw( + fnam: str, module: str | None, raw_data: FileRawData, errors: Errors, options: Options +) -> MypyFile: + """Load AST from parsed binary data. + + This essentially replicates parse() above but expects FileRawData instead of actually + parsing the source code in the file. + """ + from mypy.nativeparse import State, deserialize_imports, read_statements + + # This part mimics the logic in native_parse(). + data = ReadBuffer(raw_data.defs) + n = read_int(data) + state = State(options) + defs = read_statements(state, data, n) + imports = deserialize_imports(raw_data.imports) + + tree = MypyFile(defs, imports) + tree.path = fnam + tree.ignored_lines = raw_data.ignored_lines + tree.is_partial_stub_package = raw_data.is_partial_stub_package + tree.is_stub = fnam.endswith(".pyi") + + # Report parse errors, this replicates the logic in parse(). + all_errors = raw_data.raw_errors + state.errors + errors.set_file(fnam, module, options=options) + for error in all_errors: + message = error["message"] + message = re.sub(r"^(\s*\w)", lambda m: m.group(1).upper(), message) + is_blocker = error.get("blocker", True) + error_code = error.get("code") + if error_code is None: + error_code = codes.SYNTAX + else: + error_code = codes.error_codes.get(error_code) or codes.SYNTAX + # Note we never raise in this function, so it should not be called in coordinator. + errors.report(error["line"], error["column"], message, blocker=is_blocker, code=error_code) + return tree