今天看到如下代码(一个学习率调度器类),引发了笔者困扰已久的问题:init()特殊方法到底有什么用,为什么python类中要使用__init__()特殊方法?

class LRScheduler():
	"""
	Learning rate scheduler. If the validation loss does not decrease for the
	given number of `patience` epochs, then the learning rate will decrease by
	by given `factor`.
	"""
	def __init__(self, optimizer, patience=7, min_lr=1e-6, factor=0.5):
		"""
		new_lr = old_lr * factor
		:param optimizer: the optimizer we are using
		:param patience: how many epochs to wait before updating the lr
		:param min_lr: least lr value to reduce to while updating
		:param factor: factor by which the lr should be updated
		"""
		self.optimizer = optimizer
		self.patience = patience
		self.min_lr = min_lr
		self.factor = factor
		self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
				self.optimizer,
				mode='min',
				patience=self.patience,
				factor=self.factor,
				min_lr=self.min_lr,
				verbose=True
			)
	def __call__(self, val_loss):
		self.lr_scheduler.step(val_loss)

__init__是一个特殊方法,解释为类的初始化方法或构造器,功能也就不言而喻了,当创建一个类的实例时,python会自动调用这个方法。
__init__方法的第一个参数是self,表示调用它自身,之后定义的其他参数用于初始化其他变量。

为什么上面的代码不用 "def LRScheduler():" , 而要定义“ LRScheduler():”类呢?
LRScheduler()类中有optimizer、patience、factor等配置参数,使用类将这些参数和逻辑进行封装,可以避免传入一大串参数,并且配合__call__()特殊方法可以实现对类的实例的调用。

import torch
import torch.optim as optim

# 假设我们有一个模型
model = ...
# 创建一个优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 创建学习率调度器实例
scheduler = LRScheduler(optimizer, patience=5, min_lr=1e-7, factor=0.1)

# 在训练循环中
for epoch in range(num_epochs):
    # 训练模型...
    # 计算验证损失
    val_loss = ...
    # 更新学习率
    scheduler(val_loss)
转载请注明出处