Featured image of post AI-Chess

AI-Chess

记录一个ai象棋的项目

注意:有什么代码设计逻辑上解决不了的问题,尝试抽象出一个中间层或者控制器类型的东西看看能不能解决

记录一下毕设

环境配置

在工程目录下创建conda虚拟环境: 使用 --prefix,环境会被创建在你指定的精确路径(这里是项目文件夹下的 env 子目录)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# 推荐:直接在工程文件夹内创建环境,当前已经处于工程目录下
conda create --prefix ./env python=3.9

conda activate ./env

# 导出环境配置(方便共享):
conda env export > environment.yml

# 从 YAML 文件创建环境:
conda env create -f environment.yml

随着项目增多,每个项目都创建独立的 Conda 环境确实会占用大量磁盘空间,小型数据分析项目可共用一个环境,核心项目使用单独环境。

通过克隆基础环境减少重复安装:

1
2
3
4
5
6
7
# 创建基础环境
conda create --name base_py39 python=3.9

# 克隆出项目专用环境(共享基础包)
conda create --name projectA --clone base_py39
conda activate projectA
conda install packageA  # 仅安装项目特有包

注意--name是创建全局环境的,而--prefix是项目内环境的,不需要名字。

PyTorch库结构

由于每次用PyTorch导入,不知道应该导入哪一个是什么功能,记录以下PyTorch库的结构。

1. 核心模块

  • torch: 这是 PyTorch 的核心包,提供了所有基本的数据结构(如张量)和张量操作。
    • torch.Tensor: PyTorch 中最基本的数据结构,类似于 NumPy 的 ndarray,但支持 GPU 加速和自动微分。
    • torch.autograd: 实现了自动微分系统。当你对张量执行操作时,PyTorch 会构建一个计算图。autograd 利用这个图来自动计算梯度。
      • torch.autograd.Function: 用户可以自定义操作并提供其前向和反向传播的实现。
    • torch.nn: 这是构建神经网络的核心模块。它提供了各种预定义的层(如 Linear, Conv2d, ReLU 等)、损失函数(如 MSELoss, CrossEntropyLoss 等)和模型容器(如 Module, Sequential)。
      • torch.nn.Module: 所有神经网络模块的基类。你的自定义模型都应该继承自它。
      • torch.nn.functional: 包含了函数式的神经网络操作,这些操作没有自己的可学习参数(例如 F.relu, F.softmax)。
    • torch.optim: 包含了各种优化算法(如 SGD, Adam, RMSprop 等),用于更新模型的参数。
      • torch.optim.Optimizer: 所有优化器的基类。
    • torch.cuda: 提供了 CUDA 支持,允许你在 NVIDIA GPU 上进行张量操作和模型训练。
      • torch.cuda.is_available(): 检查 CUDA 是否可用。
      • torch.cuda.empty_cache(): 清空 CUDA 缓存。
    • torch.linalg: 提供了线性代数操作,例如矩阵乘法、分解、特征值计算等。
    • torch.jit: 提供了 JIT (Just-In-Time) 编译器,可以将 PyTorch 模型转换为可序列化的、优化的表示,用于推理部署。
      • torch.jit.script: 将 Python 代码转换为 TorchScript。
      • torch.jit.trace: 通过执行示例输入来记录模型的计算图。

2. 实用工具和领域特定库

  • torch.utils: 提供了各种实用工具。
    • torch.utils.data: 数据加载和处理工具。
      • torch.utils.data.Dataset: 用于封装数据样本及其标签的抽象类。
      • torch.utils.data.DataLoader: 迭代数据集并提供批量数据加载、并行加载和数据混洗等功能。
    • torch.utils.tensorboard: 集成了 TensorBoard,用于可视化训练过程。
    • torch.utils.bottleneck: 用于性能分析和调试。
  • torchvision: PyTorch 官方提供的计算机视觉库。
    • torchvision.datasets: 包含常用视觉数据集(如 ImageNet, CIFAR10, MNIST)。
    • torchvision.models: 包含预训练的计算机视觉模型(如 ResNet, VGG, AlexNet)。
    • torchvision.transforms: 图像预处理和数据增强变换。
    • torchvision.utils: 图像保存、网格显示等实用工具。
  • torchaudio: PyTorch 官方提供的音频处理库。
    • 包含音频数据集、模型和转换工具。
  • torchtext: PyTorch 官方提供的自然语言处理 (NLP) 库。
    • 包含文本数据集、词汇表、词嵌入和常用的 NLP 模型。
  • torch_geometric: (第三方库,但非常流行) 用于图神经网络 (GNN)。
  • torch_scatter: (第三方库,但非常流行) 提供散列操作,常与 torch_geometric 配合使用。
  • torchmetrics: (第三方库,但非常流行) 提供各种机器学习指标的实现。

3. 分布式训练

  • torch.distributed: 用于多 GPU 或多节点分布式训练。
    • torch.distributed.init_process_group: 初始化分布式环境。
    • torch.distributed.all_reduce, torch.distributed.broadcast: 实现张量在不同进程间的通信。
  • torch.nn.parallel: 提供一些简单的并行化策略。
    • torch.nn.DataParallel: 单机多 GPU 数据并行。
    • torch.nn.parallel.DistributedDataParallel (DDP): 推荐的多 GPU/多节点并行化方法,更高效、更灵活。

4. 混合精度训练

  • torch.amp: (Automatic Mixed Precision) 用于自动混合精度训练,利用 FP16 和 FP32 混合计算,以加速训练和减少内存消耗。
    • torch.amp.autocast: 自动进行类型转换的上下文管理器。
    • torch.cuda.amp.GradScaler: 用于解决混合精度训练中梯度过小的问题。

5. 高级特性和生态系统

  • torch.compile (PyTorch 2.0+): 引入的编译功能,可以显著提升模型训练和推理的速度,通过将 PyTorch 代码编译成优化的图表示。
  • TorchScript: PyTorch 的中间表示 (IR),允许你将模型从 Python 转换为一个可独立运行的图表示,方便部署到 C++ 或移动设备。
  • ONNX (Open Neural Network Exchange): PyTorch 可以方便地导出模型到 ONNX 格式,这是一种开放标准,允许在不同深度学习框架之间进行模型互操作。
  • PyTorch Lightning: 一个轻量级的 PyTorch 封装,提供了一个高级接口,用于管理训练循环、分布式训练、日志记录等,从而减少样板代码。
  • Hugging Face Transformers: 虽然不是 PyTorch 核心库的一部分,但它是 PyTorch 生态系统的重要组成部分,提供了大量预训练的 NLP 模型和工具。

整体架构

主要说来,程序主要由两个部分组成:

  1. 神经网络和MCTS组成的算法
  2. 根据象棋规则来生成动作

主要算法

主要算法就是AlphaGo Zero提出的算法。AlphaZero 的自我对弈(Self-Play)更新机制是其核心创新之一,它使得 AlphaZero 能够从零开始,在没有人类专家数据的情况下,学习并超越人类顶尖棋手的水平。这个机制是强化学习的一个典范。

主要流程为以下几个步骤:

  1. 神经网络初始化
    • 训练开始时,神经网络的参数是随机初始化的
    • 神经网络双分支输出,一个策略(policy)输出$p$,一个价值(value)输出$v$
      • 策略(Policy)输出 p: 预测在当前状态下,每个合法走法的概率分布。
      • 价值(Value)输出 v: 预测当前状态下,当前玩家最终获胜的概率(或预期回报)。
  2. 自我对弈生成训练数据:
    • 在每轮自我对弈中,程序会利用当前的神经网络来指导蒙特卡洛树搜索(MCTS)进行决策,并生成一盘完整的对局。
    • 蒙特卡洛树搜索(MCTS)的迭代过程:
      • 选择(Selection): 从根节点(当前局面)开始,根据 MCTS 的 UCB(Upper Confidence Bound)或 PUCB(Polynomial Upper Confidence Bound)公式(结合了神经网络的先验概率 p 和访问次数 N)选择一条路径,直到达到一个未完全扩展的节点。
      • 扩展(Expansion): 如果到达的节点不是一个终止局面,就使用神经网络对该节点进行一次评估,得到该局面的策略 p 和价值 v。同时,将这个新节点及其子节点添加到搜索树中。
      • 模拟(Simulation,在AlphaZero中通常省略或简化): 但在AlphaZero中,神经网络的价值输出 v 直接取代了随机模拟,提供了更准确的估计。
      • 反向传播(Backpropagation): 将神经网络评估得到的价值 v 以及对局的最终结果(胜利或失败,通常为 +1 或 -1)沿着选择路径向上回传,更新路径上所有节点的访问次数 N、总价值 Q 和平均价值 W。
    • 生成走法策略 π: 在 MCTS 进行了一定数量的模拟(例如800次)后,不再直接使用神经网络的策略输出 p 来选择走法。而是根据 MCTS 树中每个走法的访问次数 N(s,a) 来生成一个更强大的走法策略 π,通常是按访问次数的幂次方(如 N(s,a)1/t,其中 t 是一个温度参数,用于控制探索与利用的平衡)进行归一化。访问次数越多的走法,表明 MCTS 认为该走法越有潜力。
    • 收集训练样本: 每一步棋的 (状态 $s_t$​, MCTS 策略 $π_t$​, 最终胜负 z) 作为一个训练样本被收集起来。其中 $s_t$​ 是当前局面,$π_t$​ 是由 MCTS 产生的走法概率分布, z 是最终游戏结果(胜利为 +1,失败为 -1,和棋为 0)。
  3. 神经网络训练
    • 收集到足够的数据就能展开训练了,就是深度学习中最小化损失函数的过程。
  4. 模型更新与迭代
    • 每次训练我们希望得到的是新的强化过程,也就是比之前更强的模型。训练好的新神经网络会与旧的神经网络进行评估。如果新网络表现更好(通常通过在竞技场中进行对局来判断,胜率高过一个阈值就更新),它就会取代旧网络,成为下一轮自我对弈生成数据的基准模型。
    • 这个过程持续迭代:自我对弈生成数据 -> 训练更新神经网络 -> 评估新旧模型 -> 替换模型。

杂项

清零梯度

我们对一个批次训练时每次都要清零梯度。原因是PyTorch 的设计理念是,默认情况下,梯度是累积的(accumulated)。这意味着当你调用 loss.backward() 时,新计算出的梯度会加到张量(torch.Tensor)的 .grad 属性中(如果该张量是模型的可学习参数)。 举个例子:

  1. 第一次 loss.backward():计算并生成梯度 G1​,将其存储在参数的 .grad 属性中。此时,param.grad 等于 G1​。
  2. 第二次 loss.backward():计算并生成梯度 G2​。如果此时不清零,PyTorch 会将 G2​ 加到现有的 param.grad 上。所以,param.grad 会变成 G1​+G2​。

唯一例外:梯度累积(Gradient Accumulation)

梯度积累技术是唯一一个有意不清零梯度的场景。在这种情况下,我们确实需要将多个小批次的梯度累加起来,以模拟更大的有效批量大小。但即使在这种情况下,当累积到一定步数并执行 optimizer.step() 后,我们仍然会调用 optimizer.zero_grad() 来清零,为下一轮的梯度积累做准备。

PyTorch训练一个批次的流程

每个批次的训练步骤:

  1. 设置模型为训练模式: model.train() (非常重要,影响 Dropout 和 BatchNorm 的行为)。
  2. 获取输入和目标: 从数据加载器中获取一个批次的输入数据和对应的真实标签。
    • 设备转移: 将数据移动到与模型相同的设备上 (CPU 或 GPU):inputs, labels = inputs.to(device), labels.to(device)
  3. 清零梯度: optimizer.zero_grad()。在每次迭代前,清除上次计算的梯度,因为 PyTorch 默认会累积梯度。
  4. 前向传播 (Forward Pass): 将输入数据输入模型,获得模型的预测输出。
    • outputs = model(inputs)
  5. 计算损失 (Calculate Loss): 将模型的输出与真实标签输入损失函数,计算当前的损失值。
    • loss = criterion(outputs, labels)
  6. 反向传播 (Backward Pass): loss.backward()。根据损失值,自动计算所有可学习参数的梯度。
  7. 参数更新 (Optimizer Step): optimizer.step()。使用计算出的梯度来更新模型的参数。
  8. (可选)学习率调度: scheduler.step() (如果使用了学习率调度器)。

detach分离计算图

detach() 是 PyTorch 中 torch.Tensor 对象的一个方法,它的核心作用是将一个张量从当前的计算图中分离出来,使其不再跟踪梯度

  1. 停止跟踪梯度(Stop Tracking Gradients)

    • detach() 后的张量,其 requires_grad 属性会变为 False
    • 即使原始张量 xrequires_gradTruex.detach() 得到的张量 yrequires_grad 也会是 False
    • 当对 y 进行后续操作时,这些操作将不再被记录到计算图中,也不会为 y 或其后续操作的张量计算梯度。
  2. 共享底层存储(Share Underlying Storage)

    • detach() 返回的张量与原始张量共享相同的底层数据存储。
    • 这意味着,如果你修改了其中一个张量的数据(例如,通过 y.add_(1) 进行原地操作),那么另一个张量的数据也会相应改变。
    • 但是,这种数据共享是单向的:对分离出的张量 y 的任何会导致其值改变的操作(如 y.add_(1)),虽然会修改原始张量 x 的值,但不会在 x 的计算图中记录这个操作。PyTorch 会在反向传播时检测到这种“原地修改”并报错,因为它无法正确地计算梯度。

防止不必要的梯度计算和内存开销

  • 在某些情况下,你可能需要使用某个张量的值进行计算,但你并不希望这个计算过程被记录到计算图中,也不希望为这个张量计算梯度。
  • 例如,在训练循环中,如果你想记录损失值但不希望损失值的计算过程影响到模型参数的梯度计算,你可以 loss.detach().item()
  • 另一个例子是,当某个中间结果在后续计算中不需要通过反向传播来更新其上游参数时。

多进程

由于自我对弈是CPU上计算的,显然可以采用多线程或者多进程。但由于Python存在GIL锁的机制,无法真正实现多线程。全局解释器锁(GIL, Global Interpreter Lock)是 Python(尤其是 CPython 解释器)中的一个机制。它的作用是同一时刻只允许一个线程执行 Python 字节码,即使你在多核 CPU 上开启了多个线程,实际上同一时刻只有一个线程在执行 Python 代码。

  • 任何线程在执行 Python 代码前,必须先获得 GIL。
  • 这意味着即使你创建了多个线程,这些线程也不能真正并行地执行 Python 代码(C 扩展库释放 GIL 时除外)。

所以Python的多线程适用于I/O密集型任务:

  • I/O 密集型任务(如网络请求、文件读写、数据库操作等)大部分时间都在等待外部资源响应,而不是消耗 CPU。
  • 当线程遇到 I/O 操作时,GIL 会被释放,其他线程可以获得 GIL 并继续执行。
  • 这样,多线程可以在等待 I/O 的间隙切换执行,提高程序整体的资源利用率和吞吐量。
  • 典型场景:爬虫、网络服务器、日志处理等。

而对于本任务是计算密集型的,不适合使用多线程,采用多进程

  • multiprocessing 模块通过创建多个独立的 Python 进程,每个进程都有自己的 Python 解释器和内存空间,各自拥有独立的 GIL。
  • 这样可以实现真正的多核并行运算,充分利用多核 CPU 的计算能力。
  • 适合 CPU 密集型任务,如自我对弈、神经网络推理、数据处理等。

如果是多线程,线程与线程之间属于同一个进程的话,是同属一个内存空间的,共享全局变量和内存。而如果是多进程,进程间内存独立,数据不能直接共享,需要用队列、管道、共享内存等方式通信,开销较大。

日志模块

logging 模块采用模块化设计,主要包含以下四类组件:

  1. Logger (记录器): 这是应用程序代码直接使用的接口。你可以通过 logging.getLogger(name) 来获取一个 Logger 实例。name 参数是可选的,如果提供,则会创建一个具名的 Logger;如果不提供,则会返回根 Logger (root logger)。Logger 有不同的日志级别,只有级别高于或等于 Logger 设定阈值的日志消息才会被处理。

  2. Handler (处理器): 处理器负责将 Logger 创建的日志记录(LogRecord)发送到适当的目标。常见的 Handler 有:

    • StreamHandler: 将日志输出到控制台(标准输出或标准错误)。
    • FileHandler: 将日志输出到文件。
    • RotatingFileHandler: 类似于 FileHandler,但会在文件达到一定大小时自动轮转(创建新的日志文件)。
    • TimedRotatingFileHandler: 类似于 RotatingFileHandler,但会根据时间(例如每天、每周)进行日志轮转。
    • SMTPHandler: 将日志通过电子邮件发送。
    • HTTPHandler: 将日志通过 HTTP GET 或 POST 请求发送到 Web 服务器。
  3. Formatter (格式化器): 格式化器指定最终输出中日志记录的样式。你可以定义日志消息的格式,包括时间、文件名、行号、日志级别、消息内容等。

  4. Filter (过滤器): 过滤器提供了更细粒度的控制,用于决定哪些日志记录应该被输出。你可以在 Logger 或 Handler 上添加过滤器,以进一步筛选日志。

日志级别 (Logging Levels)

logging 模块定义了以下标准日志级别(从低到高):

  • DEBUG (10): 详细的调试信息,通常只在开发阶段使用。
  • INFO (20): 确认程序按预期工作。
  • WARNING (30): 表示发生了一些意外事件,或将来可能会出现问题(但程序仍在正常运行)。这是默认的级别。
  • ERROR (40): 表示由于严重问题,程序无法执行某些功能。
  • CRITICAL (50): 表示发生了非常严重的错误,程序可能无法继续运行。

层级结构

Logger 的层级结构 (Hierarchical Loggers): 当你使用 logging.getLogger(__name__) 时,__name__ 会是模块的完全限定名。例如,main.py 中的 Logger 名称是 'main'module_a.py 中的 Logger 名称是 'module_a'my_package.sub_module 中的 Logger 名称是 'my_package.sub_module'logging 模块有一个层级结构,Logger 会将其日志消息传递(“传播”)给它们的父 Logger,直到根 Logger。根 Logger 是所有 Logger 的祖先。 在 main.py 中,我们配置了根 Logger。因此,module_amodule_b 中 Logger 发出的日志消息会向上冒泡到根 Logger,然后被根 Logger 的 Handler 处理,最终输出到控制台和文件。

logging 模块的 Logger 形成一个树状的层级结构。这种层级结构是通过 Logger 名称中的点号(.)来表示的。

  • 根 Logger (Root Logger):

    • logging.getLogger()(不带任何参数)返回的是根 Logger。它是所有其他 Logger 的祖先。
    • 它的名称是一个空字符串 ''
  • 具名 Logger (Named Loggers):

    • 当你调用 logging.getLogger('some_name') 时,你获取的是一个具名 Logger。
    • 如果 Logger 的名称包含点号,那么点号前面的部分就是它的父 Logger 的名称。

父子关系的确定规则:

  1. 名称是前缀: 如果一个 Logger 的名称是另一个 Logger 名称的点分隔前缀,那么前者就是后者的父 Logger。
    • 例如:Logger 'a.b' 的父 Logger 是 'a'。Logger 'a' 的父 Logger 是根 Logger ''
  2. 根 Logger 是所有 Logger 的祖先: 所有的具名 Logger 最终都追溯到根 Logger。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17

import logging

# 1. 根 Logger
root_logger = logging.getLogger() # 名称是 ''

# 2. 具名 Logger 'app'
app_logger = logging.getLogger('app')

# 3. 具名 Logger 'app.module_a'
module_a_logger = logging.getLogger('app.module_a')

# 4. 具名 Logger 'app.module_b'
module_b_logger = logging.getLogger('app.module_b')

# 5. 具名 Logger 'app.module_a.sub_module'
sub_module_logger = logging.getLogger('app.module_a.sub_module')

它们之间的父子关系如下:

  • root_logger (名称 '') 是所有 Logger 的父 Logger。
  • app_logger (名称 'app') 的父 Logger 是 root_logger
  • module_a_logger (名称 'app.module_a') 的父 Logger 是 app_logger
  • module_b_logger (名称 'app.module_b') 的父 Logger 是 app_logger
  • sub_module_logger (名称 'app.module_a.sub_module') 的父 Logger 是 module_a_logger

要点:

  • 日志传播 (Propagation): 默认情况下,子 Logger 会将它收到的日志消息传递给它的父 Logger,这个过程会一直持续到根 Logger。根 Logger 会将这些日志消息传递给它所关联的所有 Handler 进行实际的输出。
  • __name__ 的作用: 在模块文件中使用 logging.getLogger(__name__) 是一个最佳实践。__name__ 会自动设置为当前模块的完全限定名(例如,如果文件是 my_package/my_module.py,那么 __name__ 就是 'my_package.my_module')。这使得日志的层级结构自然地与你的代码模块结构对应起来,便于管理和追踪日志来源。

记录一个BUG

报错信息是这样的:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
Traceback (most recent call last):
  File "D:\PycharmProjects\AlphaZero\main.py", line 66, in <module>
    train_p.start()
  File "D:\app\anaconda3\Lib\multiprocessing\process.py", line 121, in start
    self._popen = self._Popen(self)
                  ^^^^^^^^^^^^^^^^^
  File "D:\app\anaconda3\Lib\multiprocessing\context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "D:\app\anaconda3\Lib\multiprocessing\context.py", line 337, in _Popen
    return Popen(process_obj)
           ^^^^^^^^^^^^^^^^^^
  File "D:\app\anaconda3\Lib\multiprocessing\popen_spawn_win32.py", line 95, in __init__
    reduction.dump(process_obj, to_child)
  File "D:\app\anaconda3\Lib\multiprocessing\reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
TypeError: cannot pickle '_thread.RLock' object
Exception in thread Thread-1 (_monitor):
Traceback (most recent call last):
  File "D:\app\anaconda3\Lib\multiprocessing\connection.py", line 328, in _recv_bytes

关键在于TypeError: cannot pickle '_thread.RLock' object Exception in thread Thread-1 (_monitor):。这个错误发生在 multiprocessing 模块尝试序列化 (pickle) 对象以在进程间传递时。

错误的根本原因

multiprocessing 模块在 Windows 系统上默认使用 “spawn” 启动方式。当它启动一个新进程时,它需要将父进程中的一些对象序列化(使用 pickle 模块)并传递给子进程。

_thread.RLock 是一个递归锁 (Reentrant Lock) 对象,它通常用于线程同步。RLock 对象是不能被 pickle 化的。 这意味着你不能直接将一个包含 RLock 对象的 Logger 或 Handler 实例传递给子进程。

在 Windows 系统上,multiprocessing 模块默认使用 “spawn” 或 “forkserver” 启动方法(而不是 Unix 上的 “fork”)。

  • spawn (Windows 默认,macOS 也建议使用): 子进程是通过导入父进程模块来启动的。这意味着当一个子进程启动时,它会重新执行父进程模块中的所有代码。
  • fork (Unix/Linux 默认): 子进程是父进程的副本,它会继承父进程的所有内存空间和打开的文件描述符。

如果在 if __name__ == "__main__": 块之外定义了像 FileHandlerStreamHandler 这样的对象,那么当子进程重新导入模块时,它们也会被重新创建。这些 Handler 内部可能包含 RLock 或其他不可 pickle 的对象,导致 TypeError: cannot pickle '_thread.RLock' object

我在 TrainPipeline 类的 __init__ 方法中直接实例化了 logging.handlers.QueueHandler 并将其添加到了 self.root Logger。当主进程启动一个子进程时,它需要将子进程要执行的 target 函数以及 args 中传递的所有对象进行序列化 (pickle),然后传递给新进程。

1
2
training_pipeline = train.TrainPipeline(init_model='current_policy.pkl', log_queue=log_queue)
train_p = Process(target=training_pipeline.run, args=(lock,)) # <--- 问题在这!

原因分析:

  1. training_pipeline = train.TrainPipeline(...) 这一行代码在 主进程if __name__ == '__main__': 块内执行。这意味着 TrainPipeline 的一个实例 training_pipeline 在主进程中被创建了。
  2. TrainPipeline__init__ 方法:TrainPipeline__init__ 方法中,有以下代码:
1
2
3
4
5
    self.root = logging.getLogger()
    self.root.setLevel(logging.DEBUG)
    self.qh = logging.handlers.QueueHandler(log_queue)
    self.root.addHandler(self.qh)
    self.worker_logger = logging.getLogger(__name__)

training_pipeline 在主进程中被实例化时,self.root 获取的是主进程的根 Logger。虽然 QueueHandler 本身可以被 pickle,但 logging.getLogger() 获取的 Logger 对象(尤其是根 Logger)及其内部状态(包括可能关联的 Manager 对象),在某些情况下可能包含不可 pickle 的内部锁对象(例如 _thread.RLock)。 3. target=training_pipeline.run 当将一个实例方法training_pipeline.run)作为 Processtarget 时,multiprocessing 模块会隐式地尝试序列化 training_pipeline 这个实例本身,以便在子进程中重新构建它并调用其 run 方法。如果 TrainPipeline 实例本身被当作参数传递给 Process,或者 run 方法被当作 targetself 隐式传递,那么 TrainPipeline 实例及其所有属性(包括 self.rootself.qh)都将被尝试 pickle。 - 由于 training_pipeline 实例在主进程中创建时,其属性(如 self.rootself.qh)已经关联了主进程的日志系统内部对象(可能包含 RLock),因此在尝试序列化 training_pipeline 实例时,就会遇到 TypeError: cannot pickle '_thread.RLock' object

解决方案

核心思想是:所有涉及到进程间通信(如 QueueHandlerQueueListener)的日志配置,以及那些会在子进程中运行的类实例的创建,都应该在子进程的上下文中完成。

  1. 创建包装函数作为 Processtarget 而不是直接将 training_pipeline.run 作为 target,而是创建一个新的函数,这个函数将作为 Processtarget。在这个新函数内部,再实例化 TrainPipelineCollectPipeline,并调用它们的 run 方法。

  2. log_queue 作为参数传递: 确保 log_queue 被正确地作为参数传递给子进程的 target 函数,以便子进程能够使用它来配置自己的 QueueHandler

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# 为训练进程定义一个包装函数
# main.py

# 为训练进程定义一个包装函数
def run_train_pipeline_in_process(log_queue, init_model_path, shared_lock):
    """
    此函数将在一个单独的进程中运行。
    TrainPipeline 的实例化及其日志配置在此进程内部完成。
    """
    # 在子进程中实例化 TrainPipeline
    train_pipeline = train.TrainPipeline(log_queue=log_queue, init_model=init_model_path)
    # 调用 TrainPipeline 的 run 方法,并传入共享锁
    train_pipeline.run(shared_lock)

# 为数据收集进程定义一个包装函数
def run_collect_pipeline_in_process(process_id, init_model_path, shared_lock, log_queue):
    """
    此函数将在一个单独的进程中运行。
    CollectPipeline 的实例化及其日志配置在此进程内部完成。
    """
    # collect.CollectPipeline 的 __init__ 方法可能也需要 log_queue
    # 如果 collect.py 也有自己的日志配置,它也应该像 train.py 一样接收 log_queue
    # 这里假设 collect.py 也会使用 QueueHandler,所以需要传入 log_queue
    # 注意:你需要修改 collect.py 的 CollectPipeline.__init__ 以接受 log_queue
    collecting_pipeline = collect.CollectPipeline(process_id, init_model=init_model_path, log_queue=log_queue)
    collecting_pipeline.run(shared_lock)

PyQt5

PyQt5 基本概念

  1. QApplication: 每个 PyQt5 应用程序都需要一个 QApplication 对象。它负责处理事件循环、命令行参数解析等。
  2. QWidget: 所有用户界面对象的基类。它可以是窗口、按钮、标签等。
  3. QtCore : 主要和时间、文件与文件夹、各种数据、流、URLs、mime 类文件、进程与线程一起使用。
  4. QtGui : 图形用户界面组件,包含了窗口系统、事件处理、2D 图像、基本绘画、字体和文字类。
  5. 布局管理器 (Layout Managers): 用于组织和排列窗口中的小部件。常见的有 QVBoxLayout (垂直布局)、QHBoxLayout (水平布局)、QGridLayout (网格布局) 等。
  6. 信号与槽 (Signals and Slots): 这是 PyQt5 的核心机制。当一个事件发生时(例如按钮被点击),会发出一个“信号”,你可以将这个信号连接到某个“槽”函数上,槽函数会在信号发出时被执行。
  7. 事件处理 (Event Handling): PyQt5 通过事件循环来处理用户的输入和其他系统事件。
  8. 绘图 (Painting): 你可以使用 QPainterQWidget 上绘制图形、文本和图像。

在 Qt 框架中,对象的继承关系是其核心设计模式之一,它构建了 Qt 强大的功能和灵活的架构。理解 Qt 的继承体系对于有效使用 Qt 进行开发至关重要。

Qt 对象模型的基石:QObject

Qt 中所有可交互的、支持信号与槽机制的对象都继承自一个共同的基类:QObject

QObject 的关键特性:

  1. 信号与槽(Signals & Slots): 这是 Qt 独有的机制,用于对象之间的通信。QObject 提供了实现这一机制所需的基础设施。任何继承自 QObject 的类都可以定义信号和槽,实现解耦的事件处理。
  2. 对象树(Object Tree): QObject 对象可以组织成一个父子层次结构,形成一个“对象树”。当父对象被删除时,它的所有子对象也会自动被删除(即子对象会被 deleteLater() 标记为删除),这有助于管理内存,避免内存泄漏。
    • 例如,你在窗口上放置一个按钮,按钮就是窗口的子对象。当你关闭窗口时,按钮也会随之销毁。
  3. 属性系统(Property System): QObject 提供了一个元对象系统(Meta-Object System),允许在运行时查询对象的属性、调用槽、发出信号等。属性系统允许你定义自定义属性,并进行持久化、动画等操作。
  4. 动态类型信息: 运行时类型信息(RTTI)在 Qt 中通过 qobject_cast<T>()inherits() 等方法实现,允许你在运行时安全地进行类型转换和检查。

Qt 主要的继承层次

QObject 派生出了 Qt 应用程序中常见的各种类:

  1. QObject (基类)

    • 所有需要信号与槽、对象树等功能的类都继承自它。
  2. QObjectQPaintDevice

    • QPaintDevice 是所有可以被 QPainter 绘制的对象的基类。它定义了绘图操作所需的基本接口。
    • 例子:QPixmap, QImage, QWidget, QPrinter 等。
  3. QObjectQWidget

    • QWidget 是所有用户界面(UI)对象的基类。它代表了一个可见的、可以接收鼠标和键盘事件的矩形区域。
    • QWidget 继承自 QObject (因此支持信号与槽、对象树) 和 QPaintDevice (因此可以被绘制)。
    • 它提供基本的几何管理(大小、位置)、事件处理(鼠标、键盘)、绘画事件等。
    • 重要子类(构成大部分 GUI 界面):
      • QMainWindow: 提供一个带有菜单栏、工具栏、状态栏和中心部件的主应用程序窗口。
      • QDialog: 用于弹出对话框,例如文件选择对话框、设置对话框等。
      • QPushButton: 按钮。
      • QLabel: 显示文本或图片。
      • QLineEdit: 单行文本输入框。
      • QTextEdit: 多行文本编辑器。
      • QCheckBox: 复选框。
      • QRadioButton: 单选按钮。
      • QComboBox: 下拉列表。
      • QListWidget, QTableWidget, QTreeWidget: 用于显示列表、表格和树形结构数据。
      • 布局管理器(虽然它们不是 QWidget 的子类,但与 QWidget 紧密协作):QVBoxLayout, QHBoxLayout, QGridLayout, QFormLayout 等。它们是 QLayout 的子类,而 QLayout 又是 QObject 的子类。
  4. QObjectQAbstractItemModel / QAbstractItemView (模型/视图框架)

    • Qt 的模型/视图编程提供了一种将数据(模型)和显示数据的方式(视图)分离的强大机制。
    • QAbstractItemModel 是所有数据模型的抽象基类。
    • QAbstractItemView 是所有视图的抽象基类。
    • 例子:
      • 模型: QStringListModel, QStandardItemModel (自定义数据模型通常会继承 QAbstractItemModel 或其子类)。
      • 视图: QListView, QTableView, QTreeView (这些视图类也继承自 QWidget)。
  5. QObjectQNetworkAccessManager (网络)

    • 用于执行网络请求,如 HTTP、FTP 等。
  6. QObjectQTimer (定时器)

    • 用于在指定时间间隔后发出信号。

…还有许多其他领域特定的类,如数据库(QSqlDatabase)、多媒体(QMediaPlayer)、图形视图框架(QGraphicsView, QGraphicsScene)等,它们也大多直接或间接继承自 QObject

继承关系图示(简化版)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
                     QObject
                       |
        +----------------------------------------+
        |                |                       |
QPaintDevice         QLayout                 QNetworkAccessManager
        |                |                       |
        |                |                       |
        +-- QWidget -----+                       |
               |                                 |
  +------------+------------+                    |
  |            |            |                    |
QMainWindow    QDialog    QPushButton, QLabel,   |
                           QLineEdit, ...        |
                                                 |
                       QAbstractItemModel        QTimer
                       QAbstractItemView
为什么 Qt 采用这种继承结构?
  1. 统一的事件处理: 所有的 QObject 实例都可以参与到 Qt 的事件循环中,通过信号与槽进行通信。
  2. 内存管理: 对象树简化了内存管理,减少了程序员手动管理对象生命周期的负担。
  3. 可扩展性: 通过继承,开发者可以在现有控件的基础上轻松创建自定义控件,复用大量现有功能。
  4. 元对象系统: 这是 Qt 强大的动态特性(如属性系统、翻译、信号与槽的反射)的基础,而 QObject 及其继承体系是启用这些特性的关键。
  5. 一致的 API: 开发者在 Qt 应用程序中与不同类型的对象交互时,会发现其 API 风格和行为非常一致,降低了学习曲线。

理解 Qt 的继承关系有助于你更好地选择合适的基类来开发自己的组件,利用 Qt 提供的丰富功能,并遵循其推荐的设计模式。

在 Qt 的对象模型中,父子关系(Parent-Child Relationship) 是通过 QObject 类及其派生类来确立的,它是 Qt 内存管理和对象生命周期管理的核心机制之一。这种关系主要通过两种方式确立:

  1. 在构造函数中指定父对象(最常见且推荐的方式)
  2. 通过 setParent() 方法设置父对象

父对象

在构造函数中指定父对象 (Constructor Parameter)

这是在 Qt 中确立父子关系最常见、最直接也是最推荐的方式。几乎所有继承自 QObject 的类,包括所有的 QWidget 控件,它们的构造函数都带有一个可选的 parent 参数。

语法:

1
child_object = ChildClass(parent_object)

示例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

import sys
from PyQt5.QtWidgets import QApplication, QMainWindow, QPushButton, QVBoxLayout, QWidget

class MyWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("父子关系示例")
        self.setGeometry(100, 100, 400, 300)

        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        layout = QVBoxLayout(central_widget) # 布局管理器通常将其父部件作为参数

        # 创建一个按钮,并将其父对象设置为 central_widget
        self.button1 = QPushButton("按钮 1", central_widget)
        layout.addWidget(self.button1)

        # 创建另一个按钮,也将其父对象设置为 central_widget
        self.button2 = QPushButton("按钮 2")
        self.button2.setParent(central_widget) # 也可以这样设置,但不如构造函数直接

        layout.addWidget(self.button2)

if __name__ == '__main__':
    app = QApplication(sys.argv)
    window = MyWindow()
    window.show()
    sys.exit(app.exec_())

解释:

  • 在上面的例子中,QPushButton("按钮 1", central_widget) 这一行代码,将 central_widget 指定为 button1 的父对象。
  • QVBoxLayout(central_widget) 也是类似,布局管理器会自动将其管理的控件添加到其父部件上。
  • 当一个 QObject 实例的构造函数接收一个 QObject 类型的 parent 参数时,它会将自身(子对象)添加到 parent 对象的子对象列表中。
通过 setParent() 方法设置父对象

你也可以在对象创建之后,通过调用其 setParent() 方法来确立或修改父子关系。

语法:

1
child_object.setParent(parent_object)

示例(接上文):

1
2
3
4
5
# 创建另一个按钮,先不指定父对象
self.button2 = QPushButton("按钮 2")
# 之后再将其父对象设置为 central_widget
self.button2.setParent(central_widget)
layout.addWidget(self.button2) # 注意:布局管理器会把其添加的部件自动设为布局的父部件的子部件

需要注意的几点:

  • 一个子对象只能有一个父对象。 如果你多次调用 setParent(),最新的调用会覆盖之前的父子关系。
  • 内存管理: 当父对象被销毁时,所有作为其子对象的 QObject 实例也会被自动销毁。这是 Qt 强大的内存管理机制。你通常不需要手动 del 子对象,除非它们没有父对象。
    • 重要提示: 如果一个 QObject 没有父对象,那么它的生命周期需要你手动管理。这在创建独立的窗口(如 QMainWindowQDialog)时很常见,因为它们是应用程序的顶级窗口,通常没有父对象。
  • 可见性: 如果一个 QWidget 有父对象,通常它的显示会受到父对象的限制。子部件不会显示在父部件之外。
  • 布局管理器: 当你使用布局管理器 (QHBoxLayout, QVBoxLayout, QGridLayout 等) 将控件添加到布局中时,布局管理器会自动处理控件的父子关系,通常会将布局的父部件设置为被添加控件的父部件。这也是为什么在上面的例子中,即使 button2 最初没有指定父对象,但通过 layout.addWidget(self.button2) 后,它最终也会成为 central_widget 的子部件。
父子关系的好处
  1. 自动内存管理: 最重要的好处是避免了内存泄漏。你不需要担心何时释放子对象,Qt 会在父对象销毁时自动清理。
  2. 层次结构组织: 方便管理和组织复杂的 UI 界面,形成清晰的对象树。
  3. 事件传播: 事件(如键盘事件、鼠标事件)可以沿着对象树从父对象传播到子对象,或反之,方便事件处理。
  4. 属性继承: 某些属性(如字体、调色板)可能会从父部件传递给子部件,简化了样式设置。

总而言之,在 Qt 中确立父子关系最常见的做法是在构造函数中指定父对象,这不仅简洁,也确保了内存管理的正确性。

Qt 的事件传播机制

当一个鼠标事件(比如 mousePressEvent)发生时,Qt 的事件系统会遵循一套规则来决定哪个部件应该处理这个事件:

  1. 事件发生并首先发送给最顶层的部件: 当用户点击屏幕上的某个点时,Qt 首先会确定哪个 最顶层(top-level)QWidget 包含这个点击点(例如,你的主窗口)。
  2. 事件向下传播到最深层的子部件: 然后,Qt 会从这个顶级部件开始,沿着对象树向下遍历,找到位于点击点下方的 最深层(innermost) 的子部件。
  3. 事件分发给最深层部件: Qt 会将 mousePressEvent 首先分发给这个最深层的子部件(也就是你实际点击的那个控件)。
  4. 事件冒泡(Bubble Up)/ 默认处理:
    • 如果最深层的子部件重写了 mousePressEvent 并且没有调用 super().mousePressEvent(event) 那么这个事件就被这个子部件“消费”了,它不会继续向上冒泡到它的父部件。这意味着父部件的 mousePressEvent 不会被触发。
    • 如果最深层的子部件没有重写 mousePressEvent 那么 Qt 会调用其父类的默认 mousePressEvent 实现,事件会继续向上冒泡到其父部件。
    • 如果最深层的子部件重写了 mousePressEvent 但调用了 super().mousePressEvent(event) 那么子部件的逻辑会先执行,然后事件会继续向上冒泡,触发父部件的 mousePressEvent(如果父部件也重写了)。

信号与槽

Qt 的信号与槽(Signals & Slots)机制是 Qt 框架的核心特性之一,它用于对象之间进行通信。这种机制替代了传统的 C++ 回调函数(callbacks)或函数指针,提供了一种类型安全、松散耦合的方式来处理事件和实现模块之间的通信。

什么是信号(Signals)?
  • 定义:当一个特定事件发生时,一个对象会发出(emit)一个信号。例如,QPushButton 在被点击时会发出 clicked() 信号。
  • 触发:信号由对象自动生成,以响应某些内部状态变化或用户操作。例如,当用户点击按钮时,按钮对象就会自动发出 clicked() 信号。
  • 特性
    • 自动生成:信号是自动生成的,你不需要手动编写代码来发出它们,只需声明并连接。
    • 参数:信号可以带有参数,这些参数可以传递事件的相关信息。例如,QSlidervalueChanged() 信号可以传递当前滑块的值。
    • 无返回类型:信号没有返回类型,它们不能返回任何值。
    • 独立于接收者:发出信号的对象(发送者)不需要知道是哪个对象(或哪些对象)在接收它的信号。这种松散耦合是信号与槽机制的关键优势。
什么是槽(Slots)?
  • 定义:槽是普通的 C++ 函数(或 Python 方法),当与之连接的信号被发出时,槽就会被调用。
  • 作用:槽用于响应信号。例如,当 QPushButton 发出 clicked() 信号时,你可以将这个信号连接到一个槽,该槽负责执行某个操作,如更新文本标签或打开新窗口。
  • 特性
    • 普通函数:槽可以是任何普通的 Python 方法或 C++ 函数(包括静态函数、全局函数、或类的成员函数)。
    • 参数匹配:槽的参数必须与连接到它的信号的参数兼容。这意味着槽可以接受信号传递的所有参数,或者更少的参数(从右侧开始省略)。
    • 可以是虚函数:槽可以是虚函数,这使得它们可以在子类中被重写,从而实现多态行为。
    • 可以是私有、保护或公共:槽的访问权限没有限制,但通常为了与其他类通信,它们会被声明为公共的。
如何连接(Connecting)信号和槽?

使用 connect() 方法来建立信号与槽之间的连接。

基本语法:

1
sender.signal.connect(receiver.slot)
  • sender:发出信号的对象。
  • signal:发送者对象的某个信号。
  • receiver:接收信号的对象。
  • slot:接收者对象的某个槽(方法)。

示例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
from PyQt5.QtWidgets import QApplication, QPushButton, QLabel, QVBoxLayout, QWidget
from PyQt5.QtCore import Qt

app = QApplication([])
window = QWidget()
layout = QVBoxLayout()

button = QPushButton("点击我")
label = QLabel("你好,世界!")
label.setAlignment(Qt.AlignCenter)

# 连接信号和槽
# 当 button 发出 clicked() 信号时,调用 label 的 setText() 槽
button.clicked.connect(lambda: label.setText("按钮被点击了!"))

layout.addWidget(button)
layout.addWidget(label)
window.setLayout(layout)
window.show()
app.exec_()

在这个例子中,button 是发送者,clicked() 是信号,label 是接收者,setText() 是槽。

信号与槽的优势
  1. 松散耦合(Loose Coupling):发送者和接收者彼此独立,发送者不需要知道接收者的任何信息(除了它将发出的信号)。这种解耦使得组件更容易复用和维护。
  2. 类型安全(Type Safety):Qt 的 connect 机制会在连接时检查信号和槽的参数类型是否兼容,从而避免运行时错误。
  3. 可重用性(Reusability):由于松散耦合,组件可以更容易地在不同的上下文和应用程序中被重用。
  4. 清晰的通信路径:代码中信号和槽的连接清晰地表明了对象间的通信路径,使得代码更容易理解。
  5. 一对多 / 多对一连接
    • 一个信号可以连接到多个槽。
    • 多个信号可以连接到同一个槽。
    • 一个信号可以连接到另一个信号(信号转发)。
自定义信号(Custom Signals)

除了 Qt 内置的信号,你也可以在自己的类中定义和发出自定义信号。在 PyQt 中,这通过 pyqtSignal 来实现。

步骤:

  1. 导入 pyqtSignalfrom PyQt5.QtCore import pyqtSignal
  2. 在类中声明信号:作为类属性声明,指定信号将传递的参数类型。
  3. 在方法中发出信号:使用 emit() 方法发出信号。

示例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QVBoxLayout, QLabel
from PyQt5.QtCore import pyqtSignal, QObject

class MyEmitter(QObject):
    # 声明一个自定义信号,不带参数
    my_signal_no_args = pyqtSignal()
    # 声明一个带一个字符串参数的信号
    my_signal_with_str = pyqtSignal(str)
    # 声明一个带两个整数参数的信号
    my_signal_with_ints = pyqtSignal(int, int)

    def __init__(self, parent=None):
        super().__init__(parent)
        self.counter = 0

    def do_something(self):
        self.counter += 1
        print("Doing something...")
        self.my_signal_no_args.emit() # 发出不带参数的信号
        self.my_signal_with_str.emit(f"操作 {self.counter} 完成!") # 发出带字符串参数的信号
        self.my_signal_with_ints.emit(self.counter, self.counter * 10) # 发出带整数参数的信号

class MyReceiver(QObject):
    def __init__(self, label, parent=None):
        super().__init__(parent)
        self.label = label

    def slot_no_args(self):
        self.label.setText("信号被接收了!")
        print("槽:不带参数的信号被触发。")

    def slot_with_str(self, message):
        self.label.setText(f"接收到消息: {message}")
        print(f"槽:接收到字符串信号: {message}")

    def slot_with_ints(self, num1, num2):
        print(f"槽:接收到两个整数信号: {num1}, {num2}")

app = QApplication(sys.argv)
window = QWidget()
layout = QVBoxLayout()

emitter = MyEmitter()
status_label = QLabel("等待操作...")
receiver = MyReceiver(status_label)

# 连接自定义信号到槽
emitter.my_signal_no_args.connect(receiver.slot_no_args)
emitter.my_signal_with_str.connect(receiver.slot_with_str)
emitter.my_signal_with_ints.connect(receiver.slot_with_ints)

action_button = QPushButton("执行操作")
action_button.clicked.connect(emitter.do_something) # 连接按钮点击到发射器的do_something方法

layout.addWidget(action_button)
layout.addWidget(status_label)
window.setLayout(layout)
window.show()
sys.exit(app.exec_())

面向对象编程

好久没回忆这里的内容了,之前写了很多C,搞忘记了这里。

面向对象编程(Object-Oriented Programming, OOP)简介

在解释继承和父对象之前,我们先快速回顾一下面向对象编程的基本思想。

面向对象编程是一种编程范式,它将程序中的数据和操作数据的方法组织成一个个对象。对象是类的实例。

核心概念:

  • 类(Class): 它是创建对象的蓝图或模板。类定义了对象的属性(数据)和行为(方法)。例如,一个 汽车 类可能定义了 颜色品牌 等属性,以及 启动加速 等方法。
  • 对象(Object): 它是类的实例。一个 汽车 类的对象可以是“我的红色宝马”或者“邻居的蓝色丰田”。每个对象都有自己的属性值。
  • 封装(Encapsulation): 将数据(属性)和操作数据的方法(行为)捆绑在一起,形成一个独立的单元(对象)。它隐藏了对象的内部实现细节,只暴露必要的接口。
  • 多态(Polymorphism): 允许不同类的对象对同一个消息做出不同的响应。例如,汽车摩托车 都有 启动 方法,但它们的具体启动方式可能不同。

继承(Inheritance)

继承是面向对象编程中一个非常强大的机制,它允许一个类(子类/派生类)从另一个已存在的类(父类/基类)中获取(继承)属性和方法。

核心思想:

  • 代码复用: 子类可以直接使用父类中已经定义好的属性和方法,而无需重新编写,大大减少了代码冗余。
  • 建立“is-a”关系: 继承表达了一种“是(is-a)”的关系。例如,“狗是一种动物”,“轿车是一种汽车”。
  • 扩展性: 子类可以在继承父类的基础上,添加自己特有的属性和方法,或者重写(覆盖)父类的方法,以实现更具体或不同的行为。

现在,我们想创建 狗 (Dog)猫 (Cat) 类。它们都是动物,所以它们应该拥有动物的基本属性(名字)和行为(叫、吃)。这时,我们就可以让 DogCat 继承 Animal 类:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class Dog(Animal):  # Dog 继承 Animal
    def __init__(self, name, breed):
        super().__init__(name) # 调用父类 Animal 的构造函数
        self.breed = breed # Dog 特有的属性

    def speak(self): # 重写父类的 speak 方法
        print(f"{self.name} (一只{self.breed}) 汪汪叫!")

    def fetch(self): # Dog 特有的方法
        print(f"{self.name} 正在捡球。")

class Cat(Animal): # Cat 继承 Animal
    def __init__(self, name, color):
        super().__init__(name) # 调用父类 Animal 的构造函数
        self.color = color # Cat 特有的属性

    def speak(self): # 重写父类的 speak 方法
        print(f"{self.name} (一只{self.color}猫) 喵喵叫!")

    def scratch(self): # Cat 特有的方法
        print(f"{self.name} 正在抓挠。")
  • Animal父类
  • DogCat子类
  • DogCat 自动获得了 Animal 类的 name 属性以及 eat 方法。
  • DogCat 分别重写了 speak 方法,使其更具体地表达了狗和猫的叫声。
  • Dog 添加了特有的 breed 属性和 fetch 方法。
  • Cat 添加了特有的 color 属性和 scratch 方法。

父对象 / 父类(Parent Class / Base Class / Superclass)

“父对象”和“父类”是同一个概念的不同表达方式,但“父类”是更常用和准确的术语。

  • 父类(Parent Class / Base Class / Superclass):

    • 定义: 被其他类继承的类。它提供通用的属性和方法,供子类共享。
    • 作用: 作为子类的基础,定义了子类共有的特性和行为。
    • 在上面的例子中,Animal 就是 DogCat 的父类。
  • 父对象:

    • 这个词组通常不直接用来指代类本身
    • 在某些上下文语境中,它可能间接指代父类的实例。例如,如果 dog_instanceDog 类的一个对象,那么我们可能会说 dog_instance 的“父类型”是 Animal,或者说 dog_instance 是从 Animal “派生”出来的。

组合与继承

1. 继承关系(Inheritance)
  • 含义: 继承表达的是一种“is-a”(是…一种)的关系。一个子类“是”一个父类。
    • 例如:狗是一种动物 (Dog is an Animal),汽车是一种交通工具 (Car is a Vehicle)。
  • 实现方式: 子类直接从父类派生,获得父类的属性和方法。子类可以扩展或重写父类的功能。
  • 优点: * 代码复用: 子类无需重新实现父类已有的功能。
    • 多态性: 允许使用父类引用来处理子类对象,提高代码的灵活性和可扩展性。
    • 层次结构: 能够清晰地表示类之间的分类和泛化关系。
  • 缺点:
    • 紧耦合: 子类和父类之间存在强烈的依赖关系。父类的改变可能会影响所有子类。
    • 单一继承的限制: 许多语言(如 Java、C#、Python 的普通类)只支持单继承,即一个子类只能有一个直接父类。这可能导致“类爆炸”或难以建模多方面特性的情况。
    • 违反封装: 子类可以访问父类的受保护成员,一定程度上破坏了封装性。
    • “脆弱的基类”问题: 父类的一些修改(即使是很小的)也可能导致子类行为异常。
2. 组合关系(Composition)
  • 含义: 组合表达的是一种“has-a”(拥有…)的关系。一个类“拥有”另一个类的对象作为其成员。
    • 例如:汽车拥有一个引擎 (Car has an Engine),电脑拥有一个CPU (Computer has a CPU)。
  • 实现方式: 一个类(被称为容器类复合类)在其内部包含另一个类(被称为被包含类组件类)的实例作为其属性。
  • 优点:
    • 松耦合: 容器类和被包含类之间的依赖性较弱。容器类只需要知道如何与被包含类的公共接口进行交互,而不需要关心其内部实现细节。
    • 高内聚: 每个类只负责自己的功能,职责更单一。
    • 灵活性: 可以更容易地替换或修改组件,而不需要修改容器类的代码(只要接口不变)。
    • 避免“脆弱的基类”问题: 组合关系下,组件类的修改对容器类的影响远小于继承关系下父类的修改对子类的影响。
    • 解决多重继承的复杂性: 当一个类需要多种功能时,可以通过组合多个组件来实现,避免了多重继承可能带来的复杂性和歧义。
  • 缺点:
    • 功能委托: 容器类需要显式地将被包含类的功能暴露出来(通过方法调用),不像继承那样自动获得所有公共方法。
    • 对象创建和管理: 容器类可能需要负责创建和管理其组件对象的生命周期。
  • 适用场景:
    • 当一个类是另一个类的组成部分,而不是其具体类型时。
    • 当需要构建灵活、可配置的系统时。
    • 当一个类需要多种不同的功能,而这些功能可以通过独立的组件来提供时。
    • “优先使用组合而不是继承”(Prefer composition over inheritance)是面向对象设计的一个重要原则,尤其是在需要高灵活性和低耦合度的场景。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class Engine:
    def start(self):
        print("Engine started.")
    def stop(self):
        print("Engine stopped.")

class Wheel:
    def rotate(self):
        print("Wheel rotating.")

class Car:
    def __init__(self):
        self.engine = Engine() # Car has an Engine (组合关系)
        self.wheels = [Wheel(), Wheel(), Wheel(), Wheel()] # Car has Wheels (组合关系)

    def drive(self):
        self.engine.start() # 委托给 engine 对象
        for wheel in self.wheels:
            wheel.rotate() # 委托给 wheel 对象
        print("Car is driving.")

    def stop_car(self):
        self.engine.stop()
        print("Car stopped.")

# 使用
my_car = Car()
my_car.drive()
my_car.stop_car()
6346436
使用 Hugo 构建
主题 StackJimmy 设计