forked from lkreidberg/batman
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsetup.py
98 lines (92 loc) · 4.05 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from __future__ import print_function
from distutils.core import setup
from distutils.extension import Extension
import numpy as np
from distutils.ccompiler import new_compiler
import os
import sys
import tempfile
"""
Check for OpenMP based on
https://github.com/MDAnalysis/mdanalysis/tree/develop/package/setup.py
retrieved 06/15/15
"""
def detect_openmp():
"""Does this compiler support OpenMP parallelization?"""
compiler = new_compiler()
print("Checking for OpenMP support... ")
hasopenmp = hasfunction(compiler, 'omp_get_num_threads()')
needs_gomp = hasopenmp
if not hasopenmp:
compiler.add_library('gomp')
hasopenmp = hasfunction(compiler, 'omp_get_num_threads()')
needs_gomp = hasopenmp
if hasopenmp: print("Compiler supports OpenMP")
else: print( "Did not detect OpenMP support.")
return hasopenmp, needs_gomp
def hasfunction(cc, funcname, include=None, extra_postargs=None):
# From http://stackoverflow.com/questions/
# 7018879/disabling-output-when-compiling-with-distutils
tmpdir = tempfile.mkdtemp(prefix='hasfunction-')
devnull = oldstderr = None
try:
try:
fname = os.path.join(tmpdir, 'funcname.c')
f = open(fname, 'w')
if include is not None:
f.write('#include %s\n' % include)
f.write('int main(void) {\n')
f.write(' %s;\n' % funcname)
f.write('}\n')
f.close()
# Redirect stderr to /dev/null to hide any error messages
# from the compiler.
# This will have to be changed if we ever have to check
# for a function on Windows.
devnull = open('/dev/null', 'w')
oldstderr = os.dup(sys.stderr.fileno())
os.dup2(devnull.fileno(), sys.stderr.fileno())
objects = cc.compile([fname], output_dir=tmpdir, extra_postargs=extra_postargs)
cc.link_executable(objects, os.path.join(tmpdir, "a.out"))
except Exception as e:
return False
return True
finally:
if oldstderr is not None:
os.dup2(oldstderr, sys.stderr.fileno())
if devnull is not None:
devnull.close()
#checks whether OpenMP is supported
has_openmp, needs_gomp = detect_openmp()
parallel_args = ['-fopenmp', '-std=c99'] if has_openmp else ['-std=c99']
parallel_libraries = ['gomp'] if needs_gomp else []
_nonlinear_ld = Extension('robin._nonlinear_ld', ['c_src/_nonlinear_ld.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_quadratic_ld = Extension('robin._quadratic_ld', ['c_src/_quadratic_ld.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_uniform_ld = Extension('robin._uniform_ld', ['c_src/_uniform_ld.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_logarithmic_ld = Extension('robin._logarithmic_ld', ['c_src/_logarithmic_ld.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_exponential_ld = Extension('robin._exponential_ld', ['c_src/_exponential_ld.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_custom_ld = Extension('robin._custom_ld', ['c_src/_custom_ld.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_power2_ld = Extension('robin._power2_ld', ['c_src/_power2_ld.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_rsky = Extension('robin._rsky', ['c_src/_rsky.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
_eclipse = Extension('robin._eclipse', ['c_src/_eclipse.c'], extra_compile_args = parallel_args, libraries = parallel_libraries)
setup( name='robin-package',
version="0.1.2",
author='Brett Morris & Laura Kreidberg',
author_email = 'bmmorris@uw.edu',
url = 'https://github.com/bmorris3/robin',
packages =['robin'],
license = ['GNU GPLv3'],
description ='Fast transit light curve modeling',
classifiers = [
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering',
'Programming Language :: Python'
],
include_dirs = [np.get_include()],
install_requires = ['numpy'],
extras_requires= {
'matplotlib': ['matplotlib'],
},
ext_modules=[_nonlinear_ld, _quadratic_ld, _uniform_ld, _rsky, _eclipse, _logarithmic_ld, _exponential_ld, _power2_ld, _custom_ld]
)