1.13. Functions as Arguments#
Reference: Chapter 5 of Computational Nuclear Engineering and Radiological Science Using Python, R. McClarren (2018)
1.13.1. Learning Objectives#
After studying this notebook, completing the activities, and asking questions in class, you should be able to:
Pass a function as an argument to a function
Define and use lambda functions
1.13.2. Functions passed to Functions#
Functions can also use other functions as part of their input. In the simple example below, f2 takes f1 as an input. Try to predict the output before running the cell.
def f1(x):
return x*3
def f2(f,x):
return f(x) + 2
print(f2(f1,4))
Show code cell output
14
Here’s a more complicated example. Again, try and predict its output before you run the cell.
def my_f(x):
return (x/3)*5
def my_g(f,x):
return f(x)**2+1
print (my_g(my_f,6))
Show code cell output
101.0
1.13.3. Lambda Functions#
Python also allows you to define a lambda function, which is basically a one line function without the whole def business. Here’s an example where the lambda function is a line.
# `Normal` function definition that we've used before (old method)
def simple_line(x):
'''solve simple expression
Args:
x: variable inputed value
Returns:
solution to the expression
'''
return 2.0*x + 1.0
import numpy as np
import matplotlib.pyplot as plt
# define the function `simple_line` with a single line of code using a lambda function (new method)
simple_line = lambda x: 2.0*x + 1.0
# Evaluate function at three values of x
print("The line at x = 0 is", simple_line(0))
print("The line at x = 1 is", simple_line(1))
print("The line at x = 2 is", simple_line(2))
# Evaluate function at many values of f
x = np.linspace(0,6,50)
y = simple_line(x)
# Make plot
plt.plot(x,y)
plt.ylabel("y")
plt.xlabel("x")
plt.show()
The line at x = 0 is 1.0
The line at x = 1 is 3.0
The line at x = 2 is 5.0
Lambda functions in Python are analogous to anonymous functions in MATLAB.
Home Activity
Predict the output of the code below before you run it.
func_1 = (lambda x: x + x)(2)
func_2 = lambda x, y:x+y
print (func_1)
print (func_2(1,5))
Home Activity
Recreate the following function as a lambda function in the cell below it and test the function when r = 2 cm and h = 6 cm.
def cylinder_surface_area(r,h):
'''solve for the surface area of a cylinder
Args:
r: radius of cylinder (cm)
h: height of cylinder (cm)
Returns:
surface area of the cylinder
'''
return 2*np.pi*r**2 + 2*np.pi*r*h
print(cylinder_surface_area(2,6),"cm^2")
100.53096491487338 cm^2
# Create the above function as a lambda function
# Add your solution here
You can also use lambda functions inside of other functions, shown below.
def my_func(n):
return lambda y : y * n
doubler_func = my_func(2) # this creates the doubler_func with 2 being passed as n to the lambda return function
tripler_func = my_func(3)
print(doubler_func(11)) # this takes the 11 as y with the 2 already being specified as n
print(tripler_func(11))
22
33
Home Activity
Copy the above function into the cell below. Modify it so that it also raises the lambda function to the nth power, then create a fun_func that passes 2 into my_func. Print fun_func(4).
# Add your solution here
Sometimes you will want nested lambda functions; you will see an application in the last example in this notebook.
func_squared = lambda x: x**2
func_product = lambda F, m: lambda x: F(x)*m
print(func_product(func_squared, 5)(2))
20
1.13.4. Tying it all together: Midpoint Integration#
Recall the Midpoint formula for approximating an integral.
With only one element:
With \(N\) elements:
where
and
We want to integrate an arbitrary function using the midpoint rule. To do this, we need to evaluate the function at several values to calculate the midpoint sum.
In the following Python code, argument f
is a function with a scalar input and scalar output.
def midpoint_rule(f,a,b,num_intervals):
"""integrate function f using the midpoint rule
Args:
f: function to be integrated, it must take one argument
a: lower bound of integral range
b: upper bound of integral range
num_intervals: the number of intervals to break [a,b] into
Returns:
estimate of the integral
"""
L = lambda a,b: (b-a) #how big is the range using a lambda function (here lambda function is interchangeable with L = (b-a))
dx = L(a,b)/num_intervals #how big is each interval
#midpoints are a+dx/2, a+3dx/2, ..., b-dx/2
midpoints = np.arange(num_intervals)*dx+0.5*dx+a
integral = 0
for point in midpoints:
integral += f(point)
return integral*dx
help(midpoint_rule) # use this command to see the summary of the function in the triple quotes
Help on function midpoint_rule in module __main__:
midpoint_rule(f, a, b, num_intervals)
integrate function f using the midpoint rule
Args:
f: function to be integrated, it must take one argument
a: lower bound of integral range
b: upper bound of integral range
num_intervals: the number of intervals to break [a,b] into
Returns:
estimate of the integral
a = 1.0
dx = 1.0
midpoints = np.arange(10)*dx+0.5*dx+a
print(midpoints)
[ 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5 10.5]
Thus midpoint_rule
is a function that takes another function, f
, as an argument (input).
Now we will approximate:
using \(N=10\) and the midpoint rule. The analytic solution (correct answer) is 2.
print(midpoint_rule(np.sin,0,np.pi,10))
2.0082484079079745
Home Activity
Approximate the integral below using the midpoint formula with 5 intervals. Store your answer in approx_integral.
# Add your solution here
# Removed autograder test. You may delete this cell.
1.13.5. Convergence Analysis#
At what rate does an approximation converge to the true solution? is a central question in numerical analysis.
Let’s characterize the convergence of the midpoint approximation for this particular integral.
import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
# Define values of N to consider
num_intervals = 8 #number of interval sizes
intervals = 10**np.arange(num_intervals) #run several different intervals
print("We will consider the following values of N:")
print(intervals)
print(" ")
# Allocate an array to store result
integral_error = np.zeros(num_intervals)
# Define integration limits
a = 0
b = np.pi
# Loop over different values of N
count = 0
for interval in intervals:
print("Considering N =",interval)
integral_error[count] = np.fabs(2 - midpoint_rule(np.sin,a,b,interval))
count += 1
# Create figure
fig = plt.figure()
ax = fig.add_subplot(111)
import matplotlib.ticker as mtick
plt.loglog(intervals,integral_error,marker="o",markersize = 10,linewidth=2);
plt.xlabel("# of intervals")
plt.ylabel("Error in midpoint rule")
plt.axis([1,1.5e7,1.0e-13,10])
plt.show()
We will consider the following values of N:
[ 1 10 100 1000 10000 100000 1000000 10000000]
Considering N = 1
Considering N = 10
Considering N = 100
Considering N = 1000
Considering N = 10000
Considering N = 100000
Considering N = 1000000
Considering N = 10000000
Class Activity
Copy the code from above to below. Adapt it to analyze the integal from the previous home activity.
# Add your solution here
Class Activity
Approximate the exponential integration function below using the midpoint formula.
Approximate the exponential integral function:
Use \(N = 10^6\) element, \(x=1\), \(n=1\), and an upper bound of 10,000 (instead of infinity).
def exp_int_argument(t,n=1,x=1):
''' Exponential function integrand
Arguments:
t: scalar
n: scalar, default is n=1
x: scalar, default is x=1
Returns:
f: value of integrand at t,n,x
'''
###BEGIN SOLUTION
f = np.exp(-x*t)/t**n
return f
approx_exp_integral = midpoint_rule(exp_int_argument, 1, 10000, 10**6)
print(approx_exp_integral)
###END SOLUTION
The exact answer is 0.2193839343.
1.13.6. A Fancier Integration Function#
Using matplotlib we can make an even fancier integration function
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
def midpoint_rule_graphical(f,a,b,num_intervals,filename):
"""integrate function f using the midpoint rule
Args:
f: function to be integrated, it must take one argument
a: lower bound of integral range
b: upper bound of integral range
num_intervals: the number of intervals to break [a,b] into
Returns:
estimate of the integral
Side Effect:
Plots intervals and areas of midpoint rule
"""
# Create plot
ax = plt.subplot(111)
# Setup, similar to previous midpoint_rule function
L = (b-a) #how big is the range
dx = L/num_intervals #how big is each interval
midpoints = np.arange(num_intervals)*dx+0.5*dx+a
x = midpoints
#y = np.zeros(num_intervals)
integral = 0
count = 0
# Loop over points
for point in midpoints:
# Evaluate function f
#y[count] = f(point)
# Calculate integral
integral = integral + f(point)
# Calculate verticies for plots
verts = [(point-dx/2,0)] + [(point-dx/2,f(point))]
verts += [(point+dx/2,f(point))] + [(point+dx/2,0)]
# Draw rectangles
poly = plt.Polygon(verts, facecolor='0.8', edgecolor='k')
ax.add_patch(poly)
# Incrememnt counter
count += 1
# y = f(x)
# Draw smooth line for f
smooth_x = np.linspace(a,b,10000)
smooth_y = f(smooth_x)
plt.plot(smooth_x, smooth_y, linewidth=1)
# Add labels and title
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title("Integral Estimate is " + str(integral*dx))
# Save figure
plt.savefig(filename)
# Return approximation for integral
return integral*dx
# Call function
midpoint_rule_graphical(np.sin,0,2*np.pi,4,'C4_fig7.pdf')
0.0
Class Activity
Also apply to the previous class example. The code below will not work until exp_int_argument is defined correctly.
# Add your solution here
1.13.7. Incorporating Lambda Functions with Midpoint Integration#
We can use lambda functions in our midpoint integration routine as well. Here we define a Gaussian as the integrand:
The analytic answer to the integral is 1.
# Define lambda function
gaussian = lambda x: np.exp(-x**2)/np.sqrt(np.pi) #function to compute gaussian
# Approximate integral using lambda function
midpoint_rule_graphical(gaussian,-3,3,20,'C4-with-lambda-func.pdf')
0.999980808068639
# Approximate integral with more manual approach
midpoint_rule_graphical(lambda x: np.exp(-x**2)/np.sqrt(np.pi),-3,3,20,'C4-without-lambda-func.pdf')
0.999980808068639
1.13.8. Extension to Two Dimensional Functions#
We will revisit numeric integration at the end of the semester. We can use two lambda functions to extend the midpoint rule to integrate two dimensional functions, such as
Home Activity
Spend 5-10 minutes brainstorming how to approximate a 2 dimensional integral by calling the midpoint_rule function twice (nested).
Class Activity
Walk through the code below together.
def midpoint_2D(f,ax,bx,ay,by,num_intervals_x,num_intervals_y):
"""Midpoint rule extended to 2D functions
Arguments:
f: function to be integrated. Takes two scalar inputs (x,y) and returns a scalar.
ax: lower bound for dimension 1
bx: upper bound for dimension 1
ay: lower bound for dimension 2
by: upper bound for dimension 2
num_intervals_x: number of intervals for dimension 1
num_intervals_y: number of intervals for dimension 2
Returns:
approximation to integral (scalar)
"""
# For a given y, calculate the integral in dimension 1 (from ax to bx) using midpoint rule
integral_over_x = lambda y: midpoint_rule(lambda x: f(x,y),ax,bx,num_intervals_x)
# Apply midpoint rule to dimension 2
return midpoint_rule(integral_over_x,ay,by,num_intervals_y)
# Estimate 2-dimensional integral
sin2 = lambda x,y:np.sin(x)*np.sin(y)
print("Estimate of the integral of sin(x)sin(y), over [0,pi] x [0,pi] is",
midpoint_2D(sin2,0,np.pi,0,np.pi,1000,1000))
Estimate of the integral of sin(x)sin(y), over [0,pi] x [0,pi] is 4.000003289869757
This code snippet defines one lambda function that takes care of the x argument to f, and another that defines the integral over x for a given y, this second function is passed to the midpoint rule.
The way that this works is that when the second midpoint rule evaluates integral_over_x at a certain point in y, it evaluates the integral over all x at that point.