| | |
| |
|
| | from setuptools import find_packages, setup |
| |
|
| | import os |
| | import subprocess |
| | import sys |
| | import time |
| | import torch |
| | from torch.utils.cpp_extension import (BuildExtension, CppExtension, |
| | CUDAExtension) |
| |
|
| | version_file = 'basicsr/version.py' |
| |
|
| |
|
| | def readme(): |
| | return '' |
| | |
| | |
| | |
| |
|
| |
|
| | def get_git_hash(): |
| |
|
| | def _minimal_ext_cmd(cmd): |
| | |
| | env = {} |
| | for k in ['SYSTEMROOT', 'PATH', 'HOME']: |
| | v = os.environ.get(k) |
| | if v is not None: |
| | env[k] = v |
| | |
| | env['LANGUAGE'] = 'C' |
| | env['LANG'] = 'C' |
| | env['LC_ALL'] = 'C' |
| | out = subprocess.Popen( |
| | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] |
| | return out |
| |
|
| | try: |
| | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) |
| | sha = out.strip().decode('ascii') |
| | except OSError: |
| | sha = 'unknown' |
| |
|
| | return sha |
| |
|
| |
|
| | def get_hash(): |
| | if os.path.exists('.git'): |
| | sha = get_git_hash()[:7] |
| | elif os.path.exists(version_file): |
| | try: |
| | from basicsr.version import __version__ |
| | sha = __version__.split('+')[-1] |
| | except ImportError: |
| | raise ImportError('Unable to get git version') |
| | else: |
| | sha = 'unknown' |
| |
|
| | return sha |
| |
|
| |
|
| | def write_version_py(): |
| | content = """# GENERATED VERSION FILE |
| | # TIME: {} |
| | __version__ = '{}' |
| | short_version = '{}' |
| | version_info = ({}) |
| | """ |
| | sha = get_hash() |
| | with open('VERSION', 'r') as f: |
| | SHORT_VERSION = f.read().strip() |
| | VERSION_INFO = ', '.join( |
| | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) |
| | VERSION = SHORT_VERSION + '+' + sha |
| |
|
| | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, |
| | VERSION_INFO) |
| | with open(version_file, 'w') as f: |
| | f.write(version_file_str) |
| |
|
| |
|
| | def get_version(): |
| | with open(version_file, 'r') as f: |
| | exec(compile(f.read(), version_file, 'exec')) |
| | return locals()['__version__'] |
| |
|
| |
|
| | def make_cuda_ext(name, module, sources, sources_cuda=None): |
| | if sources_cuda is None: |
| | sources_cuda = [] |
| | define_macros = [] |
| | extra_compile_args = {'cxx': []} |
| |
|
| | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': |
| | define_macros += [('WITH_CUDA', None)] |
| | extension = CUDAExtension |
| | extra_compile_args['nvcc'] = [ |
| | '-D__CUDA_NO_HALF_OPERATORS__', |
| | '-D__CUDA_NO_HALF_CONVERSIONS__', |
| | '-D__CUDA_NO_HALF2_OPERATORS__', |
| | ] |
| | sources += sources_cuda |
| | else: |
| | print(f'Compiling {name} without CUDA') |
| | extension = CppExtension |
| |
|
| | return extension( |
| | name=f'{module}.{name}', |
| | sources=[os.path.join(*module.split('.'), p) for p in sources], |
| | define_macros=define_macros, |
| | extra_compile_args=extra_compile_args) |
| |
|
| |
|
| | def get_requirements(filename='requirements.txt'): |
| | return [] |
| | here = os.path.dirname(os.path.realpath(__file__)) |
| | with open(os.path.join(here, filename), 'r') as f: |
| | requires = [line.replace('\n', '') for line in f.readlines()] |
| | return requires |
| |
|
| |
|
| | if __name__ == '__main__': |
| | if '--no_cuda_ext' in sys.argv: |
| | ext_modules = [] |
| | sys.argv.remove('--no_cuda_ext') |
| | else: |
| | ext_modules = [ |
| | make_cuda_ext( |
| | name='deform_conv_ext', |
| | module='basicsr.models.ops.dcn', |
| | sources=['src/deform_conv_ext.cpp'], |
| | sources_cuda=[ |
| | 'src/deform_conv_cuda.cpp', |
| | 'src/deform_conv_cuda_kernel.cu' |
| | ]), |
| | make_cuda_ext( |
| | name='fused_act_ext', |
| | module='basicsr.models.ops.fused_act', |
| | sources=['src/fused_bias_act.cpp'], |
| | sources_cuda=['src/fused_bias_act_kernel.cu']), |
| | make_cuda_ext( |
| | name='upfirdn2d_ext', |
| | module='basicsr.models.ops.upfirdn2d', |
| | sources=['src/upfirdn2d.cpp'], |
| | sources_cuda=['src/upfirdn2d_kernel.cu']), |
| | ] |
| |
|
| | write_version_py() |
| | print("setup start") |
| | setup( |
| | name='basicsr', |
| | version=get_version(), |
| | description='Open Source Image and Video Super-Resolution Toolbox', |
| | long_description=readme(), |
| | author='Xintao Wang', |
| | author_email='xintao.wang@outlook.com', |
| | keywords='computer vision, restoration, super resolution', |
| | url='https://github.com/xinntao/BasicSR', |
| | packages=find_packages( |
| | exclude=('options', 'datasets', 'experiments', 'results', |
| | 'tb_logger', 'wandb')), |
| | classifiers=[ |
| | 'Development Status :: 4 - Beta', |
| | 'License :: OSI Approved :: Apache Software License', |
| | 'Operating System :: OS Independent', |
| | 'Programming Language :: Python :: 3', |
| | 'Programming Language :: Python :: 3.7', |
| | 'Programming Language :: Python :: 3.8', |
| | ], |
| | license='Apache License 2.0', |
| | setup_requires=['cython', 'numpy'], |
| | install_requires=get_requirements(), |
| | ext_modules=ext_modules, |
| | cmdclass={'build_ext': BuildExtension}, |
| | zip_safe=False) |
| |
|