o
    jg/                     @   s   d dl Z d dlZd dlmZmZ d dlmZ d dlmZm	Z	 d dl
mZ d dlmZ d dlmZ G dd	 d	eZd
d ZG dd deZdS )    N)_sympifysympify)Expr)BasicTuple)ImmutableDenseNDimArray)Symbol)Integerc                   @   s   e Zd ZdZdd Zedd Zedd Zedd	 Zed
d Z	edd Z
edd Zedd Zdd Zdd Zedd Zedd Zedd Zdd Zdd Zd d! Zd"d# Zd$d% Zd&S )'ArrayComprehensiona  
    Generate a list comprehension.

    Explanation
    ===========

    If there is a symbolic dimension, for example, say [i for i in range(1, N)] where
    N is a Symbol, then the expression will not be expanded to an array. Otherwise,
    calling the doit() function will launch the expansion.

    Examples
    ========

    >>> from sympy.tensor.array import ArrayComprehension
    >>> from sympy import symbols
    >>> i, j, k = symbols('i j k')
    >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
    >>> a
    ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
    >>> a.doit()
    [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
    >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
    >>> b.doit()
    ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k))
    c                 O   s   t dd |D rtdt|g}|| || tj| g|R i |}|jdd  |_| 	|j|_
t|j
|_| |j
|_|S )Nc                 s        | ]}t |d kpdV  qdS    Nlen.0l r   ^/var/www/html/zoom/venv/lib/python3.10/site-packages/sympy/tensor/array/array_comprehension.py	<genexpr>%       z-ArrayComprehension.__new__.<locals>.<genexpr>KArrayComprehension requires values lower and upper bound for the expression   )any
ValueErrorr   extend_check_limits_validityr   __new___args_limits_calculate_shape_from_limits_shaper   _rank_calculate_loop_size
_loop_sizeclsfunctionsymbolsassumptionsarglistobjr   r   r   r   $   s   
zArrayComprehension.__new__c                 C   s
   | j d S )aA  The function applied across limits.

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j = symbols('i j')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.function
        10*i + j
        r   )r   selfr   r   r   r'   1   s   
zArrayComprehension.functionc                 C      | j S )au  
        The list of limits that will be applied while expanding the array.

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j = symbols('i j')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.limits
        ((i, 1, 4), (j, 1, 3))
        r   r,   r   r   r   limitsA   s   zArrayComprehension.limitsc                 C   s@   | j j}| jD ]\}}}|| |j|j}||}q|S )a)  
        The set of the free_symbols in the array.
        Variables appeared in the bounds are supposed to be excluded
        from the free symbol set.

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j, k = symbols('i j k')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.free_symbols
        set()
        >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
        >>> b.free_symbols
        {k}
        )r'   free_symbolsr   discardunion)r-   expr_free_symvarinfsupcurr_free_symsr   r   r   r1   R   s   
zArrayComprehension.free_symbolsc                 C      dd | j D S )aL  The tuples of the variables in the limits.

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j, k = symbols('i j k')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.variables
        [i, j]
        c                 S   s   g | ]}|d  qS )r   r   r   r   r   r   
<listcomp>{   s    z0ArrayComprehension.variables.<locals>.<listcomp>r/   r,   r   r   r   	variablesm   s   zArrayComprehension.variablesc                 C   r9   )zThe list of dummy variables.

        Note
        ====

        Note that all variables are dummy variables since a limit without
        lower bound or upper bound is not accepted.
        c                 S   s    g | ]}t |d kr|d qS )r   r   r   r   r   r   r   r:      s     z4ArrayComprehension.bound_symbols.<locals>.<listcomp>r/   r,   r   r   r   bound_symbols}   s   
z ArrayComprehension.bound_symbolsc                 C   r.   )aE  
        The shape of the expanded array, which may have symbols.

        Note
        ====

        Both the lower and the upper bounds are included while
        calculating the shape.

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j, k = symbols('i j k')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.shape
        (4, 3)
        >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
        >>> b.shape
        (4, k + 3)
        )r!   r,   r   r   r   shape   s   zArrayComprehension.shapec                 C   s,   | j D ]\}}}t||tr dS qdS )a  
        Test if the array is shape-numeric which means there is no symbolic
        dimension.

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j, k = symbols('i j k')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.is_shape_numeric
        True
        >>> b = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, k+3))
        >>> b.is_shape_numeric
        False
        FT)r   r   atomsr   )r-   _r6   r7   r   r   r   is_shape_numeric   s
   z#ArrayComprehension.is_shape_numericc                 C   r.   )a9  The rank of the expanded array.

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j, k = symbols('i j k')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.rank()
        2
        )r"   r,   r   r   r   rank   s   zArrayComprehension.rankc                 C   s   | j jrtd| j S )a  
        The length of the expanded array which means the number
        of elements in the array.

        Raises
        ======

        ValueError : When the length of the array is symbolic

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j = symbols('i j')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> len(a)
        12
        z Symbolic length is not supported)r$   r1   r   r,   r   r   r   __len__   s   zArrayComprehension.__len__c                 C   s   g }|D ]K\}}}t |}t |}t|trt| }nt |}|t||| tdd ||fD r7td||kdkrAtd||jv sK||jv rOtdq|S )Nc                 s   s0    | ]}t |t p|tt| kV  qd S N)
isinstancer   r>   r   r	   )r   ir   r   r   r      s    (z<ArrayComprehension._check_limits_validity.<locals>.<genexpr>zABounds should be an Expression(combination of Integer and Symbol)Tz-Lower bound should be inferior to upper boundz)Variable should not be part of its bounds)	r   rD   listr   appendr   	TypeErrorr   r1   )r&   r'   r0   
new_limitsr5   r6   r7   r   r   r   r      s$   

z)ArrayComprehension._check_limits_validityc                 C   s   t dd |D S )Nc                 S   s   g | ]\}}}|| d  qS r   r   )r   r?   r6   r7   r   r   r   r:      s    zCArrayComprehension._calculate_shape_from_limits.<locals>.<listcomp>)tuple)r&   r0   r   r   r   r       s   z/ArrayComprehension._calculate_shape_from_limitsc                 C   s"   |sdS d}|D ]}|| }q|S )Nr   r   r   )r&   r=   	loop_sizer   r   r   r   r#      s   
z'ArrayComprehension._calculate_loop_sizec                 K   s   | j s| S |  S rC   )r@   _expand_array)r-   hintsr   r   r   doit  s   zArrayComprehension.doitc                 C   s<   g }t jdd | jD  D ]
}|| | qt|| jS )Nc                 S   s    g | ]\}}}t ||d  qS rJ   )range)r   r5   r6   r7   r   r   r   r:     s    z4ArrayComprehension._expand_array.<locals>.<listcomp>)	itertoolsproductr   rG   _get_elementr   r=   )r-   resvaluesr   r   r   rM     s   
z ArrayComprehension._expand_arrayc                 C   s,   | j }t| j|D ]
\}}|||}q	|S rC   )r'   zipr;   subs)r-   rU   tempr5   valr   r   r   rS     s   zArrayComprehension._get_elementc                 C   s   | j r	|   S td)a  Transform the expanded array to a list.

        Raises
        ======

        ValueError : When there is a symbolic dimension

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j = symbols('i j')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.tolist()
        [[11, 12, 13], [21, 22, 23], [31, 32, 33], [41, 42, 43]]
        z-A symbolic array cannot be expanded to a list)r@   rM   tolistr   r,   r   r   r   rZ     s   zArrayComprehension.tolistc                 C   s<   ddl m} | jstd| jdkrtd||   S )aE  Transform the expanded array to a matrix.

        Raises
        ======

        ValueError : When there is a symbolic dimension
        ValueError : When the rank of the expanded array is not equal to 2

        Examples
        ========

        >>> from sympy.tensor.array import ArrayComprehension
        >>> from sympy import symbols
        >>> i, j = symbols('i j')
        >>> a = ArrayComprehension(10*i + j, (i, 1, 4), (j, 1, 3))
        >>> a.tomatrix()
        Matrix([
        [11, 12, 13],
        [21, 22, 23],
        [31, 32, 33],
        [41, 42, 43]])
        r   )Matrixz/A symbolic array cannot be expanded to a matrix   zDimensions must be of size of 2)sympy.matricesr[   r@   r   r"   rM   tomatrix)r-   r[   r   r   r   r^   3  s   
zArrayComprehension.tomatrixN)__name__
__module____qualname____doc__r   propertyr'   r0   r1   r;   r<   r=   r@   rA   rB   classmethodr   r    r#   rO   rM   rS   rZ   r^   r   r   r   r   r
   
   s<    









		r
   c                 C   s"   dd }t | t|o| j|jkS )Nc                   S   s   dS )Nr   r   r   r   r   r   <lambda>U  s    zisLambda.<locals>.<lambda>)rD   typer_   )vLAMBDAr   r   r   isLambdaT  s   ri   c                   @   s,   e Zd ZdZdd Zedd Zdd ZdS )	ArrayComprehensionMapa[  
    A subclass of ArrayComprehension dedicated to map external function lambda.

    Notes
    =====

    Only the lambda function is considered.
    At most one argument in lambda function is accepted in order to avoid ambiguity
    in value assignment.

    Examples
    ========

    >>> from sympy.tensor.array import ArrayComprehensionMap
    >>> from sympy import symbols
    >>> i, j, k = symbols('i j k')
    >>> a = ArrayComprehensionMap(lambda: 1, (i, 1, 4))
    >>> a.doit()
    [1, 1, 1, 1]
    >>> b = ArrayComprehensionMap(lambda a: a+1, (j, 1, 4))
    >>> b.doit()
    [2, 3, 4, 5]

    c                 O   s   t dd |D rtdt|std| ||}tj| g|R i |}|j|_| |j|_	t
|j	|_| |j	|_||_|S )Nc                 s   r   r   r   r   r   r   r   r   r  r   z0ArrayComprehensionMap.__new__.<locals>.<genexpr>r   zData type not supported)r   r   ri   r   r   r   r   r   r    r!   r   r"   r#   r$   _lambdar%   r   r   r   r   q  s   zArrayComprehensionMap.__new__c                    s   G  fdddt }|S )Nc                       s   e Zd Z fddZdS )z%ArrayComprehensionMap.func.<locals>._c                    s   t  jg|R i |S rC   )rj   rk   )r&   argskwargsr,   r   r   r     s   z-ArrayComprehensionMap.func.<locals>._.__new__N)r_   r`   ra   r   r   r,   r   r   r?     s    r?   )rj   )r-   r?   r   r,   r   func  s   zArrayComprehensionMap.funcc                 C   sD   | j }| j jjdkr| }|S | j jjdkr |tdd |}|S )Nr   r   c                 S   s   | | S rC   r   )abr   r   r   re     s    z4ArrayComprehensionMap._get_element.<locals>.<lambda>)rk   __code__co_argcount	functoolsreduce)r-   rU   rX   r   r   r   rS     s   z"ArrayComprehensionMap._get_elementN)r_   r`   ra   rb   r   rc   rn   rS   r   r   r   r   rj   X  s    
rj   )rs   rQ   sympy.core.sympifyr   r   sympy.core.exprr   
sympy.corer   r   sympy.tensor.arrayr   sympy.core.symbolr   sympy.core.numbersr	   r
   ri   rj   r   r   r   r   <module>   s      L