IBM Qiskit量子机器学习速成(五)

量子核训练

在上一篇文章中,我们探讨了量子核的创建与简单应用。本文将详细阐述如何训练量子核。

首先我们导入本次需要的所有包

# External imports
from pylab import cm
from sklearn import metrics
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# Qiskit imports
from qiskit import QuantumCircuit
from qiskit.circuit import ParameterVector
from qiskit.visualization import circuit_drawer
from qiskit.circuit.library import ZZFeatureMap
from qiskit_algorithms.optimizers import SPSA
from qiskit_machine_learning.kernels import TrainableFidelityQuantumKernel
from qiskit_machine_learning.kernels.algorithms import QuantumKernelTrainer
from qiskit_machine_learning.algorithms import QSVC
from qiskit_machine_learning.datasets import ad_hoc_data

为了训练我们的量子核,我们要自定义一个自己的优化器回调类

class QKTCallback:
    """Callback wrapper class."""

    def __init__(self) -> None:
        self._data = [[] for i in range(5)]

    def callback(self, x0, x1=None, x2=None, x3=None, x4=None):
        """
        Args:
            x0: number of function evaluations
            x1: the parameters
            x2: the function value
            x3: the stepsize
            x4: whether the step was accepted
        """
        self._data[0].append(x0)
        self._data[1].append(x1)
        self._data[2].append(x2)
        self._data[3].append(x3)
        self._data[4].append(x4)

    def get_callback_data(self):
        return self._data

    def clear_callback_data(self):
        self._data = [[] for i in range(5)]

出于训练目的,我们要事先准备好数据集

adhoc_dimension = 2
X_train, y_train, X_test, y_test, adhoc_total = ad_hoc_data(
    training_size=20,
    test_size=5,
    n=adhoc_dimension,
    gap=0.3,
    plot_data=False,
    one_hot=False,
    include_sample_total=True,
)

为了将经典数据导入到量子电路中,我们需要加入一组特征映射。不出意外,我们使用ZZFeatureMap。

# Create a rotational layer to train. We will rotate each qubit the same amount.
training_params = ParameterVector("θ", 1)
fm0 = QuantumCircuit(2)
fm0.ry(training_params[0], 0)
fm0.ry(training_params[0], 1)

# Use ZZFeatureMap to represent input data
fm1 = ZZFeatureMap(2)

# Create the feature map, composed of our two circuits
fm = fm0.compose(fm1)

print(circuit_drawer(fm))
print(f"Trainable parameters: {
     training_params}")

下一步,我们正式创建量子核

# Instantiate quantum kernel
quant_kernel = TrainableFidelityQuantumKernel(feature_map=fm, training_parameters=training_params)

# Set up the optimizer
cb_qkt = QKTCallback()
spsa_opt = SPSA(maxiter=10, callback=cb_qkt.callback, learning_rate=0.05, perturbation=0.05)

# Instantiate a quantum kernel trainer.
qkt = QuantumKernelTrainer(
    quantum_kernel=quant_kernel, loss="svc_loss", optimizer=spsa_opt, initial_point=[np.pi / 2]
)

创建好量子核之后,我们直接训练它

# Train the kernel using QKT directly
qka_results = qkt.fit(X_train, y_train)
optimized_kernel = qka_results.quantum_kernel
print(qka_results)

和上一章一样,我们可以在支撑向量分类器中使用量子核

# Use QSVC for classification
qsvc = QSVC(quantum_kernel=optimized_kernel)

# Fit the QSVC
qsvc.fit(X_train, y_train)

# Predict the labels
labels_test = qsvc.predict(X_test)

# Evalaute the test accuracy
accuracy_test = metrics.balanced_accuracy_score(y_true=y_test, y_pred=labels_test)
print(f"accuracy test: {
     accuracy_test}")

相关推荐

  1. IBM Qiskit量子机器学习速成

    2023-12-10 08:34:02       33 阅读
  2. IBM Qiskit量子机器学习速成(二)

    2023-12-10 08:34:02       38 阅读
  3. IBM Qiskit量子机器学习速成(三)

    2023-12-10 08:34:02       36 阅读
  4. 机器学习速成

    2023-12-10 08:34:02       32 阅读
  5. 量子机器学习量子机器学习的介绍

    2023-12-10 08:34:02       28 阅读
  6. 机器视觉学习)—— 图像的几何

    2023-12-10 08:34:02       18 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-10 08:34:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-10 08:34:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-10 08:34:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-10 08:34:02       20 阅读

热门阅读

  1. CMMI认证有什么意义

    2023-12-10 08:34:02       42 阅读
  2. WPF(Windows Presentation Foundation) 的 Menu控件

    2023-12-10 08:34:02       34 阅读
  3. 深入探讨MySQL数据库的InnoDB存储引擎架构

    2023-12-10 08:34:02       54 阅读
  4. SpringMVC-Servlet

    2023-12-10 08:34:02       42 阅读
  5. ESP32网络编程-OTA方式升级固件(基于Arduino IDE)

    2023-12-10 08:34:02       38 阅读
  6. SQL命令---修改数据库的编码

    2023-12-10 08:34:02       38 阅读
  7. Oracle 怎樣修改DB_NAME

    2023-12-10 08:34:02       37 阅读