from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, execute, Aer
from qiskit.visualization import plot_histogram
from qiskit.quantum_info import Statevector

def create_teleportation_circuit():
    """Creates a quantum circuit for teleporting the state of a qubit."""
    # We need 3 qubits and 2 classical bits
    qr = QuantumRegister(3, name="q")
    crz = ClassicalRegister(1, name="crz")
    crx = ClassicalRegister(1, name="crx")
    circuit = QuantumCircuit(qr, crz, crx)

    # Step 1: Create the message state to be teleported (on qubit 0)
    # In this example, we create a superposition state
    circuit.h(0)
    circuit.barrier()

    # Step 2: Create a Bell pair between qubit 1 and qubit 2
    circuit.h(1)
    circuit.cx(1, 2)
    circuit.barrier()

    # Step 3: Perform a Bell measurement on the message qubit (0) and qubit 1
    circuit.cx(0, 1)
    circuit.h(0)
    circuit.barrier()

    # Step 4: Measure qubits 0 and 1 and store the results in classical bits
    circuit.measure(0, 0)
    circuit.measure(1, 1)
    circuit.barrier()

    # Step 5: Apply gates to qubit 2 based on the classical bits
    circuit.x(2).c_if(crx, 1)
    circuit.z(2).c_if(crz, 1)

    return circuit

if __name__ == "__main__":
    # Create the teleportation circuit
    teleportation_circuit = create_teleportation_circuit()
    print("Quantum Teleportation Circuit:")
    print(teleportation_circuit)

    # To verify the teleportation, we can use a statevector simulator
    # to see the state of the qubits at the end of the circuit.
    backend = Aer.get_backend('statevector_simulator')
    job = execute(teleportation_circuit, backend)
    result = job.result()
    statevector = result.get_statevector(teleportation_circuit)

    print("\nFinal Statevector:")
    print(statevector)

    # The state of the third qubit should be the same as the initial state of the first qubit.
    # Initial state of qubit 0 was |+> = 1/sqrt(2) * (|0> + |1>)
    # We can check the state of the third qubit.
    # Since the first two qubits are measured, their state is determined.
    # The state of the third qubit is what we are interested in.

    # Let's run on a qasm_simulator to see the measurement outcomes
    backend_sim = Aer.get_backend('qasm_simulator')
    job_sim = execute(teleportation_circuit, backend_sim, shots=1024)
    result_sim = job_sim.result()
    counts = result_sim.get_counts(teleportation_circuit)

    print("\nMeasurement Outcomes:")
    print(counts)
    plot_histogram(counts).show()
