{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# State and state_typical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.set_printoptions(linewidth=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## State\n",
    "Quantum state (density matrix) can be represented by a linear combination of basis.\n",
    "`vec` represents the coefficients of this linear combination in the form of a numpy array.  \n",
    "\n",
    "Example.  \n",
    "$z_0 = \\begin{bmatrix} 1 & 0 \\\\ 0 & 0 \\\\ \\end{bmatrix}$ can be represented by a linear combination of basis $1/\\sqrt{2} \\cdot I/\\sqrt{2} + 0 \\cdot X/\\sqrt{2} + 0 \\cdot Y/\\sqrt{2} + 1/\\sqrt{2} \\cdot  Z/\\sqrt{2}$. In this case `vec` of $z_0$ is $[ 1/\\sqrt{2}, 0, 0, 1/\\sqrt{2} ]$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The methods for generating a State includes the following:\n",
    "\n",
    "- Generate from `state_typical` module\n",
    "- Generate State object directly\n",
    "\n",
    "Generate from `state_typical` module by specifying CompositeSystem and state name (ex. \"z0\")."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type:\n",
      "State\n",
      "\n",
      "Dim:\n",
      "2\n",
      "\n",
      "Vec:\n",
      "[0.70710678 0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "from quara.objects.composite_system_typical import generate_composite_system\n",
    "from quara.objects.state_typical import generate_state_from_name\n",
    "\n",
    "c_sys = generate_composite_system(\"qubit\", 1)\n",
    "state = generate_state_from_name(c_sys, \"z0\")\n",
    "print(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generate State object directly using CompositeSystem and a numpy array."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type:\n",
      "State\n",
      "\n",
      "Dim:\n",
      "2\n",
      "\n",
      "Vec:\n",
      "[0.70710678 0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "from quara.objects.composite_system import CompositeSystem\n",
    "from quara.objects.elemental_system import ElementalSystem\n",
    "from quara.objects.matrix_basis import get_normalized_pauli_basis\n",
    "from quara.objects.state import State\n",
    "\n",
    "basis = get_normalized_pauli_basis(1)\n",
    "e_sys = ElementalSystem(0, basis)\n",
    "c_sys = CompositeSystem([e_sys])\n",
    "vec = np.array([1, 0, 0, 1]) / np.sqrt(2)\n",
    "state = State(c_sys, vec)\n",
    "print(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### specific properties\n",
    "The property `vec` of State is a numpy array specified by the constructor argument `vec`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vec: [0.70710678 0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "state = State(c_sys, vec)\n",
    "print(f\"vec: {state.vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The property `dim` of State is a square root of the size of element of vec."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dim: 2\n",
      "square root of the size of element of vec: 2\n"
     ]
    }
   ],
   "source": [
    "print(f\"dim: {state.dim}\")\n",
    "print(f\"square root of the size of element of vec: {int(np.sqrt(len(vec)))}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### functions to check constraints\n",
    "The `is_eq_constraint_satisfied()` function returns True, if and only if $\\text{Tr}[\\rho] = 1$, where $\\rho$ is a density matrix of State."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_eq_constraint_satisfied(): True\n",
      "trace of density matrix: (0.9999999999999998+0j)\n"
     ]
    }
   ],
   "source": [
    "print(f\"is_eq_constraint_satisfied(): {state.is_eq_constraint_satisfied()}\")\n",
    "print(f\"trace of density matrix: {np.trace(state.to_density_matrix())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `is_ineq_constraint_satisfied()` function returns True, if and only if $\\rho$ is positive semidifinite matrix, where $\\rho$ is a density matrix of State."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_eq_constraint_satisfied(): True\n",
      "is_positive_semidefinite(): True\n"
     ]
    }
   ],
   "source": [
    "print(f\"is_eq_constraint_satisfied(): {state.is_eq_constraint_satisfied()}\")\n",
    "print(f\"is_positive_semidefinite(): {state.is_positive_semidefinite()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### projection functions\n",
    "`calc_proj_eq_constraint()` function calculates the projection of State on equal constraint.  \n",
    "This function replaces the first element of `vec` with $1/\\sqrt{d}$, where $d$ is `dim`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vec: [0.70710678 2.         3.         4.        ]\n"
     ]
    }
   ],
   "source": [
    "vec = np.array([1.0, 2.0, 3.0, 4.0])\n",
    "state = State(c_sys, vec, is_physicality_required=False)\n",
    "proj_state = state.calc_proj_eq_constraint()\n",
    "print(f\"vec: {proj_state.vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`calc_proj_ineq_constraint()` function calculates the projection of State on inequal constraint as follows:  \n",
    "\n",
    "- Executes singular value decomposition on the density matrix $\\rho$ of state, $\\rho = U \\Lambda U^{\\dagger}$, where $\\Lambda = \\text{diag}[\\lambda_0, \\dots , \\lambda_{d-1}]$, and $\\lambda_{i} \\in \\mathbb{R}$.\n",
    "- $\\lambda^{\\prime}_{i} := \\begin{cases} \\lambda_{i} & (\\lambda_{i} \\geq 0) \\\\ 0 & (\\lambda_{i} < 0) \\end{cases}$\n",
    "- $\\Lambda^{\\prime} = \\text{diag}[\\lambda^{\\prime}_0, \\dots , \\lambda^{\\prime}_{d-1}]$\n",
    "- $\\rho^{\\prime} = U \\Lambda^{\\prime} U^{\\dagger}$\n",
    "- The projection of State is $|\\rho^{\\prime} \\rangle\\rangle$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "density matrix before projection: \n",
      "[[ 1.+0.j  0.+0.j]\n",
      " [ 0.+0.j -1.+0.j]]\n",
      "density matrix after projection: \n",
      "[[1.+0.j 0.+0.j]\n",
      " [0.+0.j 0.+0.j]]\n",
      "vec after projection: \n",
      "[0.70710678 0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "vec = np.sqrt(2) * np.array([0, 0, 0, 1])\n",
    "state = State(c_sys, vec, is_physicality_required=False)\n",
    "print(f\"density matrix before projection: \\n{state.to_density_matrix()}\")\n",
    "proj_state = state.calc_proj_ineq_constraint()\n",
    "print(f\"density matrix after projection: \\n{proj_state.to_density_matrix()}\")\n",
    "print(f\"vec after projection: \\n{proj_state.vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### functions to transform parameters\n",
    "`to_stacked_vector()` function returns a one-dimensional numpy array of all variables. This is equal to `vec`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to_stacked_vector(): [0.         0.         0.         1.41421356]\n",
      "vec: [0.         0.         0.         1.41421356]\n"
     ]
    }
   ],
   "source": [
    "print(f\"to_stacked_vector(): {state.to_stacked_vector()}\")\n",
    "print(f\"vec: {state.vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `on_para_eq_constraint` is True, then the first element of `vec` is equal to $1/\\sqrt{d}$, where $d$ is `dim`. Thus, State is characterized by the second and subsequent elements of `vec`.  \n",
    "Therefore, `to_var()` function returns the second and subsequent elements of `vec`, where `on_para_eq_constraint` is True."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "vec = np.array([1, 0, 0, 1]) / np.sqrt(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to_var(): [0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "# on_para_eq_constraint=True\n",
    "state = State(c_sys, vec, on_para_eq_constraint=True)\n",
    "print(f\"to_var(): {state.to_var()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to_var(): [0.70710678 0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "# on_para_eq_constraint=False\n",
    "state = State(c_sys, vec, on_para_eq_constraint=False)\n",
    "print(f\"to_var(): {state.to_var()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### functions to generate special objects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "zero: [0. 0. 0. 0.]\n",
      "origin: [0.70710678 0.         0.         0.        ]\n"
     ]
    }
   ],
   "source": [
    "zero_state = state.generate_zero_obj()\n",
    "print(f\"zero: {zero_state.vec}\")\n",
    "origin_state = state.generate_origin_obj()\n",
    "print(f\"origin: {origin_state.vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### supports arithmetic operations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum: [ 6.  8. 10. 12.]\n",
      "subtraction: [-4. -4. -4. -4.]\n",
      "right multiplication: [2. 4. 6. 8.]\n",
      "left multiplication: [2. 4. 6. 8.]\n",
      "division: [0.5 1.  1.5 2. ]\n"
     ]
    }
   ],
   "source": [
    "vec1 = np.array([1.0, 2.0, 3.0, 4.0])\n",
    "state1 = State(c_sys, vec1, is_physicality_required=False)\n",
    "vec2 = np.array([5.0, 6.0, 7.0, 8.0])\n",
    "state2 = State(c_sys, vec2, is_physicality_required=False)\n",
    "\n",
    "print(f\"sum: {(state1 + state2).vec}\")\n",
    "print(f\"subtraction: {(state1 - state2).vec}\")\n",
    "print(f\"right multiplication: {(2 * state1).vec}\")\n",
    "print(f\"left multiplication: {(state1 * 2).vec}\")\n",
    "print(f\"division: {(state1 / 2).vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### calc_gradient functions\n",
    "Calculates gradient of State with variable index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vec: [1. 0. 0. 0.]\n"
     ]
    }
   ],
   "source": [
    "grad_state = state.calc_gradient(0)\n",
    "print(f\"vec: {grad_state.vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### convert_basis function\n",
    "Returns `vec` converted to the specified basis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vec: [1.+0.j 0.+0.j 0.+0.j 0.+0.j]\n"
     ]
    }
   ],
   "source": [
    "from quara.objects.matrix_basis import get_comp_basis\n",
    "\n",
    "state = generate_state_from_name(c_sys, \"z0\")\n",
    "converted_vec = state.convert_basis(get_comp_basis())\n",
    "print(f\"vec: {converted_vec}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### to_density_matrix function\n",
    "Returns density matrix of State. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "to_density_matrix(): \n",
      "[[1.+0.j 0.+0.j]\n",
      " [0.+0.j 0.+0.j]]\n",
      "to_density_matrix_with_sparsity(): \n",
      "[[1.+0.j 0.+0.j]\n",
      " [0.+0.j 0.+0.j]]\n"
     ]
    }
   ],
   "source": [
    "state = generate_state_from_name(c_sys, \"z0\")\n",
    "print(f\"to_density_matrix(): \\n{state.to_density_matrix()}\")\n",
    "print(f\"to_density_matrix_with_sparsity(): \\n{state.to_density_matrix_with_sparsity()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### some utility functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is_trace_one(): True\n",
      "is_hermitian(): True\n",
      "is_positive_semidefinite(): True\n",
      "calc_eigenvalues(): [0.9999999999999998, 0.0]\n"
     ]
    }
   ],
   "source": [
    "print(f\"is_trace_one(): {state.is_trace_one()}\")\n",
    "print(f\"is_hermitian(): {state.is_hermitian()}\")\n",
    "print(f\"is_positive_semidefinite(): {state.is_positive_semidefinite()}\")\n",
    "print(f\"calc_eigenvalues(): {state.calc_eigenvalues()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## state_typical\n",
    "`generate_state_object_from_state_name_object_name()` function in `state_typical` module can easily generate objects related to State.  \n",
    "The `generate_state_object_from_state_name_object_name()` function has the following arguments:\n",
    "\n",
    "- The string that can be specified for `state_name` can be checked by executing the `get_state_names()` function. The tensor product of state_name \"a\", \"b\" is written \"a_b\".\n",
    "- `object_name` can be the following string:\n",
    "  - \"pure_state_vector\" - vector of pure state.\n",
    "  - \"density_mat\" - density matrix.\n",
    "  - \"density_matrix_vector\" - vectorized density matrix.\n",
    "  - \"state\" - State object.\n",
    "- `c_sys` - CompositeSystem of objects related to State. Specify when `object_name` is \"density_matrix_vector\" or \"state\".\n",
    "- `is_physicality_required` - Whether the generated object is physicality required, by default True."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "from quara.objects.state_typical import (\n",
    "    get_state_names,\n",
    "    generate_state_object_from_state_name_object_name,\n",
    ")\n",
    "\n",
    "#get_state_names()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### object_name = \"pure_state_vector\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.+0.j 0.+0.j]\n"
     ]
    }
   ],
   "source": [
    "vec = generate_state_object_from_state_name_object_name(\"z0\", \"pure_state_vector\")\n",
    "print(vec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### object_name = \"density_mat\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1.+0.j 0.+0.j]\n",
      " [0.+0.j 0.+0.j]]\n"
     ]
    }
   ],
   "source": [
    "density_mat = generate_state_object_from_state_name_object_name(\"z0\", \"density_mat\")\n",
    "print(density_mat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### object_name = \"density_matrix_vector\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.70710678 0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "c_sys = generate_composite_system(\"qubit\", 1)\n",
    "vec = generate_state_object_from_state_name_object_name(\"z0\", \"density_matrix_vector\", c_sys=c_sys)\n",
    "print(vec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### object_name = \"state\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Type:\n",
      "State\n",
      "\n",
      "Dim:\n",
      "2\n",
      "\n",
      "Vec:\n",
      "[0.70710678 0.         0.         0.70710678]\n"
     ]
    }
   ],
   "source": [
    "c_sys = generate_composite_system(\"qubit\", 1)\n",
    "state = generate_state_object_from_state_name_object_name(\"z0\", \"state\", c_sys=c_sys)\n",
    "print(state)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.7 64-bit ('quara': pipenv)",
   "metadata": {
    "interpreter": {
     "hash": "e0c99f005b0837bce79602fafa142c135e273dc311d0671484e4c834ee0e9c24"
    }
   },
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
