Skip to content

Commit 66da2c5

Browse files
shreyas3156ricardoV94
authored andcommitted
Renamed aliases from at to pt
1 parent 4da5edf commit 66da2c5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+1917
-1913
lines changed

benchmarks/benchmarks/benchmarks.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import pandas as pd
2020
import pytensor
21-
import pytensor.tensor as at
21+
import pytensor.tensor as pt
2222

2323
import pymc as pm
2424

@@ -61,8 +61,8 @@ def mixture_model(random_seed=1234):
6161
mu = pm.Normal("mu", mu=0.0, sigma=10.0, shape=w_true.shape)
6262
enforce_order = pm.Potential(
6363
"enforce_order",
64-
at.switch(mu[0] - mu[1] <= 0, 0.0, -np.inf)
65-
+ at.switch(mu[1] - mu[2] <= 0, 0.0, -np.inf),
64+
pt.switch(mu[0] - mu[1] <= 0, 0.0, -np.inf)
65+
+ pt.switch(mu[1] - mu[2] <= 0, 0.0, -np.inf),
6666
)
6767
tau = pm.Gamma("tau", alpha=1.0, beta=1.0, shape=w_true.shape)
6868
pm.NormalMixture("x_obs", w=w, mu=mu, tau=tau, observed=x)

docs/source/PyMC_and_PyTensor.rst

+26-26
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,25 @@ First, we need to define symbolic variables for our inputs (this
3434
is similar to eg SymPy's `Symbol`)::
3535

3636
import pytensor
37-
import pytensor.tensor as at
37+
import pytensor.tensor as pt
3838
# We don't specify the dtype of our input variables, so it
3939
# defaults to using float64 without any special config.
40-
a = at.scalar('a')
41-
x = at.vector('x')
42-
# `at.ivector` creates a symbolic vector of integers.
43-
y = at.ivector('y')
40+
a = pt.scalar('a')
41+
x = pt.vector('x')
42+
# `pt.ivector` creates a symbolic vector of integers.
43+
y = pt.ivector('y')
4444

4545
Next, we use those variables to build up a symbolic representation
4646
of the output of our function. Note that no computation is actually
4747
being done at this point. We only record what operations we need to
4848
do to compute the output::
4949

5050
inner = a * x**3 + y**2
51-
out = at.exp(inner).sum()
51+
out = pt.exp(inner).sum()
5252

5353
.. note::
5454

55-
In this example we use `at.exp` to create a symbolic representation
55+
In this example we use `pt.exp` to create a symbolic representation
5656
of the exponential of `inner`. Somewhat surprisingly, it
5757
would also have worked if we used `np.exp`. This is because numpy
5858
gives objects it operates on a chance to define the results of
@@ -77,8 +77,8 @@ We can call this function with actual arrays as many times as we want::
7777

7878
For the most part the symbolic PyTensor variables can be operated on
7979
like NumPy arrays. Most NumPy functions are available in `pytensor.tensor`
80-
(which is typically imported as `at`). A lot of linear algebra operations
81-
can be found in `at.nlinalg` and `at.slinalg` (the NumPy and SciPy
80+
(which is typically imported as `pt`). A lot of linear algebra operations
81+
can be found in `pt.nlinalg` and `pt.slinalg` (the NumPy and SciPy
8282
operations respectively). Some support for sparse matrices is available
8383
in `pytensor.sparse`. For a detailed overview of available operations,
8484
see :mod:`the pytensor api docs <pytensor.tensor>`.
@@ -88,9 +88,9 @@ NumPy arrays are operations involving conditional execution.
8888

8989
Code like this won't work as expected::
9090

91-
a = at.vector('a')
91+
a = pt.vector('a')
9292
if (a > 0).all():
93-
b = at.sqrt(a)
93+
b = pt.sqrt(a)
9494
else:
9595
b = -a
9696

@@ -100,28 +100,28 @@ and according to the rules for this conversion, things that aren't empty
100100
containers or zero are converted to `True`. So the code is equivalent
101101
to this::
102102

103-
a = at.vector('a')
104-
b = at.sqrt(a)
103+
a = pt.vector('a')
104+
b = pt.sqrt(a)
105105

106-
To get the desired behaviour, we can use `at.switch`::
106+
To get the desired behaviour, we can use `pt.switch`::
107107

108-
a = at.vector('a')
109-
b = at.switch((a > 0).all(), at.sqrt(a), -a)
108+
a = pt.vector('a')
109+
b = pt.switch((a > 0).all(), pt.sqrt(a), -a)
110110

111111
Indexing also works similarly to NumPy::
112112

113-
a = at.vector('a')
113+
a = pt.vector('a')
114114
# Access the 10th element. This will fail when a function build
115115
# from this expression is executed with an array that is too short.
116116
b = a[10]
117117

118118
# Extract a subvector
119119
b = a[[1, 2, 10]]
120120

121-
Changing elements of an array is possible using `at.set_subtensor`::
121+
Changing elements of an array is possible using `pt.set_subtensor`::
122122

123-
a = at.vector('a')
124-
b = at.set_subtensor(a[:10], 1)
123+
a = pt.vector('a')
124+
b = pt.set_subtensor(a[:10], 1)
125125

126126
# is roughly equivalent to this (although pytensor avoids
127127
# the copy if `a` isn't used anymore)
@@ -167,7 +167,7 @@ this is happening::
167167
# in exactly this way!
168168
model = pm.Model()
169169

170-
mu = at.scalar('mu')
170+
mu = pt.scalar('mu')
171171
model.add_free_variable(mu)
172172
model.add_logp_term(pm.Normal.dist(0, 1).logp(mu))
173173

@@ -195,15 +195,15 @@ is roughly equivalent to this::
195195

196196
# For illustration only, not real code!
197197
model = pm.Model()
198-
mu = at.scalar('mu')
198+
mu = pt.scalar('mu')
199199
model.add_free_variable(mu)
200200
model.add_logp_term(pm.Normal.dist(0, 1).logp(mu))
201201

202-
sd_log__ = at.scalar('sd_log__')
202+
sd_log__ = pt.scalar('sd_log__')
203203
model.add_free_variable(sd_log__)
204204
model.add_logp_term(corrected_logp_half_normal(sd_log__))
205205

206-
sigma = at.exp(sd_log__)
206+
sigma = pt.exp(sd_log__)
207207
model.add_deterministic_variable(sigma)
208208

209209
model.add_logp_term(pm.Normal.dist(mu, sigma).logp(data))
@@ -214,8 +214,8 @@ PyTensor operation on them::
214214

215215
design_matrix = np.array([[...]])
216216
with pm.Model() as model:
217-
# beta is a at.dvector
217+
# beta is a pt.dvector
218218
beta = pm.Normal('beta', 0, 1, shape=len(design_matrix))
219-
predict = at.dot(design_matrix, beta)
219+
predict = pt.dot(design_matrix, beta)
220220
sigma = pm.HalfCauchy('sigma', beta=2.5)
221221
pm.Normal('y', mu=predict, sigma=sigma, observed=data)

docs/source/contributing/implementing_distribution.md

+8-8
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ Here is how the example continues:
129129

130130
```python
131131

132-
import pytensor.tensor as at
132+
import pytensor.tensor as pt
133133
from pymc.pytensorf import floatX, intX
134134
from pymc.distributions.continuous import PositiveContinuous
135135
from pymc.distributions.dist_math import check_parameters
@@ -146,12 +146,12 @@ class Blah(PositiveContinuous):
146146
# We pass the standard parametrizations to super().dist
147147
@classmethod
148148
def dist(cls, param1, param2=None, alt_param2=None, **kwargs):
149-
param1 = at.as_tensor_variable(intX(param1))
149+
param1 = pt.as_tensor_variable(intX(param1))
150150
if param2 is not None and alt_param2 is not None:
151151
raise ValueError("Only one of param2 and alt_param2 is allowed.")
152152
if alt_param2 is not None:
153153
param2 = 1 / alt_param2
154-
param2 = at.as_tensor_variable(floatX(param2))
154+
param2 = pt.as_tensor_variable(floatX(param2))
155155

156156
# The first value-only argument should be a list of the parameters that
157157
# the rv_op needs in order to be instantiated
@@ -161,19 +161,19 @@ class Blah(PositiveContinuous):
161161
# the variable, given the implicit `rv`, `size` and `param1` ... `paramN`.
162162
# This is typically a "representative" point such as the the mean or mode.
163163
def moment(rv, size, param1, param2):
164-
moment, _ = at.broadcast_arrays(param1, param2)
164+
moment, _ = pt.broadcast_arrays(param1, param2)
165165
if not rv_size_is_none(size):
166-
moment = at.full(size, moment)
166+
moment = pt.full(size, moment)
167167
return moment
168168

169169
# Logp returns a symbolic expression for the elementwise log-pdf or log-pmf evaluation
170170
# of the variable given the `value` of the variable and the parameters `param1` ... `paramN`.
171171
def logp(value, param1, param2):
172-
logp_expression = value * (param1 + at.log(param2))
172+
logp_expression = value * (param1 + pt.log(param2))
173173

174174
# A switch is often used to enforce the distribution support domain
175-
bounded_logp_expression = at.switch(
176-
at.gt(value >= 0),
175+
bounded_logp_expression = pt.switch(
176+
pt.gt(value >= 0),
177177
logp_expression,
178178
-np.inf,
179179
)

docs/source/glossary.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ tensor_like
128128
Any scalar or sequence that can be interpreted as a {class}`~pytensor.tensor.TensorVariable`. In addition to TensorVariables, this includes NumPy ndarrays, scalars, lists and tuples (possibly nested). Any argument accepted by `pytensor.tensor.as_tensor_variable` is tensor_like.
129129

130130
```{jupyter-execute}
131-
import pytensor.tensor as at
131+
import pytensor.tensor as pt
132132
133-
at.as_tensor_variable([[1, 2.0], [0, 0]])
133+
pt.as_tensor_variable([[1, 2.0], [0, 0]])
134134
```
135135

136136
:::::

pymc/data.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
import pandas as pd
2424
import pytensor
25-
import pytensor.tensor as at
25+
import pytensor.tensor as pt
2626
import xarray as xr
2727

2828
from pytensor.compile.sharedvalue import SharedVariable
@@ -164,7 +164,7 @@ def assert_all_scalars_equal(scalar, *scalars):
164164
else:
165165
return Assert(
166166
"All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code"
167-
)(scalar, at.all([scalar == s for s in scalars]))
167+
)(scalar, pt.all([scalar == s for s in scalars]))
168168

169169

170170
def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: int):
@@ -185,7 +185,7 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
185185
>>> mdata1, mdata2 = Minibatch(data1, data2, batch_size=10)
186186
"""
187187

188-
tensor, *tensors = tuple(map(at.as_tensor, (variable, *variables)))
188+
tensor, *tensors = tuple(map(pt.as_tensor, (variable, *variables)))
189189
upper = assert_all_scalars_equal(*[t.shape[0] for t in (tensor, *tensors)])
190190
slc = minibatch_index(0, upper, size=batch_size)
191191
for i, v in enumerate((tensor, *tensors)):
@@ -435,7 +435,7 @@ def Data(
435435
if mutable:
436436
x = pytensor.shared(arr, name, **kwargs)
437437
else:
438-
x = at.as_tensor_variable(arr, name, **kwargs)
438+
x = pt.as_tensor_variable(arr, name, **kwargs)
439439

440440
if isinstance(dims, str):
441441
dims = (dims,)

pymc/distributions/bound.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import warnings
1515

1616
import numpy as np
17-
import pytensor.tensor as at
17+
import pytensor.tensor as pt
1818

1919
from pytensor.tensor import as_tensor_variable
2020
from pytensor.tensor.random.op import RandomVariable
@@ -72,8 +72,8 @@ def logp(value, distribution, lower, upper):
7272
-------
7373
TensorVariable
7474
"""
75-
res = at.switch(
76-
at.or_(at.lt(value, lower), at.gt(value, upper)),
75+
res = pt.switch(
76+
pt.or_(pt.lt(value, lower), pt.gt(value, upper)),
7777
-np.inf,
7878
logp(distribution, value),
7979
)
@@ -126,8 +126,8 @@ def logp(value, distribution, lower, upper):
126126
-------
127127
TensorVariable
128128
"""
129-
res = at.switch(
130-
at.or_(at.lt(value, lower), at.gt(value, upper)),
129+
res = pt.switch(
130+
pt.or_(pt.lt(value, lower), pt.gt(value, upper)),
131131
-np.inf,
132132
logp(distribution, value),
133133
)

pymc/distributions/censored.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import numpy as np
15-
import pytensor.tensor as at
15+
import pytensor.tensor as pt
1616

1717
from pytensor.tensor import TensorVariable
1818
from pytensor.tensor.random.op import RandomVariable
@@ -101,16 +101,16 @@ def dist(cls, dist, lower, upper, **kwargs):
101101

102102
@classmethod
103103
def rv_op(cls, dist, lower=None, upper=None, size=None):
104-
lower = at.constant(-np.inf) if lower is None else at.as_tensor_variable(lower)
105-
upper = at.constant(np.inf) if upper is None else at.as_tensor_variable(upper)
104+
lower = pt.constant(-np.inf) if lower is None else pt.as_tensor_variable(lower)
105+
upper = pt.constant(np.inf) if upper is None else pt.as_tensor_variable(upper)
106106

107107
# When size is not specified, dist may have to be broadcasted according to lower/upper
108-
dist_shape = size if size is not None else at.broadcast_shape(dist, lower, upper)
108+
dist_shape = size if size is not None else pt.broadcast_shape(dist, lower, upper)
109109
dist = change_dist_size(dist, dist_shape)
110110

111111
# Censoring is achieved by clipping the base distribution between lower and upper
112112
dist_, lower_, upper_ = dist.type(), lower.type(), upper.type()
113-
censored_rv_ = at.clip(dist_, lower_, upper_)
113+
censored_rv_ = pt.clip(dist_, lower_, upper_)
114114

115115
return CensoredRV(
116116
inputs=[dist_, lower_, upper_],
@@ -129,22 +129,22 @@ def change_censored_size(cls, dist, new_size, expand=False):
129129

130130
@_moment.register(CensoredRV)
131131
def moment_censored(op, rv, dist, lower, upper):
132-
moment = at.switch(
133-
at.eq(lower, -np.inf),
134-
at.switch(
135-
at.isinf(upper),
132+
moment = pt.switch(
133+
pt.eq(lower, -np.inf),
134+
pt.switch(
135+
pt.isinf(upper),
136136
# lower = -inf, upper = inf
137137
0,
138138
# lower = -inf, upper = x
139139
upper - 1,
140140
),
141-
at.switch(
142-
at.eq(upper, np.inf),
141+
pt.switch(
142+
pt.eq(upper, np.inf),
143143
# lower = x, upper = inf
144144
lower + 1,
145145
# lower = x, upper = x
146146
(lower + upper) / 2,
147147
),
148148
)
149-
moment = at.full_like(dist, moment)
149+
moment = pt.full_like(dist, moment)
150150
return moment

0 commit comments

Comments
 (0)